diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 2ed158f89..3a9ddc36f 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -143,9 +143,7 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -190,9 +188,7 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index 4aab3b676..dcd76d47b 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { text->set_left(label_left_px_); text->set_baseline(label_baseline_px + i * label_height_px_); text->set_font_face(options_.font_face()); + if (options_.outline_thickness() > 0) { + text->set_outline_thickness(options_.outline_thickness()); + if (options_.outline_color_size() > 0) { + *(text->mutable_outline_color()) = + options_.outline_color(i % options_.outline_color_size()); + } else { + text->mutable_outline_color()->set_r(0); + text->mutable_outline_color()->set_g(0); + text->mutable_outline_color()->set_b(0); + } + } } cc->Outputs() .Tag(kRenderDataTag) diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index cf0ada9c2..7946ff683 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions { // Thickness for drawing the label(s). optional double thickness = 2 [default = 2]; + // Color of outline around each character, if any. One per label, as with + // color attribute. + repeated Color outline_color = 12; + + // Thickness of outline around each character. + optional double outline_thickness = 11; + // The font height in absolute pixels. optional int32 font_height_px = 3 [default = 50]; diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 69d2fab7a..7d095a5d4 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -185,7 +185,10 @@ void GlTextureBuffer::Updated(std::shared_ptr prod_token) { << "Updated existing texture which had not been marked for reuse!"; CHECK(prod_token); producer_sync_ = std::move(prod_token); - producer_context_ = producer_sync_->GetContext(); + const auto& synced_context = producer_sync_->GetContext(); + if (synced_context) { + producer_context_ = synced_context; + } } void GlTextureBuffer::DidRead(std::shared_ptr cons_token) const { diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index 6b7fb1271..dd5f8f1da 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -34,6 +34,7 @@ android_library( android_library( name = "android_framework_no_mff", proguard_specs = [":proguard.pgcfg"], + visibility = ["//visibility:public"], exports = [ ":android_framework_no_proguard", ], diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 33667d18e..33d96e9f2 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -48,6 +48,8 @@ pybind_extension( "//mediapipe/python/pybind:timestamp", "//mediapipe/python/pybind:validated_graph_config", "//mediapipe/tasks/python/core/pybind:task_runner", + "@com_google_absl//absl/strings:str_format", + "@stblib//:stb_image", # Type registration. "//mediapipe/framework:basic_types_registration", "//mediapipe/framework/formats:classification_registration", diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index c8d929e72..117d20974 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -15,6 +15,7 @@ """Tests for mediapipe.python._framework_bindings.image.""" import gc +import os import random import sys @@ -23,6 +24,7 @@ import cv2 import numpy as np import PIL.Image +# resources dependency from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image_frame @@ -185,6 +187,5 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) - if __name__ == '__main__': absltest.main() diff --git a/mediapipe/python/pybind/BUILD b/mediapipe/python/pybind/BUILD index 45cc83b38..b26d6bb6a 100644 --- a/mediapipe/python/pybind/BUILD +++ b/mediapipe/python/pybind/BUILD @@ -45,6 +45,8 @@ pybind_library( ":util", "//mediapipe/framework:type_map", "//mediapipe/framework/formats:image", + "@com_google_absl//absl/strings:str_format", + "@stblib//:stb_image", ], ) diff --git a/mediapipe/python/pybind/image.cc b/mediapipe/python/pybind/image.cc index 651eb2ca6..5d8663143 100644 --- a/mediapipe/python/pybind/image.cc +++ b/mediapipe/python/pybind/image.cc @@ -16,9 +16,11 @@ #include +#include "absl/strings/str_format.h" #include "mediapipe/python/pybind/image_frame_util.h" #include "mediapipe/python/pybind/util.h" #include "pybind11/stl.h" +#include "stb_image.h" namespace mediapipe { namespace python { @@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) { image.is_aligned(16) )doc"); + image.def_static( + "create_from_file", + [](const std::string& file_name) { + int width; + int height; + int channels; + auto* image_data = + stbi_load(file_name.c_str(), &width, &height, &channels, + /*desired_channels=*/0); + if (image_data == nullptr) { + throw RaisePyError(PyExc_RuntimeError, + absl::StrFormat("Image decoding failed (%s): %s", + stbi_failure_reason(), file_name) + .c_str()); + } + ImageFrameSharedPtr image_frame; + switch (channels) { + case 1: + image_frame = std::make_shared( + ImageFormat::GRAY8, width, height, width, image_data, + stbi_image_free); + break; + case 3: + image_frame = std::make_shared( + ImageFormat::SRGB, width, height, 3 * width, image_data, + stbi_image_free); + break; + case 4: + image_frame = std::make_shared( + ImageFormat::SRGBA, width, height, 4 * width, image_data, + stbi_image_free); + break; + default: + throw RaisePyError( + PyExc_RuntimeError, + absl::StrFormat( + "Expected image with 1 (grayscale), 3 (RGB) or 4 " + "(RGBA) channels, found %d channels.", + channels) + .c_str()); + } + return Image(std::move(image_frame)); + }, + R"doc(Creates `Image` object from the image file. + +Args: + file_name: Image file name. + +Returns: + `Image` object. + +Raises: + RuntimeError if the image file can't be decoded. + )doc", + py::arg("file_name")); + image.def_property_readonly("width", &Image::width) .def_property_readonly("height", &Image::height) .def_property_readonly("channels", &Image::channels) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 363dc89a9..ac238bfda 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -33,11 +33,12 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:matrix", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", @@ -60,12 +61,13 @@ cc_library( ":audio_classifier_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index db939f341..702d802c5 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -22,10 +22,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" +#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -33,8 +34,12 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { + namespace { +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioTag[] = "AUDIO"; constexpr char kClassificationResultStreamName[] = "classification_result_out"; @@ -42,16 +47,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.audio.AudioClassifierGraph"; + "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; -using AudioClassifierOptionsProto = - audio_classifier::proto::AudioClassifierOptions; - // Creates a MediaPipe graph config that only contains a single subgraph node of -// "mediapipe.tasks.audio.AudioClassifierGraph". +// type "AudioClassifierGraph". CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options_proto) { + std::unique_ptr options_proto) { api2::builder::Graph graph; auto& subgraph = graph.AddNode(kSubgraphTypeName); graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag); @@ -59,7 +61,8 @@ CalculatorGraphConfig CreateGraphConfig( graph.In(kSampleRateTag).SetName(kSampleRateName) >> subgraph.In(kSampleRateTag); } - subgraph.GetOptions().Swap(options_proto.get()); + subgraph.GetOptions().Swap( + options_proto.get()); subgraph.Out(kClassificationResultTag) .SetName(kClassificationResultStreamName) >> graph.Out(kClassificationResultTag); @@ -67,18 +70,18 @@ CalculatorGraphConfig CreateGraphConfig( } // Converts the user-facing AudioClassifierOptions struct to the internal -// AudioClassifierOptions proto. -std::unique_ptr +// AudioClassifierGraphOptions proto. +std::unique_ptr ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { - auto options_proto = std::make_unique(); + auto options_proto = std::make_unique(); auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode == core::RunningMode::AUDIO_STREAM); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); @@ -119,7 +122,7 @@ absl::StatusOr> AudioClassifier::Create( }; } return core::AudioTaskApiFactory::Create( + proto::AudioClassifierGraphOptions>( CreateGraphConfig(std::move(options_proto)), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)); @@ -140,6 +143,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index 688bb60e3..200cffb8c 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -23,13 +23,14 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { // The options for configuring a mediapipe audio classifier task. struct AudioClassifierOptions { @@ -39,7 +40,7 @@ struct AudioClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The running mode of the audio classifier. Default to the audio clips mode. // Audio classifier has two running modes: @@ -58,8 +59,9 @@ struct AudioClassifierOptions { // The user-defined result callback for processing audio stream data. // The result callback should only be specified when the running mode is set // to RunningMode::AUDIO_STREAM. - std::function)> result_callback = - nullptr; + std::function)> + result_callback = nullptr; }; // Performs audio classification on audio clips or audio stream. @@ -131,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // framed audio clip. // TODO: Use `sample_rate` in AudioClassifierOptions by default // and makes `audio_sample_rate` optional. - absl::StatusOr Classify(mediapipe::Matrix audio_clip, - double audio_sample_rate); + absl::StatusOr Classify( + mediapipe::Matrix audio_clip, double audio_sample_rate); // Sends audio data (a block in a continuous audio stream) to perform audio // classification. Only use this method when the AudioClassifier is created @@ -162,6 +164,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { absl::Status Close() { return runner_->Close(); } }; +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 0f40b59a4..12f8ce31a 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -28,12 +28,12 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" +#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -44,6 +44,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { namespace { @@ -52,6 +53,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAudioTag[] = "AUDIO"; @@ -60,10 +62,9 @@ constexpr char kPacketTag[] = "PACKET"; constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; -using AudioClassifierOptionsProto = - audio_classifier::proto::AudioClassifierOptions; -absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) { +absl::Status SanityCheckOptions( + const proto::AudioClassifierGraphOptions& options) { if (options.base_options().use_stream_mode() && !options.has_default_input_audio_sample_rate()) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -111,7 +112,7 @@ void ConfigureAudioToTensorCalculator( } // namespace -// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification. +// An "AudioClassifierGraph" performs audio classification. // - Accepts CPU audio buffer and outputs classification results on CPU. // // Inputs: @@ -129,12 +130,12 @@ void ConfigureAudioToTensorCalculator( // // Example: // node { -// calculator: "mediapipe.tasks.audio.AudioClassifierGraph" +// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph" // input_stream: "AUDIO:audio_in" // input_stream: "SAMPLE_RATE:sample_rate_in" // output_stream: "CLASSIFICATION_RESULT:classification_result_out" // options { -// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext] +// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext] // { // base_options { // model_asset { @@ -152,16 +153,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; - const bool use_stream_mode = sc->Options() - .base_options() - .use_stream_mode(); + const bool use_stream_mode = + sc->Options() + .base_options() + .use_stream_mode(); ASSIGN_OR_RETURN( auto classification_result_out, BuildAudioClassificationTask( - sc->Options(), *model_resources, + sc->Options(), *model_resources, graph[Input(kAudioTag)], use_stream_mode ? absl::nullopt @@ -178,14 +181,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // buffer (mediapipe::Matrix) and the corresponding sample rate (double) as // the inputs and returns one classification result per input audio buffer. // - // task_options: the mediapipe tasks AudioClassifierOptions proto. + // task_options: the mediapipe tasks AudioClassifierGraphOptions proto. // model_resources: the ModelSources object initialized from an audio // classifier model file with model metadata. // audio_in: (mediapipe::Matrix) stream to run audio classification on. // sample_rate_in: (double) optional stream of the input audio sample rate. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr> BuildAudioClassificationTask( - const AudioClassifierOptionsProto& task_options, + const proto::AudioClassifierGraphOptions& task_options, const core::ModelResources& model_resources, Source audio_in, absl::optional> sample_rate_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); @@ -236,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio classification on @@ -257,8 +263,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph { } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph); +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index dd56c4ff1..4b64d2231 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -37,17 +37,19 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" namespace mediapipe { namespace tasks { namespace audio { +namespace audio_classifier { namespace { using ::absl::StatusOr; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; @@ -557,6 +559,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { } } // namespace +} // namespace audio_classifier } // namespace audio } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index 7b1952e06..bfe37ec01 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -19,12 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "audio_classifier_options_proto", - srcs = ["audio_classifier_options.proto"], + name = "audio_classifier_graph_options_proto", + srcs = ["audio_classifier_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto similarity index 84% rename from mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto rename to mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index a76ccdcab..16aa86aeb 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,12 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message AudioClassifierOptions { +message AudioClassifierGraphOptions { extend mediapipe.CalculatorOptions { - optional AudioClassifierOptions ext = 451755788; + optional AudioClassifierGraphOptions ext = 451755788; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. @@ -31,7 +31,7 @@ message AudioClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // The default sample rate of the input audio. Must be set when the // AudioClassifier is configured to process audio stream data. diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 4de32ce9b..7939e4e39 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -58,65 +58,6 @@ cc_library( # TODO: Enable this test -cc_library( - name = "classifier_options", - srcs = ["classifier_options.cc"], - hdrs = ["classifier_options.h"], - deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"], -) - -mediapipe_proto_library( - name = "classification_postprocessing_options_proto", - srcs = ["classification_postprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", - ], -) - -cc_library( - name = "classification_postprocessing", - srcs = ["classification_postprocessing.cc"], - hdrs = ["classification_postprocessing.h"], - deps = [ - ":classification_postprocessing_options_cc_proto", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/tensor:tensors_dequantization_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:packet", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 13ca6b496..7d01e4dfe 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,8 +37,8 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers:category_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], alwayslink = 1, @@ -128,7 +128,7 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index b2848bc3f..e1f69e607 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,15 +25,15 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; -using ::mediapipe::tasks::ClassificationResult; -using ::mediapipe::tasks::Classifications; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into a single ClassificationResult that has // 3 dimensions: (classification head, classification timestamp, classification diff --git a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc index b688cda91..10eb962dd 100644 --- a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc @@ -17,12 +17,13 @@ limitations under the License. #include -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" // Specialized EndLoopCalculator for Tasks specific types. namespace mediapipe::tasks { -typedef EndLoopCalculator> +typedef EndLoopCalculator< + std::vector> EndLoopClassificationResultCalculator; REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 9c6402e64..633b5b369 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "category_proto", + srcs = ["category.proto"], +) + +mediapipe_proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [ + ":category_proto", + ], +) + +mediapipe_proto_library( + name = "embeddings_proto", + srcs = ["embeddings.proto"], +) + mediapipe_proto_library( name = "landmarks_detection_result_proto", srcs = [ @@ -29,8 +47,3 @@ mediapipe_proto_library( "//mediapipe/framework/formats:rect_proto", ], ) - -mediapipe_proto_library( - name = "embeddings_proto", - srcs = ["embeddings.proto"], -) diff --git a/mediapipe/tasks/cc/components/containers/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto similarity index 96% rename from mediapipe/tasks/cc/components/containers/category.proto rename to mediapipe/tasks/cc/components/containers/proto/category.proto index 47f38b75a..a44fb5b15 100644 --- a/mediapipe/tasks/cc/components/containers/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; // A single classification result. message Category { diff --git a/mediapipe/tasks/cc/components/containers/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto similarity index 93% rename from mediapipe/tasks/cc/components/containers/classifications.proto rename to mediapipe/tasks/cc/components/containers/proto/classifications.proto index 469c67fc9..e0ccad7a1 100644 --- a/mediapipe/tasks/cc/components/containers/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -15,9 +15,9 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; -import "mediapipe/tasks/cc/components/containers/category.proto"; +import "mediapipe/tasks/cc/components/containers/proto/category.proto"; // List of predicted categories with an optional timestamp. message ClassificationEntry { diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD new file mode 100644 index 000000000..62f04dcb7 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -0,0 +1,64 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "classifier_options", + srcs = ["classifier_options.cc"], + hdrs = ["classifier_options.h"], + deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], +) + +cc_library( + name = "classification_postprocessing_graph", + srcs = ["classification_postprocessing_graph.cc"], + hdrs = ["classification_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc similarity index 92% rename from mediapipe/tasks/cc/components/classification_postprocessing.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 871476e8f..35adab687 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include @@ -37,9 +37,9 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -51,6 +51,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; @@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; // Performs sanity checks on provided ClassifierOptions. -absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) { +absl::Status SanityCheckClassifierOptions( + const proto::ClassifierOptions& options) { if (options.max_results() == 0) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -203,7 +205,7 @@ absl::StatusOr GetScoreThreshold( // Gets the category allowlist or denylist (if any) as a set of indices. absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( - const ClassifierOptions& options, const LabelItems& label_items) { + const proto::ClassifierOptions& options, const LabelItems& label_items) { absl::flat_hash_set category_indices; // Exit early if no denylist/allowlist. if (options.category_denylist_size() == 0 && @@ -239,7 +241,7 @@ absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( absl::Status ConfigureScoreCalibrationIfAny( const ModelMetadataExtractor& metadata_extractor, int tensor_index, - ClassificationPostprocessingOptions* options) { + proto::ClassificationPostprocessingGraphOptions* options) { const auto* tensor_metadata = metadata_extractor.GetOutputTensorMetadata(tensor_index); if (tensor_metadata == nullptr) { @@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny( // Fills in the TensorsToClassificationCalculatorOptions based on the // classifier options and the (optional) output tensor metadata. absl::Status ConfigureTensorsToClassificationCalculator( - const ClassifierOptions& options, + const proto::ClassifierOptions& options, const ModelMetadataExtractor& metadata_extractor, int tensor_index, TensorsToClassificationCalculatorOptions* calculator_options) { const auto* tensor_metadata = @@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator( } // namespace -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const ModelResources& model_resources, - const ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options) { + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options) { MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options)); ASSIGN_OR_RETURN(const auto heads_properties, GetClassificationHeadsProperties(model_resources)); @@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing( return absl::OkStatus(); } -// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts -// raw tensors into ClassificationResult objects. +// A "ClassificationPostprocessingGraph" converts raw tensors into +// ClassificationResult objects. // - Accepts CPU input tensors. // // Inputs: @@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing( // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. // -// The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureClassificationPostprocessing()' function. See header file -// for more details. -class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { +// The recommended way of using this graph is through the GraphBuilder API +// using the 'ConfigureClassificationPostprocessingGraph()' function. See header +// file for more details. +class ClassificationPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { ASSIGN_OR_RETURN( auto classification_result_out, BuildClassificationPostprocessing( - sc->Options(), + sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); classification_result_out >> @@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { } private: - // Adds an on-device classification postprocessing subgraph into the provided - // builder::Graph instance. The classification postprocessing subgraph takes + // Adds an on-device classification postprocessing graph into the provided + // builder::Graph instance. The classification postprocessing graph takes // tensors (std::vector) as input and returns one output // stream containing the output classification results (ClassificationResult). // - // options: the on-device ClassificationPostprocessingOptions. + // options: the on-device ClassificationPostprocessingGraphOptions. // tensors_in: (std::vector>) tensors to postprocess. // timestamps_in: (std::vector) optional collection of // timestamps that a single ClassificationResult should aggregate. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr> BuildClassificationPostprocessing( - const ClassificationPostprocessingOptions& options, + const proto::ClassificationPostprocessingGraphOptions& options, Source> tensors_in, Source> timestamps_in, Graph& graph) { const int num_heads = options.tensors_to_classifications_options_size(); @@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { kClassificationResultTag)]; } }; -REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph); +REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors:: + ClassificationPostprocessingGraph); // NOLINT + +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h similarity index 59% rename from mediapipe/tasks/cc/components/classification_postprocessing.h rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index eb638bd60..8aedad46d 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -13,32 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures a ClassificationPostprocessing subgraph using the provided model +// Configures a ClassificationPostprocessingGraph using the provided model // resources and ClassifierOptions. // - Accepts CPU input tensors. // // Example usage: // // auto& postprocessing = -// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); -// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( +// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( // model_resources, // classifier_options, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ClassificationPostprocessing subgraph has the following I/O: +// The resulting ClassificationPostprocessingGraph has the following I/O: // Inputs: // TENSORS - std::vector // The output tensors of an InferenceCalculator. @@ -49,13 +50,14 @@ namespace components { // Outputs: // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, - const tasks::components::proto::ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options); + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc similarity index 88% rename from mediapipe/tasks/cc/components/classification_postprocessing_test.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 67223050f..bb03e2530 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include #include @@ -42,9 +42,9 @@ limitations under the License. #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/util/label_map.pb.h" @@ -53,6 +53,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::api2::Input; @@ -60,7 +61,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::testing::HasSubstr; using ::testing::proto::Approximately; @@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(0); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); @@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); options_in.add_category_denylist("bar"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); @@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( @@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(3); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_score_threshold(0.5); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("tench"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_denylist("background"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { auto model_resources, CreateModelResourcesForModel( kQuantizedImageClassifierWithDummyScoreCalibration)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label maps sizes and first two elements. EXPECT_EQ( options_out.tensors_to_classifications_options(0).label_items_size(), @@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { class PostprocessingTest : public tflite_shims::testing::Test { protected: absl::StatusOr BuildGraph( - absl::string_view model_name, const ClassifierOptions& options, + absl::string_view model_name, const proto::ClassifierOptions& options, bool connect_timestamps = false) { ASSIGN_OR_RETURN(auto model_resources, CreateModelResourcesForModel(model_name)); Graph graph; auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( *model_resources, options, - &postprocessing.GetOptions())); + &postprocessing + .GetOptions())); graph[Input>(kTensorsTag)].SetName(kTensorsName) >> postprocessing.In(kTensorsTag); if (connect_timestamps) { @@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); options.set_score_threshold(0.5); MP_ASSERT_OK_AND_ASSIGN( @@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); @@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithTimestamps) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, @@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { } } // namespace +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.cc b/mediapipe/tasks/cc/components/processors/classifier_options.cc similarity index 81% rename from mediapipe/tasks/cc/components/classifier_options.cc rename to mediapipe/tasks/cc/components/processors/classifier_options.cc index c54db5f88..349bb569d 100644 --- a/mediapipe/tasks/cc/components/classifier_options.cc +++ b/mediapipe/tasks/cc/components/processors/classifier_options.cc @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* options) { - tasks::components::proto::ClassifierOptions options_proto; + proto::ClassifierOptions options_proto; options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_max_results(options->max_results); options_proto.set_score_threshold(options->score_threshold); @@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( return options_proto; } +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.h b/mediapipe/tasks/cc/components/processors/classifier_options.h similarity index 83% rename from mediapipe/tasks/cc/components/classifier_options.h rename to mediapipe/tasks/cc/components/processors/classifier_options.h index e15bf5e69..189b42e60 100644 --- a/mediapipe/tasks/cc/components/classifier_options.h +++ b/mediapipe/tasks/cc/components/processors/classifier_options.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { // Classifier options for MediaPipe C++ classification Tasks. struct ClassifierOptions { @@ -49,11 +50,12 @@ struct ClassifierOptions { }; // Converts a ClassifierOptions to a ClassifierOptionsProto. -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* classifier_options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD similarity index 58% rename from mediapipe/tasks/cc/components/containers/BUILD rename to mediapipe/tasks/cc/components/processors/proto/BUILD index 701f84824..d7cbe47ff 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], + name = "classifier_options_proto", + srcs = ["classifier_options.proto"], ) mediapipe_proto_library( - name = "classifications_proto", - srcs = ["classifications.proto"], + name = "classification_postprocessing_graph_options_proto", + srcs = ["classification_postprocessing_graph_options.proto"], deps = [ - ":category_proto", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", ], ) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto similarity index 91% rename from mediapipe/tasks/cc/components/classification_postprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto index 9b67e2f75..1de788eab 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto @@ -15,16 +15,16 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; -message ClassificationPostprocessingOptions { +message ClassificationPostprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ClassificationPostprocessingOptions ext = 460416950; + optional ClassificationPostprocessingGraphOptions ext = 460416950; } // Optional mapping between output tensor index and corresponding score diff --git a/mediapipe/tasks/cc/components/proto/classifier_options.proto b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto similarity index 97% rename from mediapipe/tasks/cc/components/proto/classifier_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classifier_options.proto index ea1491bb8..7afbfc14e 100644 --- a/mediapipe/tasks/cc/components/proto/classifier_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; // Shared options used by all classification tasks. message ClassifierOptions { diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 8c4dcdad9..c11d6f95a 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -23,11 +23,6 @@ mediapipe_proto_library( srcs = ["segmenter_options.proto"], ) -mediapipe_proto_library( - name = "classifier_options_proto", - srcs = ["classifier_options.proto"], -) - mediapipe_proto_library( name = "embedder_options_proto", srcs = ["embedder_options.proto"], diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 0ec7ac945..d16e2fbc4 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -42,3 +42,16 @@ cc_test( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "gate", + hdrs = ["gate.h"], + deps = [ + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + ], +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/components/utils/gate.h b/mediapipe/tasks/cc/components/utils/gate.h new file mode 100644 index 000000000..139205fc5 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/gate.h @@ -0,0 +1,160 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ + +#include + +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { + +// Utility class that simplifies allowing (gating) multiple streams. +class AllowGate { + public: + AllowGate(api2::builder::Source allow, api2::builder::Graph& graph) + : node_(AddSourceGate(allow, graph)) {} + AllowGate(api2::builder::SideSource allow, api2::builder::Graph& graph) + : node_(AddSideSourceGate(allow, graph)) {} + + // Move-only + AllowGate(AllowGate&& allow_gate) = default; + AllowGate& operator=(AllowGate&& allow_gate) = default; + + template + api2::builder::Source Allow(api2::builder::Source source) { + source >> node_.In(index_); + return node_.Out(index_++).Cast(); + } + + private: + template + static api2::builder::GenericNode& AddSourceGate( + T allow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + allow >> gate_node.In("ALLOW"); + return gate_node; + } + + template + static api2::builder::GenericNode& AddSideSourceGate( + T allow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + allow >> gate_node.SideIn("ALLOW"); + return gate_node; + } + + api2::builder::GenericNode& node_; + int index_ = 0; +}; + +// Utility class that simplifies disallowing (gating) multiple streams. +class DisallowGate { + public: + DisallowGate(api2::builder::Source disallow, + api2::builder::Graph& graph) + : node_(AddSourceGate(disallow, graph)) {} + DisallowGate(api2::builder::SideSource disallow, + api2::builder::Graph& graph) + : node_(AddSideSourceGate(disallow, graph)) {} + + // Move-only + DisallowGate(DisallowGate&& disallow_gate) = default; + DisallowGate& operator=(DisallowGate&& disallow_gate) = default; + + template + api2::builder::Source Disallow(api2::builder::Source source) { + source >> node_.In(index_); + return node_.Out(index_++).Cast(); + } + + private: + template + static api2::builder::GenericNode& AddSourceGate( + T disallow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + auto& gate_node_opts = + gate_node.GetOptions(); + // Supposedly, the most popular configuration for MediaPipe Tasks team + // graphs. Hence, intentionally hard coded to catch and verify any other use + // case (should help to workout a common approach and have a recommended way + // of blocking streams). + gate_node_opts.set_empty_packets_as_allow(true); + disallow >> gate_node.In("DISALLOW"); + return gate_node; + } + + template + static api2::builder::GenericNode& AddSideSourceGate( + T disallow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + auto& gate_node_opts = + gate_node.GetOptions(); + gate_node_opts.set_empty_packets_as_allow(true); + disallow >> gate_node.SideIn("DISALLOW"); + return gate_node; + } + + api2::builder::GenericNode& node_; + int index_ = 0; +}; + +// Updates graph to drop @value stream packet if corresponding @condition stream +// packet holds true. +template +api2::builder::Source DisallowIf(api2::builder::Source value, + api2::builder::Source condition, + api2::builder::Graph& graph) { + return DisallowGate(condition, graph).Disallow(value); +} + +// Updates graph to drop @value stream packet if corresponding @condition stream +// packet holds true. +template +api2::builder::Source DisallowIf(api2::builder::Source value, + api2::builder::SideSource condition, + api2::builder::Graph& graph) { + return DisallowGate(condition, graph).Disallow(value); +} + +// Updates graph to pass through @value stream packet if corresponding +// @allow stream packet holds true. +template +api2::builder::Source AllowIf(api2::builder::Source value, + api2::builder::Source allow, + api2::builder::Graph& graph) { + return AllowGate(allow, graph).Allow(value); +} + +// Updates graph to pass through @value stream packet if corresponding +// @allow side stream packet holds true. +template +api2::builder::Source AllowIf(api2::builder::Source value, + api2::builder::SideSource allow, + api2::builder::Graph& graph) { + return AllowGate(allow, graph).Allow(value); +} + +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ diff --git a/mediapipe/tasks/cc/components/utils/gate_test.cc b/mediapipe/tasks/cc/components/utils/gate_test.cc new file mode 100644 index 000000000..7fdca48e7 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/gate_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/utils/gate.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { +namespace { + +using ::mediapipe::api2::builder::SideSource; +using ::mediapipe::api2::builder::Source; + +TEST(DisallowGate, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source condition = graph.In("CONDITION").Cast(); + Source value1 = graph.In("VALUE_1").Cast(); + Source value2 = graph.In("VALUE_2").Cast(); + Source value3 = graph.In("VALUE_3").Cast(); + + DisallowGate gate(condition, graph); + gate.Disallow(value1).SetName("gated_stream1"); + gate.Disallow(value2).SetName("gated_stream2"); + gate.Disallow(value3).SetName("gated_stream3"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "__stream_2" + input_stream: "__stream_3" + input_stream: "DISALLOW:__stream_0" + output_stream: "gated_stream1" + output_stream: "gated_stream2" + output_stream: "gated_stream3" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE_1:__stream_1" + input_stream: "VALUE_2:__stream_2" + input_stream: "VALUE_3:__stream_3" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DisallowIf, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + Source condition = graph.In("CONDITION").Cast(); + + auto gated_stream = DisallowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "DISALLOW:__stream_0" + output_stream: "gated_stream" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DisallowIf, VerifyConfigWithSideCondition) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + SideSource condition = graph.SideIn("CONDITION").Cast(); + + auto gated_stream = DisallowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_0" + output_stream: "gated_stream" + input_side_packet: "DISALLOW:__side_packet_1" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "VALUE:__stream_0" + input_side_packet: "CONDITION:__side_packet_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowGate, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source condition = graph.In("CONDITION").Cast(); + Source value1 = graph.In("VALUE_1").Cast(); + Source value2 = graph.In("VALUE_2").Cast(); + Source value3 = graph.In("VALUE_3").Cast(); + + AllowGate gate(condition, graph); + gate.Allow(value1).SetName("gated_stream1"); + gate.Allow(value2).SetName("gated_stream2"); + gate.Allow(value3).SetName("gated_stream3"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "__stream_2" + input_stream: "__stream_3" + input_stream: "ALLOW:__stream_0" + output_stream: "gated_stream1" + output_stream: "gated_stream2" + output_stream: "gated_stream3" + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE_1:__stream_1" + input_stream: "VALUE_2:__stream_2" + input_stream: "VALUE_3:__stream_3" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowIf, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + Source condition = graph.In("CONDITION").Cast(); + + auto gated_stream = AllowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "ALLOW:__stream_0" + output_stream: "gated_stream" + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowIf, VerifyConfigWithSideConition) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + SideSource condition = graph.SideIn("CONDITION").Cast(); + + auto gated_stream = AllowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_0" + output_stream: "gated_stream" + input_side_packet: "ALLOW:__side_packet_1" + } + input_stream: "VALUE:__stream_0" + input_side_packet: "CONDITION:__side_packet_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 38030c525..8d19227f1 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -23,6 +23,7 @@ cc_library( srcs = ["base_options.cc"], hdrs = ["base_options.h"], deps = [ + ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", @@ -50,6 +51,21 @@ cc_library( ], ) +cc_library( + name = "mediapipe_builtin_op_resolver", + srcs = ["mediapipe_builtin_op_resolver.cc"], + hdrs = ["mediapipe_builtin_op_resolver.h"], + deps = [ + "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", + "//mediapipe/util/tflite/operations:max_pool_argmax", + "//mediapipe/util/tflite/operations:max_unpooling", + "//mediapipe/util/tflite/operations:transform_landmarks", + "//mediapipe/util/tflite/operations:transform_tensor_bilinear", + "//mediapipe/util/tflite/operations:transpose_conv_bias", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + # TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator # supports TFLite-in-GMSCore. cc_library( diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 67a03385b..4717ea50e 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" @@ -63,7 +64,7 @@ struct BaseOptions { // A non-default OpResolver to support custom Ops or specify a subset of // built-in Ops. std::unique_ptr op_resolver = - absl::make_unique(); + absl::make_unique(); }; // Converts a BaseOptions to a BaseOptionsProto. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc similarity index 87% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index cd3b5690f..62898a005 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" @@ -21,14 +21,11 @@ limitations under the License. #include "mediapipe/util/tflite/operations/transform_landmarks.h" #include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h" #include "mediapipe/util/tflite/operations/transpose_conv_bias.h" -#include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() - : BuiltinOpResolver() { +namespace core { +MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("MaxPoolingWithArgmax2D", mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); AddCustom("MaxUnpooling2D", @@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), /*version=*/2); } - -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h similarity index 65% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h index a0538a674..a7c28aa71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h @@ -13,25 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ +#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ #include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -class SelfieSegmentationModelOpResolver +namespace core { +class MediaPipeBuiltinOpResolver : public tflite::ops::builtin::BuiltinOpResolver { public: - SelfieSegmentationModelOpResolver(); - SelfieSegmentationModelOpResolver( - const SelfieSegmentationModelOpResolver& r) = delete; + MediaPipeBuiltinOpResolver(); + MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete; }; -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 23cf5f72d..c87cc50a6 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -18,18 +18,6 @@ package(default_visibility = [ licenses(["notice"]) -cc_library( - name = "hand_detector_op_resolver", - srcs = ["hand_detector_op_resolver.cc"], - hdrs = ["hand_detector_op_resolver.h"], - deps = [ - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - cc_library( name = "hand_detector_graph", srcs = ["hand_detector_graph.cc"], diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index a2fbd7c54..3fa97664e 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -35,11 +35,11 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -121,8 +121,8 @@ absl::StatusOr> CreateTaskRunner( hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> graph[Output>(kHandNormRectsTag)]; - return TaskRunner::Create(graph.GetConfig(), - absl::make_unique()); + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); } HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc deleted file mode 100644 index 262fb2c75..000000000 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" - -#include "mediapipe/util/tflite/operations/max_pool_argmax.h" -#include "mediapipe/util/tflite/operations/max_unpooling.h" -#include "mediapipe/util/tflite/operations/transpose_conv_bias.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -HandDetectorOpResolver::HandDetectorOpResolver() { - AddCustom("MaxPoolingWithArgmax2D", - mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); - AddCustom("MaxUnpooling2D", - mediapipe::tflite_operations::RegisterMaxUnpooling2D()); - AddCustom("Convolution2DTransposeBias", - mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()); -} -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD index bb5b86212..9e2d9bd17 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD @@ -54,10 +54,10 @@ cc_library( "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc index e124d3410..247d8453d 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc @@ -27,9 +27,9 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -49,6 +49,7 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: HandGestureRecognizerSubgraphOptions; using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; @@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { auto inference_output_tensors = inference.Out(kTensorsTag); auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, graph_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, graph_options.classifier_options(), + &postprocessing + .GetOptions())); inference_output_tensors >> postprocessing.In(kTensorsTag); auto classification_result = postprocessing[Output("CLASSIFICATION_RESULT")]; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD index f3927727e..44ec611b2 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD @@ -26,7 +26,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) @@ -37,7 +37,5 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto index f73443eaf..d8ee95037 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.hand_gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message HandGestureRecognizerSubgraphOptions { @@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions { // Options for configuring the gesture classifier behavior, such as score // threshold, number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // considered tracked successfully diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD new file mode 100644 index 000000000..dea81bae3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -0,0 +1,49 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/app/xeno:__subpackages__", + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "hand_association_calculator_proto", + srcs = ["hand_association_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "hand_association_calculator", + srcs = ["hand_association_calculator.cc"], + deps = [ + ":hand_association_calculator_cc_proto", + "//mediapipe/calculators/util:association_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:rectangle", + "//mediapipe/framework/port:status", + "//mediapipe/util:rectangle_util", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc new file mode 100644 index 000000000..b6df80588 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -0,0 +1,125 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/rectangle.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" +#include "mediapipe/util/rectangle_util.h" + +namespace mediapipe::api2 { + +// HandAssociationCalculator accepts multiple inputs of vectors of +// NormalizedRect. The output is a vector of NormalizedRect that contains +// rects from the input vectors that don't overlap with each other. When two +// rects overlap, the rect that comes in from an earlier input stream is +// kept in the output. If a rect has no ID (i.e. from detection stream), +// then a unique rect ID is assigned for it. + +// The rects in multiple input streams are effectively flattened to a single +// list. For example: +// Stream1 : rect 1, rect 2 +// Stream2: rect 3, rect 4 +// Stream3: rect 5, rect 6 +// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6 +// In the flattened list, if a rect with a higher index overlaps with a rect a +// lower index, beyond a specified IOU threshold, the rect with the lower +// index will be in the output, and the rect with higher index will be +// discarded. +// TODO: Upgrade this to latest API for calculators +class HandAssociationCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + // Initialize input and output streams. + for (auto& input_stream : cc->Inputs()) { + input_stream.Set>(); + } + cc->Outputs().Index(0).Set>(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + CHECK_GT(options_.min_similarity_threshold(), 0.0); + CHECK_LE(options_.min_similarity_threshold(), 1.0); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + ASSIGN_OR_RETURN(auto result, GetNonOverlappingElements(cc)); + + auto output = + std::make_unique>(std::move(result)); + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + + private: + HandAssociationCalculatorOptions options_; + + // Return a list of non-overlapping elements from all input streams, with + // decreasing order of priority based on input stream index and indices + // within an input stream. + absl::StatusOr> GetNonOverlappingElements( + CalculatorContext* cc) { + std::vector result; + + for (const auto& input_stream : cc->Inputs()) { + if (input_stream.IsEmpty()) { + continue; + } + + for (auto rect : input_stream.Get>()) { + ASSIGN_OR_RETURN( + bool is_overlapping, + mediapipe::DoesRectOverlap(rect, result, + options_.min_similarity_threshold())); + if (!is_overlapping) { + if (!rect.has_rect_id()) { + rect.set_rect_id(GetNextRectId()); + } + result.push_back(rect); + } + } + } + + return result; + } + + private: + // Each NormalizedRect processed by the calculator will be assigned + // an unique id, if it does not already have an ID. The starting ID will be 1. + // Note: This rect_id_ is local to an instance of this calculator. And it is + // expected that the hand tracking graph to have only one instance of + // this association calculator. + int64 rect_id_ = 1; + + inline int GetNextRectId() { return rect_id_++; } +}; + +MEDIAPIPE_REGISTER_NODE(HandAssociationCalculator); + +} // namespace mediapipe::api2 diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto similarity index 52% rename from mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h rename to mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto index a55661fa3..e7229b4a2 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto @@ -13,22 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ +syntax = "proto2"; -#include "tensorflow/lite/kernels/register.h" +package mediapipe; -namespace mediapipe { -namespace tasks { -namespace vision { -class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver { - public: - HandDetectorOpResolver(); - HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete; -}; +import "mediapipe/framework/calculator.proto"; -} // namespace vision -} // namespace tasks -} // namespace mediapipe +message HandAssociationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional HandAssociationCalculatorOptions ext = 408244367; + } -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ + optional float min_similarity_threshold = 1 [default = 1.0]; +} diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc new file mode 100644 index 000000000..cb3130854 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -0,0 +1,302 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { +namespace { + +class HandAssociationCalculatorTest : public testing::Test { + protected: + HandAssociationCalculatorTest() { + // 0.4 ================ + // | | | | + // 0.3 ===================== | NR2 | | + // | | | NR1 | | | NR4 | + // 0.2 | NR0 | =========== ================ + // | | | | | | + // 0.1 =====|=============== | + // | NR3 | | | + // 0.0 ================ | + // | NR5 | + // -0.1 =========== + // 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 + + // NormalizedRect nr_0. + nr_0_.set_x_center(0.2); + nr_0_.set_y_center(0.2); + nr_0_.set_width(0.2); + nr_0_.set_height(0.2); + + // NormalizedRect nr_1. + nr_1_.set_x_center(0.4); + nr_1_.set_y_center(0.2); + nr_1_.set_width(0.2); + nr_1_.set_height(0.2); + + // NormalizedRect nr_2. + nr_2_.set_x_center(1.0); + nr_2_.set_y_center(0.3); + nr_2_.set_width(0.2); + nr_2_.set_height(0.2); + + // NormalizedRect nr_3. + nr_3_.set_x_center(0.35); + nr_3_.set_y_center(0.15); + nr_3_.set_width(0.3); + nr_3_.set_height(0.3); + + // NormalizedRect nr_4. + nr_4_.set_x_center(1.1); + nr_4_.set_y_center(0.3); + nr_4_.set_width(0.2); + nr_4_.set_height(0.2); + + // NormalizedRect nr_5. + nr_5_.set_x_center(0.5); + nr_5_.set_y_center(0.05); + nr_5_.set_width(0.3); + nr_5_.set_height(0.3); + } + + NormalizedRect nr_0_, nr_1_, nr_2_, nr_3_, nr_4_, nr_5_; +}; + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_0, nr_1, nr_2. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_0_); + input_vec_0->push_back(nr_1_); + input_vec_0->push_back(nr_2_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_3, nr_4. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_3_); + input_vec_1->push_back(nr_4_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_5. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_5_); + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_0 is added 1st. + // nr_1 is added because it does not overlap with nr_0. + // nr_2 is added because it does not overlap with nr_0 or nr_1. + // nr_3 is NOT added because it overlaps with nr_0. + // nr_4 is NOT added because it overlaps with nr_2. + // nr_5 is NOT added because it overlaps with nr_1. + EXPECT_EQ(3, assoc_rects.size()); + + // Check that IDs are filled in and contents match. + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), 2); + assoc_rects[1].clear_rect_id(); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 3); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_0, nr_1. Tracked hands. + auto input_vec_0 = std::make_unique>(); + // Setting ID to a negative number for test only, since newly generated + // ID by HandAssociationCalculator are positive numbers. + nr_0_.set_rect_id(-2); + input_vec_0->push_back(nr_0_); + nr_1_.set_rect_id(-1); + input_vec_0->push_back(nr_1_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_2, nr_3. Newly detected palms. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_2_); + input_vec_1->push_back(nr_3_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_0 is added 1st. + // nr_1 is added because it does not overlap with nr_0. + // nr_2 is added because it does not overlap with nr_0 or nr_1. + // nr_3 is NOT added because it overlaps with nr_0. + EXPECT_EQ(3, assoc_rects.size()); + + // Check that IDs are filled in and contents match. + EXPECT_EQ(assoc_rects[0].rect_id(), -2); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), -1); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 1); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_5. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_5_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_4, nr_3 + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_4_); + input_vec_1->push_back(nr_3_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_2, nr_1, nr_0. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_2_); + input_vec_2->push_back(nr_1_); + input_vec_2->push_back(nr_0_); + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_5 is added 1st. + // nr_4 is added because it does not overlap with nr_5. + // nr_3 is NOT added because it overlaps with nr_5. + // nr_2 is NOT added because it overlaps with nr_4. + // nr_1 is NOT added because it overlaps with nr_5. + // nr_0 is added because it does not overlap with nr_5 or nr_4. + EXPECT_EQ(3, assoc_rects.size()); + + // Outputs are in same order as inputs, and IDs are filled in. + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), 2); + assoc_rects[1].clear_rect_id(); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 3); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Just one input stream : nr_3, nr_5. + auto input_vec = std::make_unique>(); + input_vec->push_back(nr_3_); + input_vec->push_back(nr_5_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_3 is added 1st. + // nr_5 is NOT added because it overlaps with nr_3. + EXPECT_EQ(1, assoc_rects.size()); + + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 4dcecdbbe..dfa77cb96 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -26,14 +26,14 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", ], alwayslink = 1, @@ -50,9 +50,9 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:utils", @@ -61,7 +61,7 @@ cc_library( "//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index eb74c3d98..0338b2ee2 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -26,9 +26,9 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -36,11 +36,12 @@ limitations under the License. #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" -#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { namespace { @@ -52,12 +53,11 @@ constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectName[] = "norm_rect_in"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.ImageClassifierGraph"; + "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; -using ImageClassifierOptionsProto = - image_classifier::proto::ImageClassifierOptions; // Builds a NormalizedRect covering the entire image. NormalizedRect BuildFullImageNormRect() { @@ -70,17 +70,17 @@ NormalizedRect BuildFullImageNormRect() { } // Creates a MediaPipe graph config that contains a subgraph node of -// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the -// live stream mode, a "FlowLimiterCalculator" will be added to limit the number -// of frames in flight. +// type "ImageClassifierGraph". If the task is running in the live stream mode, +// a "FlowLimiterCalculator" will be added to limit the number of frames in +// flight. CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options_proto, + std::unique_ptr options_proto, bool enable_flow_limiting) { api2::builder::Graph graph; graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName); - task_subgraph.GetOptions().Swap( + task_subgraph.GetOptions().Swap( options_proto.get()); task_subgraph.Out(kClassificationResultTag) .SetName(kClassificationResultStreamName) >> @@ -98,18 +98,18 @@ CalculatorGraphConfig CreateGraphConfig( } // Converts the user-facing ImageClassifierOptions struct to the internal -// ImageClassifierOptions proto. -std::unique_ptr +// ImageClassifierGraphOptions proto. +std::unique_ptr ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { - auto options_proto = std::make_unique(); + auto options_proto = std::make_unique(); auto base_options_proto = std::make_unique( tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); @@ -145,7 +145,7 @@ absl::StatusOr> ImageClassifier::Create( }; } return core::VisionTaskApiFactory::Create( + proto::ImageClassifierGraphOptions>( CreateGraphConfig( std::move(options_proto), options->running_mode == core::RunningMode::LIVE_STREAM), @@ -214,6 +214,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 2fbac71b2..24f36017a 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -32,6 +32,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { // The options for configuring a Mediapipe image classifier task. struct ImageClassifierOptions { @@ -50,12 +51,14 @@ struct ImageClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function, const Image&, int64)> + std::function, + const Image&, int64)> result_callback = nullptr; }; @@ -112,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. // TODO: describe exact preprocessing steps once // YUVToImageCalculator is integrated. - absl::StatusOr Classify( + absl::StatusOr Classify( mediapipe::Image image, std::optional roi = std::nullopt); @@ -126,9 +129,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - absl::StatusOr ClassifyForVideo( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::StatusOr + ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, + std::optional roi = std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. @@ -161,6 +164,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { absl::Status Close() { return runner_->Close(); } }; +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 532b7db45..9a0078c5c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -22,18 +22,19 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { namespace { @@ -42,8 +43,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ImageClassifierOptionsProto = - image_classifier::proto::ImageClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -61,8 +61,7 @@ struct ImageClassifierOutputStreams { } // namespace -// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image -// classification. +// An "ImageClassifierGraph" performs image classification. // - Accepts CPU input images and outputs classifications on CPU. // // Inputs: @@ -80,12 +79,12 @@ struct ImageClassifierOutputStreams { // // Example: // node { -// calculator: "mediapipe.tasks.vision.ImageClassifierGraph" +// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph" // input_stream: "IMAGE:image_in" // output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext] +// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext] // { // base_options { // model_asset { @@ -104,13 +103,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( auto output_streams, BuildImageClassificationTask( - sc->Options(), *model_resources, + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); output_streams.classification_result >> @@ -125,13 +125,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // (mediapipe::Image) as input and returns one classification result per input // image. // - // task_options: the mediapipe tasks ImageClassifierOptions. + // task_options: the mediapipe tasks ImageClassifierGraphOptions. // model_resources: the ModelSources object initialized from an image // classification model file with model metadata. // image_in: (mediapipe::Image) stream to run classification on. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr BuildImageClassificationTask( - const ImageClassifierOptionsProto& task_options, + const proto::ImageClassifierGraphOptions& task_options, const core::ModelResources& model_resources, Source image_in, Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image @@ -153,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the aggregated classification result as the subgraph output @@ -168,8 +171,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph { /*image=*/preprocessing[Output(kImageTag)]}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph); +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 7cf6414bf..070a5a034 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -32,8 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -44,9 +44,13 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; using ::testing::HasSubstr; using ::testing::Optional; @@ -814,6 +818,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { } } // namespace +} // namespace image_classifier } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD index dc8241799..29638bebd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -19,12 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "image_classifier_options_proto", - srcs = ["image_classifier_options.proto"], + name = "image_classifier_graph_options_proto", + srcs = ["image_classifier_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto similarity index 82% rename from mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto rename to mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 8aa8b4615..b307a66b6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,12 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message ImageClassifierOptions { +message ImageClassifierGraphOptions { extend mediapipe.CalculatorOptions { - optional ImageClassifierOptions ext = 456383383; + optional ImageClassifierGraphOptions ext = 456383383; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. @@ -31,5 +31,5 @@ message ImageClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 6af733657..6bdbf41da 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -33,7 +33,6 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) @@ -73,19 +72,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "image_segmenter_op_resolvers", - srcs = ["image_segmenter_op_resolvers.cc"], - hdrs = ["image_segmenter_op_resolvers.h"], - deps = [ - "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transform_landmarks", - "//mediapipe/util/tflite/operations:transform_tensor_bilinear", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index ce9cb104c..e2734c4e4 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,7 +26,6 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 2f1c26a79..1d3f3e786 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::SOFTMAX; @@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 65c1214af..8f3c1539c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -11,3 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "category", + srcs = ["Category.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "detection", + srcs = ["Detection.java"], + deps = [ + ":category", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java new file mode 100644 index 000000000..3b7c41fbe --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Category.java @@ -0,0 +1,86 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Category is a util class, contains a category name, its display name, a float value as score, and + * the index of the label in the corresponding label file. Typically it's used as result of + * classification or detection tasks. + */ +@AutoValue +public abstract class Category { + private static final float TOLERANCE = 1e-6f; + + /** + * Creates a {@link Category} instance. + * + * @param score the probability score of this label category. + * @param index the index of the label in the corresponding label file. + * @param categoryName the label of this category object. + * @param displayName the display name of the label. + */ + public static Category create(float score, int index, String categoryName, String displayName) { + return new AutoValue_Category(score, index, categoryName, displayName); + } + + /** The probability score of this label category. */ + public abstract float score(); + + /** The index of the label in the corresponding label file. Returns -1 if the index is not set. */ + public abstract int index(); + + /** The label of this category object. */ + public abstract String categoryName(); + + /** + * The display name of the label, which may be translated for different locales. For example, a + * label, "apple", may be translated into Spanish for display purpose, so that the display name is + * "manzana". + */ + public abstract String displayName(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof Category)) { + return false; + } + Category other = (Category) o; + return Math.abs(other.score() - this.score()) < TOLERANCE + && other.index() == this.index() + && other.categoryName().equals(this.categoryName()) + && other.displayName().equals(this.displayName()); + } + + @Override + public final int hashCode() { + return Objects.hash(categoryName(), displayName(), score(), index()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Detection.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Detection.java new file mode 100644 index 000000000..c02c8025a --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Detection.java @@ -0,0 +1,50 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import android.graphics.RectF; +import com.google.auto.value.AutoValue; +import java.util.Collections; +import java.util.List; + +/** + * Represents one detected object in the results of {@link + * com.google.mediapipe.tasks.version.objectdetector.ObjectDetector}. + */ +@AutoValue +public abstract class Detection { + + /** + * Creates a {@link Detection} instance from a list of {@link Category} and a bounding box. + * + * @param categories a list of {@link Category} objects that contain category name, display name, + * score, and the label index. + * @param boundingBox a {@link RectF} object to represent the bounding box. + */ + public static Detection create(List categories, RectF boundingBox) { + + // As an open source project, we've been trying avoiding depending on common java libraries, + // such as Guava, because it may introduce conflicts with clients who also happen to use those + // libraries. Therefore, instead of using ImmutableList here, we convert the List into + // unmodifiableList + return new AutoValue_Detection(Collections.unmodifiableList(categories), boundingBox); + } + + /** A list of {@link Category} objects. */ + public abstract List categories(); + + /** A {@link RectF} object to represent the bounding box of the detected object. */ + public abstract RectF boundingBox(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index 65c1214af..8da7b8561 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -11,3 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +android_library( + name = "core", + srcs = glob(["*.java"]), + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:autovalue", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java new file mode 100644 index 000000000..d28946736 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BaseOptions.java @@ -0,0 +1,96 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import com.google.auto.value.AutoValue; +import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; +import java.util.Optional; + +/** Options to configure MediaPipe Tasks in general. */ +@AutoValue +public abstract class BaseOptions { + /** Builder for {@link BaseOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets the model path to a tflite model with metadata in the assets. + * + *

Note: when model path is set, both model file descriptor and model buffer should be empty. + */ + public abstract Builder setModelAssetPath(String value); + + /** + * Sets the native fd int of a tflite model with metadata. + * + *

Note: when model file descriptor is set, both model path and model buffer should be empty. + */ + public abstract Builder setModelAssetFileDescriptor(Integer value); + + /** + * Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model + * with metadata. + * + *

Note: when model buffer is set, both model file and model file descriptor should be empty. + */ + public abstract Builder setModelAssetBuffer(ByteBuffer value); + + /** + * Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default + * delegate CPU is used. + */ + public abstract Builder setDelegate(Delegate delegate); + + abstract BaseOptions autoBuild(); + + /** + * Validates and builds the {@link BaseOptions} instance. + * + * @throws IllegalArgumentException if {@link BaseOptions} is invalid, or the provided model + * buffer is not a direct {@link ByteBuffer} or a {@link MappedByteBuffer}. + */ + public final BaseOptions build() { + BaseOptions options = autoBuild(); + int modelAssetPathPresent = options.modelAssetPath().isPresent() ? 1 : 0; + int modelAssetFileDescriptorPresent = options.modelAssetFileDescriptor().isPresent() ? 1 : 0; + int modelAssetBufferPresent = options.modelAssetBuffer().isPresent() ? 1 : 0; + + if (modelAssetPathPresent + modelAssetFileDescriptorPresent + modelAssetBufferPresent != 1) { + throw new IllegalArgumentException( + "Please specify only one of the model asset path, the model asset file descriptor, and" + + " the model asset buffer."); + } + if (options.modelAssetBuffer().isPresent() + && !(options.modelAssetBuffer().get().isDirect() + || options.modelAssetBuffer().get() instanceof MappedByteBuffer)) { + throw new IllegalArgumentException( + "The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + return options; + } + } + + abstract Optional modelAssetPath(); + + abstract Optional modelAssetFileDescriptor(); + + abstract Optional modelAssetBuffer(); + + abstract Delegate delegate(); + + public static Builder builder() { + return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/Delegate.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/Delegate.java new file mode 100644 index 000000000..84bf7270f --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/Delegate.java @@ -0,0 +1,22 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +/** MediaPipe Tasks delegate. */ +// TODO implement advanced delegate setting. +public enum Delegate { + CPU, + GPU, +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ErrorListener.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ErrorListener.java new file mode 100644 index 000000000..3f62d5d4f --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ErrorListener.java @@ -0,0 +1,20 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +/** Interface for the customizable MediaPipe task error listener. */ +public interface ErrorListener { + void onError(RuntimeException e); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ModelResourcesCache.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ModelResourcesCache.java new file mode 100644 index 000000000..4e81a3805 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ModelResourcesCache.java @@ -0,0 +1,49 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** Facilitates creation and destruction of the native ModelResourcesCache. */ +class ModelResourcesCache { + private final long nativeHandle; + private final AtomicBoolean isHandleValid; + + public ModelResourcesCache() { + nativeHandle = nativeCreateModelResourcesCache(); + isHandleValid = new AtomicBoolean(true); + } + + public boolean isHandleValid() { + return isHandleValid.get(); + } + + public long getNativeHandle() { + if (isHandleValid.get()) { + return nativeHandle; + } + return 0; + } + + public void release() { + if (isHandleValid.compareAndSet(true, false)) { + nativeReleaseModelResourcesCache(nativeHandle); + } + } + + private native long nativeCreateModelResourcesCache(); + + private native void nativeReleaseModelResourcesCache(long nativeHandle); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ModelResourcesCacheService.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ModelResourcesCacheService.java new file mode 100644 index 000000000..2cf2f096b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/ModelResourcesCacheService.java @@ -0,0 +1,29 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import com.google.mediapipe.framework.GraphService; + +/** Java wrapper for graph service of ModelResourcesCacheService. */ +class ModelResourcesCacheService implements GraphService { + public ModelResourcesCacheService() {} + + @Override + public void installServiceObject(long context, ModelResourcesCache object) { + nativeInstallServiceObject(context, object.getNativeHandle()); + } + + public native void nativeInstallServiceObject(long context, long object); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java new file mode 100644 index 000000000..3fa7c2bcc --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java @@ -0,0 +1,130 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import android.util.Log; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import java.util.List; + +/** Base class for handling MediaPipe task graph outputs. */ +public class OutputHandler { + /** + * Interface for converting MediaPipe graph output {@link Packet}s to task result object and task + * input object. + */ + public interface OutputPacketConverter { + OutputT convertToTaskResult(List packets); + + InputT convertToTaskInput(List packets); + } + + /** Interface for the customizable MediaPipe task result listener. */ + public interface ResultListener { + void run(OutputT result, InputT input); + } + + private static final String TAG = "OutputHandler"; + // A task-specific graph output packet converter that should be implemented per task. + private OutputPacketConverter outputPacketConverter; + // The user-defined task result listener. + private ResultListener resultListener; + // The user-defined error listener. + protected ErrorListener errorListener; + // The cached task result for non latency sensitive use cases. + protected OutputT cachedTaskResult; + // Whether the output handler should react to timestamp-bound changes by outputting empty packets. + private boolean handleTimestampBoundChanges = false; + + /** + * Sets a callback to be invoked to convert a {@link Packet} list to a task result object and a + * task input object. + * + * @param converter the task-specific {@link OutputPacketConverter} callback. + */ + public void setOutputPacketConverter(OutputPacketConverter converter) { + this.outputPacketConverter = converter; + } + + /** + * Sets a callback to be invoked when task result objects become available. + * + * @param listener the user-defined {@link ResultListener} callback. + */ + public void setResultListener(ResultListener listener) { + this.resultListener = listener; + } + + /** + * Sets a callback to be invoked when exceptions are thrown from the task graph. + * + * @param listener The user-defined {@link ErrorListener} callback. + */ + public void setErrorListener(ErrorListener listener) { + this.errorListener = listener; + } + + /** + * Sets whether the output handler should react to the timestamp bound changes that are reprsented + * as empty output {@link Packet}s. + * + * @param handleTimestampBoundChanges A boolean value. + */ + public void setHandleTimestampBoundChanges(boolean handleTimestampBoundChanges) { + this.handleTimestampBoundChanges = handleTimestampBoundChanges; + } + + /** Returns true if the task graph is set to handle timestamp bound changes. */ + boolean handleTimestampBoundChanges() { + return handleTimestampBoundChanges; + } + + /* Returns the cached task result object. */ + public OutputT retrieveCachedTaskResult() { + OutputT taskResult = cachedTaskResult; + cachedTaskResult = null; + return taskResult; + } + + /** + * Handles a list of output {@link Packet}s. Invoked when a packet list become available. + * + * @param packets A list of output {@link Packet}s. + */ + void run(List packets) { + OutputT taskResult = null; + try { + taskResult = outputPacketConverter.convertToTaskResult(packets); + if (resultListener == null) { + cachedTaskResult = taskResult; + } else { + InputT taskInput = outputPacketConverter.convertToTaskInput(packets); + resultListener.run(taskResult, taskInput); + } + } catch (MediaPipeException e) { + if (errorListener != null) { + errorListener.onError(e); + } else { + Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e); + } + } finally { + for (Packet packet : packets) { + if (packet != null) { + packet.release(); + } + } + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java new file mode 100644 index 000000000..12f8be8ba --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskInfo.java @@ -0,0 +1,156 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node; +import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo; +import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions; +import java.util.ArrayList; +import java.util.List; + +/** + * {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link + * com.google.mediapipe.framework.Graph}. + */ +@AutoValue +public abstract class TaskInfo { + /** Builder for {@link TaskInfo}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the MediaPipe task graph name. */ + public abstract Builder setTaskGraphName(String value); + + /** Sets a list of task graph input stream info {@link String}s in the form TAG:name. */ + public abstract Builder setInputStreams(List value); + + /** Sets a list of task graph output stream info {@link String}s in the form TAG:name. */ + public abstract Builder setOutputStreams(List value); + + /** Sets to true if the task requires a flow limiter. */ + public abstract Builder setEnableFlowLimiting(Boolean value); + + /** + * Sets a task-specific options instance. + * + * @param value a task-specific options that is derived from {@link TaskOptions}. + */ + public abstract Builder setTaskOptions(T value); + + public abstract TaskInfo autoBuild(); + + /** + * Validates and builds the {@link TaskInfo} instance. * + * + * @throws IllegalArgumentException if the required information such as task graph name, graph + * input streams, and the graph output streams are empty. + */ + public final TaskInfo build() { + TaskInfo taskInfo = autoBuild(); + if (taskInfo.taskGraphName().isEmpty() + || taskInfo.inputStreams().isEmpty() + || taskInfo.outputStreams().isEmpty()) { + throw new IllegalArgumentException( + "Task graph's name, input streams, and output streams should be non-empty."); + } + return taskInfo; + } + } + + abstract String taskGraphName(); + + abstract T taskOptions(); + + abstract List inputStreams(); + + abstract List outputStreams(); + + abstract Boolean enableFlowLimiting(); + + public static Builder builder() { + return new AutoValue_TaskInfo.Builder(); + } + + /* Returns a list of the output stream names without the stream tags. */ + List outputStreamNames() { + List streamNames = new ArrayList<>(outputStreams().size()); + for (String stream : outputStreams()) { + streamNames.add(stream.substring(stream.lastIndexOf(':') + 1)); + } + return streamNames; + } + + /** + * Creates a MediaPipe Task {@link CalculatorGraphConfig} protobuf message from the {@link + * TaskInfo} instance. + */ + CalculatorGraphConfig generateGraphConfig() { + CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder(); + Node.Builder taskSubgraphBuilder = + Node.newBuilder() + .setCalculator(taskGraphName()) + .setOptions(taskOptions().convertToCalculatorOptionsProto()); + for (String outputStream : outputStreams()) { + taskSubgraphBuilder.addOutputStream(outputStream); + graphBuilder.addOutputStream(outputStream); + } + if (!enableFlowLimiting()) { + for (String inputStream : inputStreams()) { + taskSubgraphBuilder.addInputStream(inputStream); + graphBuilder.addInputStream(inputStream); + } + graphBuilder.addNode(taskSubgraphBuilder.build()); + return graphBuilder.build(); + } + Node.Builder flowLimiterCalculatorBuilder = + Node.newBuilder() + .setCalculator("FlowLimiterCalculator") + .addInputStreamInfo( + InputStreamInfo.newBuilder().setTagIndex("FINISHED").setBackEdge(true).build()) + .setOptions( + CalculatorOptions.newBuilder() + .setExtension( + FlowLimiterCalculatorOptions.ext, + FlowLimiterCalculatorOptions.newBuilder() + .setMaxInFlight(1) + .setMaxInQueue(1) + .build()) + .build()); + for (String inputStream : inputStreams()) { + graphBuilder.addInputStream(inputStream); + flowLimiterCalculatorBuilder.addInputStream(stripTagIndex(inputStream)); + String taskInputStream = addStreamNamePrefix(inputStream); + flowLimiterCalculatorBuilder.addOutputStream(stripTagIndex(taskInputStream)); + taskSubgraphBuilder.addInputStream(taskInputStream); + } + flowLimiterCalculatorBuilder.addInputStream( + "FINISHED:" + stripTagIndex(outputStreams().get(0))); + graphBuilder.addNode(flowLimiterCalculatorBuilder.build()); + graphBuilder.addNode(taskSubgraphBuilder.build()); + return graphBuilder.build(); + } + + private String stripTagIndex(String tagIndexName) { + return tagIndexName.substring(tagIndexName.lastIndexOf(':') + 1); + } + + private String addStreamNamePrefix(String tagIndexName) { + return tagIndexName.substring(0, tagIndexName.lastIndexOf(':') + 1) + + "throttled_" + + stripTagIndex(tagIndexName); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java new file mode 100644 index 000000000..9bf600360 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskOptions.java @@ -0,0 +1,75 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.google.mediapipe.tasks.core; + +import com.google.mediapipe.calculator.proto.InferenceCalculatorProto; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.tasks.core.proto.AccelerationProto; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.core.proto.ExternalFileProto; +import com.google.protobuf.ByteString; + +/** + * MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend + * {@link TaskOptions}. + */ +public abstract class TaskOptions { + /** + * Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf + * message. + */ + public abstract CalculatorOptions convertToCalculatorOptionsProto(); + + /** + * Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf + * message. + */ + protected BaseOptionsProto.BaseOptions convertBaseOptionsToProto(BaseOptions options) { + ExternalFileProto.ExternalFile.Builder externalFileBuilder = + ExternalFileProto.ExternalFile.newBuilder(); + options.modelAssetPath().ifPresent(externalFileBuilder::setFileName); + options + .modelAssetFileDescriptor() + .ifPresent( + fd -> + externalFileBuilder.setFileDescriptorMeta( + ExternalFileProto.FileDescriptorMeta.newBuilder().setFd(fd).build())); + options + .modelAssetBuffer() + .ifPresent( + modelBuffer -> { + modelBuffer.rewind(); + externalFileBuilder.setFileContent(ByteString.copyFrom(modelBuffer)); + }); + AccelerationProto.Acceleration.Builder accelerationBuilder = + AccelerationProto.Acceleration.newBuilder(); + switch (options.delegate()) { + case CPU: + accelerationBuilder.setXnnpack( + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack + .getDefaultInstance()); + break; + case GPU: + accelerationBuilder.setGpu( + InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.getDefaultInstance()); + break; + } + return BaseOptionsProto.BaseOptions.newBuilder() + .setModelAsset(externalFileBuilder.build()) + .setAcceleration(accelerationBuilder.build()) + .build(); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskResult.java new file mode 100644 index 000000000..03b11e877 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskResult.java @@ -0,0 +1,24 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +/** + * Interface for the MediaPipe Task result. Any MediaPipe task-specific result class should + * implement {@link TaskResult}. + */ +public interface TaskResult { + /** Returns the timestamp that is associated with the task result object. */ + long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java new file mode 100644 index 000000000..5739edebe --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -0,0 +1,265 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import android.content.Context; +import android.util.Log; +import com.google.mediapipe.framework.AndroidAssetUtil; +import com.google.mediapipe.framework.AndroidPacketCreator; +import com.google.mediapipe.framework.Graph; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +/** The runner of MediaPipe task graphs. */ +public class TaskRunner implements AutoCloseable { + private static final String TAG = TaskRunner.class.getSimpleName(); + private static final long TIMESATMP_UNITS_PER_SECOND = 1000000; + + private final OutputHandler outputHandler; + private final AtomicBoolean graphStarted = new AtomicBoolean(false); + private final Graph graph; + private final ModelResourcesCache modelResourcesCache; + private final AndroidPacketCreator packetCreator; + private long lastSeenTimestamp = Long.MIN_VALUE; + private ErrorListener errorListener; + + /** + * Create a {@link TaskRunner} instance. + * + * @param context an Android {@link Context}. + * @param taskInfo a {@link TaskInfo} instance contains task graph name, task options, and graph + * input and output stream names. + * @param outputHandler a {@link OutputHandler} instance handles task result object and runtime + * exception. + * @throws MediaPipeException for any error during {@link TaskRunner} creation. + */ + public static TaskRunner create( + Context context, + TaskInfo taskInfo, + OutputHandler outputHandler) { + AndroidAssetUtil.initializeNativeAssetManager(context); + Graph mediapipeGraph = new Graph(); + mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig()); + ModelResourcesCache graphModelResourcesCache = new ModelResourcesCache(); + mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache); + mediapipeGraph.addMultiStreamCallback( + taskInfo.outputStreamNames(), + outputHandler::run, + /*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges()); + mediapipeGraph.startRunningGraph(); + // Waits until all calculators are opened and the graph is fully started. + mediapipeGraph.waitUntilGraphIdle(); + return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler); + } + + /** + * Sets a callback to be invoked when exceptions are thrown by the {@link TaskRunner} instance. + * + * @param listener an {@link ErrorListener} callback. + */ + public void setErrorListener(ErrorListener listener) { + this.errorListener = listener; + } + + /** Returns the {@link AndroidPacketCreator} associated to the {@link TaskRunner} instance. */ + public AndroidPacketCreator getPacketCreator() { + return packetCreator; + } + + /** + * A synchronous method for processing batch data. + * + *

Note: This method is designed for processing batch data such as unrelated images and texts. + * The call blocks the current thread until a failure status or a successful result is returned. + * An internal timestamp will be assigend per invocation. This method is thread-safe and allows + * clients to call it from different threads. + * + * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. + */ + public synchronized TaskResult process(Map inputs) { + addPackets(inputs, generateSyntheticTimestamp()); + graph.waitUntilGraphIdle(); + return outputHandler.retrieveCachedTaskResult(); + } + + /** + * A synchronous method for processing offline streaming data. + * + *

Note: This method is designed for processing offline streaming data such as the decoded + * frames from a video file and an audio file. The call blocks the current thread until a failure + * status or a successful result is returned. The caller must ensure that the input timestamp is + * greater than the timestamps of previous invocations. This method is thread-unsafe and it is the + * caller's responsibility to synchronize access to this method across multiple threads and to + * ensure that the input packet timestamps are in order. + * + * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. + * @param inputTimestamp the timestamp of the input packets. + */ + public synchronized TaskResult process(Map inputs, long inputTimestamp) { + validateInputTimstamp(inputTimestamp); + addPackets(inputs, inputTimestamp); + graph.waitUntilGraphIdle(); + return outputHandler.retrieveCachedTaskResult(); + } + + /** + * An asynchronous method for handling live streaming data. + * + *

Note: This method that is designed for handling live streaming data such as live camera and + * microphone data. A user-defined packets callback function must be provided in the constructor + * to receive the output packets. The caller must ensure that the input packet timestamps are + * monotonically increasing. This method is thread-unsafe and it is the caller's responsibility to + * synchronize access to this method across multiple threads and to ensure that the input packet + * timestamps are in order. + * + * @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs. + * @param inputTimestamp the timestamp of the input packets. + */ + public synchronized void send(Map inputs, long inputTimestamp) { + validateInputTimstamp(inputTimestamp); + addPackets(inputs, inputTimestamp); + } + + /** + * Resets and restarts the {@link TaskRunner} instance. This can be useful for resetting a + * stateful task graph to process new data. + */ + public void restart() { + if (graphStarted.get()) { + try { + graphStarted.set(false); + graph.closeAllPacketSources(); + graph.waitUntilGraphDone(); + } catch (MediaPipeException e) { + reportError(e); + } + } + try { + graph.startRunningGraph(); + // Waits until all calculators are opened and the graph is fully restarted. + graph.waitUntilGraphIdle(); + graphStarted.set(true); + } catch (MediaPipeException e) { + reportError(e); + } + } + + /** Closes and cleans up the {@link TaskRunner} instance. */ + @Override + public void close() { + if (!graphStarted.get()) { + return; + } + try { + graphStarted.set(false); + graph.closeAllPacketSources(); + graph.waitUntilGraphDone(); + if (modelResourcesCache != null) { + modelResourcesCache.release(); + } + } catch (MediaPipeException e) { + // Note: errors during Process are reported at the earliest opportunity, + // which may be addPacket or waitUntilDone, depending on timing. For consistency, + // we want to always report them using the same async handler if installed. + reportError(e); + } + try { + graph.tearDown(); + } catch (MediaPipeException e) { + reportError(e); + } + } + + private synchronized void addPackets(Map inputs, long inputTimestamp) { + if (!graphStarted.get()) { + reportError( + new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "The task graph hasn't been successfully started or error occurs during graph" + + " initializaton.")); + } + try { + for (Map.Entry entry : inputs.entrySet()) { + // addConsumablePacketToInputStream allows the graph to take exclusive ownership of the + // packet, which may allow for more memory optimizations. + graph.addConsumablePacketToInputStream(entry.getKey(), entry.getValue(), inputTimestamp); + // If addConsumablePacket succeeded, we don't need to release the packet ourselves. + entry.setValue(null); + } + } catch (MediaPipeException e) { + // TODO: do not suppress exceptions here! + if (errorListener == null) { + Log.e(TAG, "Mediapipe error: ", e); + } else { + throw e; + } + } finally { + for (Packet packet : inputs.values()) { + // In case of error, addConsumablePacketToInputStream will not release the packet, so we + // have to release it ourselves. + if (packet != null) { + packet.release(); + } + } + } + } + + /** + * Checks if the input timestamp is strictly greater than the last timestamp that has been + * processed. + * + * @param inputTimestamp the input timestamp. + */ + private void validateInputTimstamp(long inputTimestamp) { + if (lastSeenTimestamp >= inputTimestamp) { + reportError( + new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "The received packets having a smaller timestamp than the processed timestamp.")); + } + lastSeenTimestamp = inputTimestamp; + } + + /** Generates a synthetic input timestamp in the batch processing mode. */ + private long generateSyntheticTimestamp() { + long timestamp = + lastSeenTimestamp == Long.MIN_VALUE ? 0 : lastSeenTimestamp + TIMESATMP_UNITS_PER_SECOND; + lastSeenTimestamp = timestamp; + return timestamp; + } + + /** Private constructor. */ + private TaskRunner( + Graph graph, + ModelResourcesCache modelResourcesCache, + OutputHandler outputHandler) { + this.outputHandler = outputHandler; + this.graph = graph; + this.modelResourcesCache = modelResourcesCache; + this.packetCreator = new AndroidPacketCreator(graph); + graphStarted.set(true); + } + + /** Reports error. */ + private void reportError(MediaPipeException e) { + if (errorListener != null) { + errorListener.onError(e); + } else { + throw e; + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD new file mode 100644 index 000000000..97f8dfd15 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD @@ -0,0 +1,41 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library_with_tflite( + name = "model_resources_cache_jni", + srcs = [ + "model_resources_cache_jni.cc", + ], + hdrs = [ + "model_resources_cache_jni.h", + ], + tflite_deps = [ + "//mediapipe/tasks/cc/core:model_resources_cache", + "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + ], + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + ] + select({ + "//conditions:default": ["//third_party/java/jdk:jni"], + "//mediapipe:android": [], + }), + alwayslink = 1, +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel new file mode 100644 index 000000000..0eb74e7ff --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD.bazel @@ -0,0 +1,72 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +cc_library_with_tflite( + name = "model_resources_cache_jni", + srcs = [ + "model_resources_cache_jni.cc", + ], + hdrs = [ + "model_resources_cache_jni.h", + ] + select({ + # The Android toolchain makes "jni.h" available in the include path. + # For non-Android toolchains, generate jni.h and jni_md.h. + "//mediapipe:android": [], + "//conditions:default": [ + ":jni.h", + ":jni_md.h", + ], + }), + tflite_deps = [ + "//mediapipe/tasks/cc/core:model_resources_cache", + ], + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", + ] + select({ + "//conditions:default": [], + "//mediapipe:android": [], + }), + alwayslink = 1, +) + +# Silly rules to make +# #include +# in the source headers work +# (in combination with the "includes" attribute of the tf_cuda_library rule +# above. Not needed when using the Android toolchain). +# +# Inspired from: +# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD +# but hopefully there is a simpler alternative to this. +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "//mediapipe:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc new file mode 100644 index 000000000..aab022dec --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc @@ -0,0 +1,51 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h" + +#include + +#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "tensorflow/lite/core/shims/cc/kernels/register.h" + +namespace { +using ::mediapipe::tasks::core::kModelResourcesCacheService; +using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver; +using ::mediapipe::tasks::core::ModelResourcesCache; +using HandleType = std::shared_ptr*; +} // namespace + +JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD( + nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) { + auto ptr = std::make_shared( + absl::make_unique()); + HandleType handle = new std::shared_ptr(std::move(ptr)); + return reinterpret_cast(handle); +} + +JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD( + nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz, + jlong nativeHandle) { + delete reinterpret_cast(nativeHandle); +} + +JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD( + nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle, + jlong objectHandle) { + mediapipe::android::GraphServiceHelper::SetServiceObject( + contextHandle, kModelResourcesCacheService, + *reinterpret_cast(objectHandle)); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h new file mode 100644 index 000000000..9b0478939 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h @@ -0,0 +1,45 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_ +#define JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define MODEL_RESOURCES_CACHE_METHOD(METHOD_NAME) \ + Java_com_google_mediapipe_tasks_core_ModelResourcesCache_##METHOD_NAME + +#define MODEL_RESOURCES_CACHE_SERVICE_METHOD(METHOD_NAME) \ + Java_com_google_mediapipe_tasks_core_ModelResourcesCacheService_##METHOD_NAME + +JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD( + nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz); + +JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD( + nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz, + jlong nativeHandle); + +JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD( + nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle, + jlong objectHandle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD index 65c1214af..94f77ea68 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD @@ -11,3 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +android_library( + name = "core", + srcs = glob(["*.java"]), + deps = [ + ":libmediapipe_tasks_vision_jni_lib", + "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "@maven//:com_google_guava_guava", + ], +) + +# The native library of all MediaPipe vision tasks. +cc_binary( + name = "libmediapipe_tasks_vision_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + ], +) + +cc_library( + name = "libmediapipe_tasks_vision_jni_lib", + srcs = [":libmediapipe_tasks_vision_jni.so"], + alwayslink = 1, +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java new file mode 100644 index 000000000..92f64e898 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -0,0 +1,114 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.core.TaskResult; +import com.google.mediapipe.tasks.core.TaskRunner; +import java.util.HashMap; +import java.util.Map; + +/** The base class of MediaPipe vision tasks. */ +public class BaseVisionTaskApi implements AutoCloseable { + private static final long MICROSECONDS_PER_MILLISECOND = 1000; + private final TaskRunner runner; + private final RunningMode runningMode; + + static { + System.loadLibrary("mediapipe_tasks_vision_jni"); + } + + /** + * Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision + * task {@link RunningMode}. + * + * @param runner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) { + this.runner = runner; + this.runningMode = runningMode; + } + + /** + * A synchronous method to process single image inputs. The call blocks the current thread until a + * failure status or a successful result is returned. + * + * @param imageStreamName the image input stream name. + * @param image a MediaPipe {@link Image} object for processing. + * @throws MediaPipeException if the task is not in the image mode. + */ + protected TaskResult processImageData(String imageStreamName, Image image) { + if (runningMode != RunningMode.IMAGE) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the image mode. Current running mode:" + + runningMode.name()); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + return runner.process(inputPackets); + } + + /** + * A synchronous method to process continuous video frames. The call blocks the current thread + * until a failure status or a successful result is returned. + * + * @param imageStreamName the image input stream name. + * @param image a MediaPipe {@link Image} object for processing. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode. + */ + protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) { + if (runningMode != RunningMode.VIDEO) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the video mode. Current running mode:" + + runningMode.name()); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + + /** + * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be + * available in the user-defined result listener. + * + * @param imageStreamName the image input stream name. + * @param image a MediaPipe {@link Image} object for processing. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode. + */ + protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) { + if (runningMode != RunningMode.LIVE_STREAM) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the live stream mode. Current running mode:" + + runningMode.name()); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + + /** Closes and cleans up the MediaPipe vision task. */ + @Override + public void close() { + runner.close(); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/RunningMode.java new file mode 100644 index 000000000..8b0ffb8fd --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/RunningMode.java @@ -0,0 +1,32 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +/** + * MediaPipe vision task running mode. A MediaPipe vision task can be run with three different + * modes: + * + *

    + *
  • IMAGE: The mode for running a mediapipe vision task on single image inputs. + *
  • VIDEO: The mode for running a mediapipe vision task on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for running a mediapipe vision task on a live stream of input data, + * such as from camera. + *
+ */ +public enum RunningMode { + IMAGE, + VIDEO, + LIVE_STREAM +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD index 65c1214af..8ba2705eb 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD @@ -11,3 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "objectdetector", + srcs = [ + "ObjectDetectionResult.java", + "ObjectDetector.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = ":AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java new file mode 100644 index 000000000..9a0c7e8f6 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java @@ -0,0 +1,75 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.objectdetector; + +import android.graphics.RectF; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.core.TaskResult; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Represents the detection results generated by {@link ObjectDetector}. */ +@AutoValue +public abstract class ObjectDetectionResult implements TaskResult { + private static final int DEFAULT_CATEGORY_INDEX = -1; + + @Override + public abstract long timestampMs(); + + public abstract List detections(); + + /** + * Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf + * messages. + * + * @param detectionList a list of {@link Detection} protobuf messages. + */ + static ObjectDetectionResult create(List detectionList, long timestampMs) { + List detections = new ArrayList<>(); + for (Detection detectionProto : detectionList) { + List categories = new ArrayList<>(); + for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) { + categories.add( + Category.create( + detectionProto.getScore(idx), + detectionProto.getLabelIdCount() > idx + ? detectionProto.getLabelId(idx) + : DEFAULT_CATEGORY_INDEX, + detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "", + detectionProto.getDisplayNameCount() > idx + ? detectionProto.getDisplayName(idx) + : "")); + } + RectF boundingBox = new RectF(); + if (detectionProto.getLocationData().hasBoundingBox()) { + BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox(); + boundingBox.set( + /*left=*/ boundingBoxProto.getXmin(), + /*top=*/ boundingBoxProto.getYmin(), + /*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(), + /*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight()); + } + detections.add( + com.google.mediapipe.tasks.components.containers.Detection.create( + categories, boundingBox)); + } + return new AutoValue_ObjectDetectionResult( + timestampMs, Collections.unmodifiableList(detections)); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java new file mode 100644 index 000000000..463ab4c43 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -0,0 +1,418 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.objectdetector; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs object detection on images. + * + *

The API expects a TFLite model with TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and four output tensors. To be more + * specific, here are the requirements. + * + *

    + *
  • Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32}) + *
      + *
    • image input of size {@code [batch x height x width x channels]}. + *
    • batch inference is not supported ({@code batch} is required to be 1). + *
    • only RGB inputs are supported ({@code channels} is required to be 3). + *
    • if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached + * to the metadata for input normalization. + *
    + *
  • Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e: + *
      + *
    • Location tensor ({@code kTfLiteFloat32}): + *
        + *
      • tensor of size {@code [1 x num_results x 4]}, the inner array representing + * bounding boxes in the form [top, left, right, bottom]. + *
      • {@code BoundingBoxProperties} are required to be attached to the metadata and + * must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}. + *
      + *
    • Classes tensor ({@code kTfLiteFloat32}): + *
        + *
      • tensor of size {@code [1 x num_results]}, each value representing the integer + * index of a class. + *
      • if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS} + * associated files, they are used to convert the tensor values into labels. + *
      + *
    • scores tensor ({@code kTfLiteFloat32}): + *
        + *
      • tensor of size {@code [1 x num_results]}, each value representing the score of + * the detected object. + *
      + *
    • Number of detection tensor ({@code kTfLiteFloat32}): + *
        + *
      • integer num_results as a tensor of size {@code [1]}. + *
      + *
    + *
+ * + *

An example of such model can be found on TensorFlow + * Hub.. + */ +public final class ObjectDetector extends BaseVisionTaskApi { + private static final String TAG = ObjectDetector.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); + private static final int DETECTIONS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph"; + + /** + * Creates an {@link ObjectDetector} instance from a model file and the default {@link + * ObjectDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the detection model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. + */ + public static ObjectDetector createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ObjectDetector} instance from a model file and the default {@link + * ObjectDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the detection model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. + */ + public static ObjectDetector createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ObjectDetector} instance from a model buffer and the default {@link + * ObjectDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. + */ + public static ObjectDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param detectorOptions a {@link ObjectDetectorOptions} instance. + * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. + */ + public static ObjectDetector createFromOptions( + Context context, ObjectDetectorOptions detectorOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ObjectDetectionResult convertToTaskResult(List packets) { + return ObjectDetectionResult.create( + PacketGetter.getProtoVector( + packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), + packets.get(DETECTIONS_OUT_STREAM_INDEX).getTimestamp()); + } + + @Override + public Image convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + detectorOptions.resultListener().ifPresent(handler::setResultListener); + detectorOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(detectorOptions) + .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + detectorOptions.errorListener().ifPresent(runner::setErrorListener); + return new ObjectDetector(runner, detectorOptions.runningMode()); + } + + /** + * Constructor to initialize an {@link ObjectDetector} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode); + } + + /** + * Performs object detection on the provided single image. Only use this method when the {@link + * ObjectDetector} is created with {@link RunningMode.IMAGE}. + * + *

{@link ObjectDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detect(Image inputImage) { + return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage); + } + + /** + * Performs object detection on the provided video frame. Only use this method when the {@link + * ObjectDetector} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ObjectDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { + return (ObjectDetectionResult) + processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); + } + + /** + * Sends live image data to perform object detection, and the results will be available via the + * {@link ResultListener} provided in the {@link ObjectDetectorOptions}. Only use this method when + * the {@link ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ObjectDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param inputImage a MediaPipe {@link Image} object for processing. + * @param inputTimestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(Image inputImage, long inputTimestampMs) { + sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); + } + + /** Options for setting up an {@link ObjectDetector}. */ + @AutoValue + public abstract static class ObjectDetectorOptions extends TaskOptions { + + /** Builder for {@link ObjectDetectorOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the object detector task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the object detector task. Default to the image mode. Object + * detector has three modes: + * + *
    + *
  • IMAGE: The mode for detecting objects on single image inputs. + *
  • VIDEO: The mode for detecting objects on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for detecting objects on a live stream of input data, such + * as from camera. In this mode, {@code setResultListener} must be called to set up a + * listener to receive the detection results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * Sets the locale to use for display names specified through the TFLite Model Metadata, if + * any. Defaults to English. + */ + public abstract Builder setDisplayNamesLocale(String value); + + /** + * Sets the optional maximum number of top-scored detection results to return. + * + *

Overrides the ones provided in the model metadata. Results below this value are + * rejected. + */ + public abstract Builder setMaxResults(Integer value); + + /** + * Sets the optional score threshold that overrides the one provided in the model metadata (if + * any). Results below this value are rejected. + */ + public abstract Builder setScoreThreshold(Float value); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List value); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List value); + + /** + * Sets the result listener to receive the detection results asynchronously when the object + * detector is in the live stream mode. + */ + public abstract Builder setResultListener(ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract ObjectDetectorOptions autoBuild(); + + /** + * Validates and builds the {@link ObjectDetectorOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the object detector is + * in the live stream mode. + */ + public final ObjectDetectorOptions build() { + ObjectDetectorOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The object detector is in the live stream mode, a user-defined result listener" + + " must be provided in ObjectDetectorOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The object detector is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in ObjectDetectorOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ObjectDetector_ObjectDetectorOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); + } + + /** Converts a {@link ObjectDetectorOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ObjectDetectorOptionsProto.ObjectDetectorOptions.Builder taskOptionsBuilder = + ObjectDetectorOptionsProto.ObjectDetectorOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + displayNamesLocale().ifPresent(taskOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(taskOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(taskOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + taskOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + taskOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } + return CalculatorOptions.newBuilder() + .setExtension( + ObjectDetectorOptionsProto.ObjectDetectorOptions.ext, taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD index 65c1214af..74bf48c59 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/BUILD @@ -11,3 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "test_utils", + srcs = ["TestUtils.java"], + deps = [ + "//third_party/java/android_libs/guava_jdk5:io", + ], +) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/TestUtils.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/TestUtils.java new file mode 100644 index 000000000..130c413b9 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/core/TestUtils.java @@ -0,0 +1,83 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.core; + +import android.content.Context; +import android.content.res.AssetManager; +import com.google.common.io.ByteStreams; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** Helper class for the Java test in MediaPipe Tasks. */ +public final class TestUtils { + + /** + * Loads the file and create a {@link File} object by reading a file from the asset directory. + * Simulates downloading or reading a file that's not precompiled with the app. + * + * @return a {@link File} object for the model. + */ + public static File loadFile(Context context, String fileName) { + File target = new File(context.getFilesDir(), fileName); + try (InputStream is = context.getAssets().open(fileName); + FileOutputStream os = new FileOutputStream(target)) { + ByteStreams.copy(is, os); + } catch (IOException e) { + throw new AssertionError("Failed to load model file at " + fileName, e); + } + return target; + } + + /** + * Reads a file into a direct {@link ByteBuffer} object from the asset directory. + * + * @return a {@link ByteBuffer} object for the file. + */ + public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName) + throws IOException { + AssetManager assetManager = context.getAssets(); + InputStream inputStream = assetManager.open(fileName); + byte[] bytes = ByteStreams.toByteArray(inputStream); + + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder()); + buffer.put(bytes); + return buffer; + } + + /** + * Reads a file into a non-direct {@link ByteBuffer} object from the asset directory. + * + * @return a {@link ByteBuffer} object for the file. + */ + public static ByteBuffer loadToNonDirectByteBuffer(Context context, String fileName) + throws IOException { + AssetManager assetManager = context.getAssets(); + InputStream inputStream = assetManager.open(fileName); + byte[] bytes = ByteStreams.toByteArray(inputStream); + return ByteBuffer.wrap(bytes); + } + + public enum ByteBufferType { + DIRECT, + BACK_UP_ARRAY, + OTHER // Non-direct ByteBuffer without a back-up array. + } + + private TestUtils() {} +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD index 1bec2be3e..a7f804c64 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/BUILD @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + # TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java new file mode 100644 index 000000000..cdec57d76 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -0,0 +1,456 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.objectdetector; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Detection; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ObjectDetector}. */ +@RunWith(Suite.class) +@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class}) +public class ObjectDetectorTest { + private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; + private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg"; + private static final int IMAGE_WIDTH = 1200; + private static final int IMAGE_HEIGHT = 600; + private static final float CAT_SCORE = 0.69f; + private static final RectF catBoundingBox = new RectF(611, 164, 986, 596); + // TODO: Figure out why android_x86 and android_arm tests have slightly different + // scores (0.6875 vs 0.69921875). + private static final float SCORE_DIFF_TOLERANCE = 0.01f; + private static final float PIXEL_DIFF_TOLERANCE = 5.0f; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ObjectDetectorTest { + + @Test + public void detect_successWithValidModels() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMaxResults(1) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + } + + @Test + public void detect_successWithNoOptions() throws Exception { + ObjectDetector objectDetector = + ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // Check if the object with the highest score is cat. + assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); + } + + @Test + public void detect_succeedsWithMaxResultsOption() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMaxResults(8) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // results should have 8 detected objects because maxResults was set to 8. + assertThat(results.detections()).hasSize(8); + } + + @Test + public void detect_succeedsWithScoreThresholdOption() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setScoreThreshold(0.68f) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // The score threshold should block all other other objects, except cat. + assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + } + + @Test + public void detect_succeedsWithAllowListOption() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("cat")) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // Because of the allowlist, results should only contain cat, and there are 6 detected + // bounding boxes of cats in CAT_AND_DOG_IMAGE. + assertThat(results.detections()).hasSize(5); + } + + @Test + public void detect_succeedsWithDenyListOption() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setCategoryDenylist(Arrays.asList("cat")) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // Because of the denylist, the highest result is not cat anymore. + assertThat(results.detections().get(0).categories().get(0).categoryName()) + .isNotEqualTo("cat"); + } + + @Test + public void detect_succeedsWithModelFileObject() throws Exception { + ObjectDetector objectDetector = + ObjectDetector.createFromFile( + ApplicationProvider.getApplicationContext(), + TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // Check if the object with the highest score is cat. + assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); + } + + @Test + public void detect_succeedsWithModelBuffer() throws Exception { + ObjectDetector objectDetector = + ObjectDetector.createFromBuffer( + ApplicationProvider.getApplicationContext(), + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE)); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + // Check if the object with the highest score is cat. + assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE); + } + + @Test + public void detect_succeedsWithModelBufferAndOptions() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE)) + .build()) + .setMaxResults(1) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + } + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonexistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ObjectDetector.createFromFile( + ApplicationProvider.getApplicationContext(), nonexistentFile)); + assertThat(exception).hasMessageThat().contains(nonexistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ObjectDetector.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void detect_failsWithBothAllowAndDenyListOption() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("cat")) + .setCategoryDenylist(Arrays.asList("dog")) + .build(); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ObjectDetector.createFromOptions( + ApplicationProvider.getApplicationContext(), options)); + assertThat(exception) + .hasMessageThat() + .contains("`category_allowlist` and `category_denylist` are mutually exclusive options."); + } + + // TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation, + // detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions, + // detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero. + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ObjectDetectorTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(mode) + .setResultListener((objectDetectionResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void detect_failsWithCallingWrongApiInImageMode() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInVideoMode() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((objectDetectionResult, inputImage) -> {}) + .build(); + + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void detect_successWithImageMode() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .setMaxResults(1) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); + assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + } + + @Test + public void detect_successWithVideoMode() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .setMaxResults(1) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + ObjectDetectionResult results = + objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i); + assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + } + } + + @Test + public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { + Image image = getImageFromAsset(CAT_AND_DOG_IMAGE); + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (objectDetectionResult, inputImage) -> { + assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertImageSizeIsExpected(inputImage); + }) + .setMaxResults(1) + .build(); + try (ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + objectDetector.detectAsync(image, 1); + MediaPipeException exception = + assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void detect_successWithLiveSteamMode() throws Exception { + Image image = getImageFromAsset(CAT_AND_DOG_IMAGE); + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (objectDetectionResult, inputImage) -> { + assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertImageSizeIsExpected(inputImage); + }) + .setMaxResults(1) + .build(); + try (ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + objectDetector.detectAsync(image, i); + } + } + } + } + + private static Image getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + // Checks if results has one and only detection result, which is a cat. + private static void assertContainsOnlyCat( + ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) { + assertThat(result.detections()).hasSize(1); + Detection catResult = result.detections().get(0); + assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox); + // We only support one category for each detected object at this point. + assertIsCat(catResult.categories().get(0), expectedScore); + } + + private static void assertIsCat(Category category, float expectedScore) { + assertThat(category.categoryName()).isEqualTo("cat"); + // coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite does not support label locale. + assertThat(category.displayName()).isEmpty(); + assertThat((double) category.score()).isWithin(SCORE_DIFF_TOLERANCE).of(expectedScore); + assertThat(category.index()).isEqualTo(-1); + } + + private static void assertApproximatelyEqualBoundingBoxes( + RectF boundingBox1, RectF boundingBox2) { + assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left); + assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top); + assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right); + assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom); + } + + private static void assertImageSizeIsExpected(Image inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +} diff --git a/mediapipe/tasks/metadata/BUILD b/mediapipe/tasks/metadata/BUILD index 957bf6b74..abd948809 100644 --- a/mediapipe/tasks/metadata/BUILD +++ b/mediapipe/tasks/metadata/BUILD @@ -1,4 +1,4 @@ -load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library") package( default_visibility = [ @@ -14,3 +14,13 @@ flatbuffer_cc_library( name = "metadata_schema_cc", srcs = ["metadata_schema.fbs"], ) + +flatbuffer_py_library( + name = "schema_py", + srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"], +) + +flatbuffer_py_library( + name = "metadata_schema_py", + srcs = ["metadata_schema.fbs"], +) diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index eb3acdd97..450111161 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -31,7 +31,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers:category_py_pb2", + "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index 00f68e532..0b347fc10 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,7 +16,7 @@ import dataclasses from typing import Any -from mediapipe.tasks.cc.components.containers import category_pb2 +from mediapipe.tasks.cc.components.containers.proto import category_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls _CategoryProto = category_pb2.Category diff --git a/mediapipe/tasks/python/core/pybind/BUILD b/mediapipe/tasks/python/core/pybind/BUILD index fab878135..b59635dc3 100644 --- a/mediapipe/tasks/python/core/pybind/BUILD +++ b/mediapipe/tasks/python/core/pybind/BUILD @@ -27,6 +27,7 @@ pybind_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/python/pybind:util", + "//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver", "//mediapipe/tasks/cc/core:task_runner", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index 52834bab2..cb13787c3 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/python/pybind/util.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" @@ -75,7 +76,7 @@ mode) or not (synchronous mode).)doc"); } auto task_runner = TaskRunner::Create( std::move(graph_config), - absl::make_unique(), + absl::make_unique(), std::move(callback)); RaisePyErrorIfNotOk(task_runner.status()); return std::move(*task_runner); diff --git a/mediapipe/tasks/python/metadata/BUILD b/mediapipe/tasks/python/metadata/BUILD new file mode 100644 index 000000000..07805ec61 --- /dev/null +++ b/mediapipe/tasks/python/metadata/BUILD @@ -0,0 +1,38 @@ +load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version") + +package( + licenses = ["notice"], # Apache 2.0 +) + +stamp_metadata_parser_version( + name = "metadata_parser_py", + srcs = ["metadata_parser.py.template"], + outs = ["metadata_parser.py"], +) + +py_library( + name = "metadata", + srcs = [ + "metadata.py", + ":metadata_parser_py", + ], + data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version", + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/metadata:schema_py", + "//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers", + "@flatbuffers//:runtime_py", + ], +) + +py_binary( + name = "metadata_displayer_cli", + srcs = ["metadata_displayer_cli.py"], + visibility = [ + "//visibility:public", + ], + deps = [":metadata"], +) diff --git a/mediapipe/tasks/python/metadata/__init__.py b/mediapipe/tasks/python/metadata/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/tasks/python/metadata/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD b/mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD new file mode 100644 index 000000000..303ff3224 --- /dev/null +++ b/mediapipe/tasks/python/metadata/flatbuffers_lib/BUILD @@ -0,0 +1,20 @@ +load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = ["//mediapipe/tasks:internal"], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_flatbuffers", + srcs = [ + "flatbuffers_lib.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_flatbuffers", + deps = [ + "@flatbuffers", + "@local_config_python//:python_headers", + "@pybind11", + ], +) diff --git a/mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc b/mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc new file mode 100644 index 000000000..34407620c --- /dev/null +++ b/mediapipe/tasks/python/metadata/flatbuffers_lib/flatbuffers_lib.cc @@ -0,0 +1,59 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/idl.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace tflite { +namespace support { + +PYBIND11_MODULE(_pywrap_flatbuffers, m) { + pybind11::class_(m, "IDLOptions") + .def(pybind11::init<>()) + .def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json); + pybind11::class_(m, "Parser") + .def(pybind11::init()) + .def("parse", + [](flatbuffers::Parser* self, const std::string& source) { + return self->Parse(source.c_str()); + }) + .def_readonly("builder", &flatbuffers::Parser::builder_) + .def_readonly("error", &flatbuffers::Parser::error_); + pybind11::class_(m, "FlatBufferBuilder") + .def("clear", &flatbuffers::FlatBufferBuilder::Clear) + .def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self, + const std::string& contents) { + self->PushFlatBuffer(reinterpret_cast(contents.c_str()), + contents.length()); + }); + m.def("generate_text_file", &flatbuffers::GenerateTextFile); + m.def( + "generate_text", + [](const flatbuffers::Parser& parser, + const std::string& buffer) -> std::string { + std::string text; + if (!flatbuffers::GenerateText( + parser, reinterpret_cast(buffer.c_str()), &text)) { + return ""; + } + return text; + }); +} + +} // namespace support +} // namespace tflite diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py new file mode 100644 index 000000000..10a0b9b66 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -0,0 +1,865 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite metadata tools.""" + +import copy +import inspect +import io +import os +import shutil +import sys +import tempfile +import warnings +import zipfile + +import flatbuffers +from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.metadata import schema_py_generated as _schema_fb +from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers + +try: + # If exists, optionally use TensorFlow to open and check files. Used to + # support more than local file systems. + # In pip requirements, we doesn't necessarily need tensorflow as a dep. + import tensorflow as tf + _open_file = tf.io.gfile.GFile + _exists_file = tf.io.gfile.exists +except ImportError as e: + # If TensorFlow package doesn't exist, fall back to original open and exists. + _open_file = open + _exists_file = os.path.exists + + +def _maybe_open_as_binary(filename, mode): + """Maybe open the binary file, and returns a file-like.""" + if hasattr(filename, "read"): # A file-like has read(). + return filename + openmode = mode if "b" in mode else mode + "b" # Add binary explicitly. + return _open_file(filename, openmode) + + +def _open_as_zipfile(filename, mode="r"): + """Open file as a zipfile. + + Args: + filename: str or file-like or path-like, to the zipfile. + mode: str, common file mode for zip. + (See: https://docs.python.org/3/library/zipfile.html) + + Returns: + A ZipFile object. + """ + file_like = _maybe_open_as_binary(filename, mode) + return zipfile.ZipFile(file_like, mode) + + +def _is_zipfile(filename): + """Checks whether it is a zipfile.""" + with _maybe_open_as_binary(filename, "r") as f: + return zipfile.is_zipfile(f) + + +def get_path_to_datafile(path): + """Gets the path to the specified file in the data dependencies. + + The path is relative to the file calling the function. + + It's a simple replacement of + "tensorflow.python.platform.resource_loader.get_path_to_datafile". + + Args: + path: a string resource path relative to the calling file. + + Returns: + The path to the specified file present in the data attribute of py_test + or py_binary. + """ + data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access + return os.path.join(data_files_path, path) + + +_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile( + "../../metadata/metadata_schema.fbs") + + +# TODO: add delete method for associated files. +class MetadataPopulator(object): + """Packs metadata and associated files into TensorFlow Lite model file. + + MetadataPopulator can be used to populate metadata and model associated files + into a model file or a model buffer (in bytearray). It can also help to + inspect list of files that have been packed into the model or are supposed to + be packed into the model. + + The metadata file (or buffer) should be generated based on the metadata + schema: + third_party/tensorflow/lite/schema/metadata_schema.fbs + + Example usage: + Populate matadata and label file into an image classifier model. + + First, based on metadata_schema.fbs, generate the metadata for this image + classifer model using Flatbuffers API. Attach the label file onto the ouput + tensor (the tensor of probabilities) in the metadata. + + Then, pack the metadata and label file into the model as follows. + + ```python + # Populating a metadata file (or a metadta buffer) and associated files to + a model file: + populator = MetadataPopulator.with_model_file(model_file) + # For metadata buffer (bytearray read from the metadata file), use: + # populator.load_metadata_buffer(metadata_buf) + populator.load_metadata_file(metadata_file) + populator.load_associated_files([label.txt]) + # For associated file buffer (bytearray read from the file), use: + # populator.load_associated_file_buffers({"label.txt": b"file content"}) + populator.populate() + + # Populating a metadata file (or a metadata buffer) and associated files to + a model buffer: + populator = MetadataPopulator.with_model_buffer(model_buf) + populator.load_metadata_file(metadata_file) + populator.load_associated_files([label.txt]) + populator.populate() + # Writing the updated model buffer into a file. + updated_model_buf = populator.get_model_buffer() + with open("updated_model.tflite", "wb") as f: + f.write(updated_model_buf) + + # Transferring metadata and associated files from another TFLite model: + populator = MetadataPopulator.with_model_buffer(model_buf) + populator_dst.load_metadata_and_associated_files(src_model_buf) + populator_dst.populate() + updated_model_buf = populator.get_model_buffer() + with open("updated_model.tflite", "wb") as f: + f.write(updated_model_buf) + ``` + + Note that existing metadata buffer (if applied) will be overridden by the new + metadata buffer. + """ + # As Zip API is used to concatenate associated files after tflite model file, + # the populating operation is developed based on a model file. For in-memory + # model buffer, we create a tempfile to serve the populating operation. + # Creating the deleting such a tempfile is handled by the class, + # _MetadataPopulatorWithBuffer. + + METADATA_FIELD_NAME = "TFLITE_METADATA" + TFLITE_FILE_IDENTIFIER = b"TFL3" + METADATA_FILE_IDENTIFIER = b"M001" + + def __init__(self, model_file): + """Constructor for MetadataPopulator. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Raises: + IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. + """ + _assert_model_file_identifier(model_file) + self._model_file = model_file + self._metadata_buf = None + # _associated_files is a dict of file name and file buffer. + self._associated_files = {} + + @classmethod + def with_model_file(cls, model_file): + """Creates a MetadataPopulator object that populates data to a model file. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Returns: + MetadataPopulator object. + + Raises: + IOError: File not found. + ValueError: the model does not have the expected flatbuffer identifer. + """ + return cls(model_file) + + # TODO: investigate if type check can be applied to model_buf for + # FB. + @classmethod + def with_model_buffer(cls, model_buf): + """Creates a MetadataPopulator object that populates data to a model buffer. + + Args: + model_buf: TensorFlow Lite model buffer in bytearray. + + Returns: + A MetadataPopulator(_MetadataPopulatorWithBuffer) object. + + Raises: + ValueError: the model does not have the expected flatbuffer identifer. + """ + return _MetadataPopulatorWithBuffer(model_buf) + + def get_model_buffer(self): + """Gets the buffer of the model with packed metadata and associated files. + + Returns: + Model buffer (in bytearray). + """ + with _open_file(self._model_file, "rb") as f: + return f.read() + + def get_packed_associated_file_list(self): + """Gets a list of associated files packed to the model file. + + Returns: + List of packed associated files. + """ + if not _is_zipfile(self._model_file): + return [] + + with _open_as_zipfile(self._model_file, "r") as zf: + return zf.namelist() + + def get_recorded_associated_file_list(self): + """Gets a list of associated files recorded in metadata of the model file. + + Associated files may be attached to a model, a subgraph, or an input/output + tensor. + + Returns: + List of recorded associated files. + """ + if not self._metadata_buf: + return [] + + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + self._metadata_buf, 0)) + + return [ + file.name.decode("utf-8") + for file in self._get_recorded_associated_file_object_list(metadata) + ] + + def load_associated_file_buffers(self, associated_files): + """Loads the associated file buffers (in bytearray) to be populated. + + Args: + associated_files: a dictionary of associated file names and corresponding + file buffers, such as {"file.txt": b"file content"}. If pass in file + paths for the file name, only the basename will be populated. + """ + + self._associated_files.update({ + os.path.basename(name): buffers + for name, buffers in associated_files.items() + }) + + def load_associated_files(self, associated_files): + """Loads associated files that to be concatenated after the model file. + + Args: + associated_files: list of file paths. + + Raises: + IOError: + File not found. + """ + for af_name in associated_files: + _assert_file_exist(af_name) + with _open_file(af_name, "rb") as af: + self.load_associated_file_buffers({af_name: af.read()}) + + def load_metadata_buffer(self, metadata_buf): + """Loads the metadata buffer (in bytearray) to be populated. + + Args: + metadata_buf: metadata buffer (in bytearray) to be populated. + + Raises: + ValueError: The metadata to be populated is empty. + ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: Cannot get minimum metadata parser version. + ValueError: The number of SubgraphMetadata is not 1. + ValueError: The number of input/output tensors does not match the number + of input/output tensor metadata. + """ + if not metadata_buf: + raise ValueError("The metadata to be populated is empty.") + + self._validate_metadata(metadata_buf) + + # Gets the minimum metadata parser version of the metadata_buf. + min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion( + bytes(metadata_buf)) + + # Inserts in the minimum metadata parser version into the metadata_buf. + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + # Remove local file directory in the `name` field of `AssociatedFileT`, and + # make it consistent with the name of the actual file packed in the model. + self._use_basename_for_associated_files_in_metadata(metadata) + + b = flatbuffers.Builder(0) + b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER) + metadata_buf_with_version = b.Output() + + self._metadata_buf = metadata_buf_with_version + + def load_metadata_file(self, metadata_file): + """Loads the metadata file to be populated. + + Args: + metadata_file: path to the metadata file to be populated. + + Raises: + IOError: File not found. + ValueError: The metadata to be populated is empty. + ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: Cannot get minimum metadata parser version. + ValueError: The number of SubgraphMetadata is not 1. + ValueError: The number of input/output tensors does not match the number + of input/output tensor metadata. + """ + _assert_file_exist(metadata_file) + with _open_file(metadata_file, "rb") as f: + metadata_buf = f.read() + self.load_metadata_buffer(bytearray(metadata_buf)) + + def load_metadata_and_associated_files(self, src_model_buf): + """Loads the metadata and associated files from another model buffer. + + Args: + src_model_buf: source model buffer (in bytearray) with metadata and + associated files. + """ + # Load the model metadata from src_model_buf if exist. + metadata_buffer = get_metadata_buffer(src_model_buf) + if metadata_buffer: + self.load_metadata_buffer(metadata_buffer) + + # Load the associated files from src_model_buf if exist. + if _is_zipfile(io.BytesIO(src_model_buf)): + with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf: + self.load_associated_file_buffers( + {f: zf.read(f) for f in zf.namelist()}) + + def populate(self): + """Populates loaded metadata and associated files into the model file.""" + self._assert_validate() + self._populate_metadata_buffer() + self._populate_associated_files() + + def _assert_validate(self): + """Validates the metadata and associated files to be populated. + + Raises: + ValueError: + File is recorded in the metadata, but is not going to be populated. + File has already been packed. + """ + # Gets files that are recorded in metadata. + recorded_files = self.get_recorded_associated_file_list() + + # Gets files that have been packed to self._model_file. + packed_files = self.get_packed_associated_file_list() + + # Gets the file name of those associated files to be populated. + to_be_populated_files = self._associated_files.keys() + + # Checks all files recorded in the metadata will be populated. + for rf in recorded_files: + if rf not in to_be_populated_files and rf not in packed_files: + raise ValueError("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.".format(rf)) + + for f in to_be_populated_files: + if f in packed_files: + raise ValueError("File, '{0}', has already been packed.".format(f)) + + if f not in recorded_files: + warnings.warn( + "File, '{0}', does not exist in the metadata. But packing it to " + "tflite model is still allowed.".format(f)) + + def _copy_archived_files(self, src_zip, file_list, dst_zip): + """Copy archieved files in file_list from src_zip ro dst_zip.""" + + if not _is_zipfile(src_zip): + raise ValueError("File, '{0}', is not a zipfile.".format(src_zip)) + + with _open_as_zipfile(src_zip, "r") as src_zf, \ + _open_as_zipfile(dst_zip, "a") as dst_zf: + src_list = src_zf.namelist() + for f in file_list: + if f not in src_list: + raise ValueError( + "File, '{0}', does not exist in the zipfile, {1}.".format( + f, src_zip)) + file_buffer = src_zf.read(f) + dst_zf.writestr(f, file_buffer) + + def _get_associated_files_from_process_units(self, table, field_name): + """Gets the files that are attached the process units field of a table. + + Args: + table: a Flatbuffers table object that contains fields of an array of + ProcessUnit, such as TensorMetadata and SubGraphMetadata. + field_name: the name of the field in the table that represents an array of + ProcessUnit. If the table is TensorMetadata, field_name can be + "ProcessUnits". If the table is SubGraphMetadata, field_name can be + either "InputProcessUnits" or "OutputProcessUnits". + + Returns: + A list of AssociatedFileT objects. + """ + + if table is None: + return [] + + file_list = [] + process_units = getattr(table, field_name) + # If the process_units field is not populated, it will be None. Use an + # empty list to skip the check. + for process_unit in process_units or []: + options = process_unit.options + if isinstance(options, (_metadata_fb.BertTokenizerOptionsT, + _metadata_fb.RegexTokenizerOptionsT)): + file_list += self._get_associated_files_from_table(options, "vocabFile") + elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT): + file_list += self._get_associated_files_from_table( + options, "sentencePieceModel") + file_list += self._get_associated_files_from_table(options, "vocabFile") + return file_list + + def _get_associated_files_from_table(self, table, field_name): + """Gets the associated files that are attached a table directly. + + Args: + table: a Flatbuffers table object that contains fields of an array of + AssociatedFile, such as TensorMetadata and BertTokenizerOptions. + field_name: the name of the field in the table that represents an array of + ProcessUnit. If the table is TensorMetadata, field_name can be + "AssociatedFiles". If the table is BertTokenizerOptions, field_name can + be "VocabFile". + + Returns: + A list of AssociatedFileT objects. + """ + + if table is None: + return [] + + # If the associated file field is not populated, + # `getattr(table, field_name)` will be None. Return an empty list. + return getattr(table, field_name) or [] + + def _get_recorded_associated_file_object_list(self, metadata): + """Gets a list of AssociatedFileT objects recorded in the metadata. + + Associated files may be attached to a model, a subgraph, or an input/output + tensor. + + Args: + metadata: the ModelMetadataT object. + + Returns: + List of recorded AssociatedFileT objects. + """ + recorded_files = [] + + # Add associated files attached to ModelMetadata. + recorded_files += self._get_associated_files_from_table( + metadata, "associatedFiles") + + # Add associated files attached to each SubgraphMetadata. + for subgraph in metadata.subgraphMetadata or []: + recorded_files += self._get_associated_files_from_table( + subgraph, "associatedFiles") + + # Add associated files attached to each input tensor. + for tensor_metadata in subgraph.inputTensorMetadata or []: + recorded_files += self._get_associated_files_from_table( + tensor_metadata, "associatedFiles") + recorded_files += self._get_associated_files_from_process_units( + tensor_metadata, "processUnits") + + # Add associated files attached to each output tensor. + for tensor_metadata in subgraph.outputTensorMetadata or []: + recorded_files += self._get_associated_files_from_table( + tensor_metadata, "associatedFiles") + recorded_files += self._get_associated_files_from_process_units( + tensor_metadata, "processUnits") + + # Add associated files attached to the input_process_units. + recorded_files += self._get_associated_files_from_process_units( + subgraph, "inputProcessUnits") + + # Add associated files attached to the output_process_units. + recorded_files += self._get_associated_files_from_process_units( + subgraph, "outputProcessUnits") + + return recorded_files + + def _populate_associated_files(self): + """Concatenates associated files after TensorFlow Lite model file. + + If the MetadataPopulator object is created using the method, + with_model_file(model_file), the model file will be updated. + """ + # Opens up the model file in "appending" mode. + # If self._model_file already has pack files, zipfile will concatenate + # addition files after self._model_file. For example, suppose we have + # self._model_file = old_tflite_file | label1.txt | label2.txt + # Then after trigger populate() to add label3.txt, self._model_file becomes + # self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt + with tempfile.SpooledTemporaryFile() as temp: + # (1) Copy content from model file of to temp file. + with _open_file(self._model_file, "rb") as f: + shutil.copyfileobj(f, temp) + + # (2) Append of to a temp file as a zip. + with _open_as_zipfile(temp, "a") as zf: + for file_name, file_buffer in self._associated_files.items(): + zf.writestr(file_name, file_buffer) + + # (3) Copy temp file to model file. + temp.seek(0) + with _open_file(self._model_file, "wb") as f: + shutil.copyfileobj(temp, f) + + def _populate_metadata_buffer(self): + """Populates the metadata buffer (in bytearray) into the model file. + + Inserts metadata_buf into the metadata field of schema.Model. If the + MetadataPopulator object is created using the method, + with_model_file(model_file), the model file will be updated. + + Existing metadata buffer (if applied) will be overridden by the new metadata + buffer. + """ + + with _open_file(self._model_file, "rb") as f: + model_buf = f.read() + + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = self._metadata_buf + + is_populated = False + if not model.metadata: + model.metadata = [] + else: + # Check if metadata has already been populated. + for meta in model.metadata: + if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME: + is_populated = True + model.buffers[meta.buffer] = buffer_field + + if not is_populated: + if not model.buffers: + model.buffers = [] + model.buffers.append(buffer_field) + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = self.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata.append(metadata_field) + + # Packs model back to a flatbuffer binaray file. + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER) + model_buf = b.Output() + + # Saves the updated model buffer to model file. + # Gets files that have been packed to self._model_file. + packed_files = self.get_packed_associated_file_list() + if packed_files: + # Writes the updated model buffer and associated files into a new model + # file (in memory). Then overwrites the original model file. + with tempfile.SpooledTemporaryFile() as temp: + temp.write(model_buf) + self._copy_archived_files(self._model_file, packed_files, temp) + temp.seek(0) + with _open_file(self._model_file, "wb") as f: + shutil.copyfileobj(temp, f) + else: + with _open_file(self._model_file, "wb") as f: + f.write(model_buf) + + def _use_basename_for_associated_files_in_metadata(self, metadata): + """Removes any associated file local directory (if exists).""" + for file in self._get_recorded_associated_file_object_list(metadata): + file.name = os.path.basename(file.name) + + def _validate_metadata(self, metadata_buf): + """Validates the metadata to be populated.""" + _assert_metadata_buffer_identifier(metadata_buf) + + # Verify the number of SubgraphMetadata is exactly one. + # TFLite currently only support one subgraph. + model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata( + metadata_buf, 0) + if model_meta.SubgraphMetadataLength() != 1: + raise ValueError("The number of SubgraphMetadata should be exactly one, " + "but got {0}.".format( + model_meta.SubgraphMetadataLength())) + + # Verify if the number of tensor metadata matches the number of tensors. + with _open_file(self._model_file, "rb") as f: + model_buf = f.read() + model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + num_input_tensors = model.Subgraphs(0).InputsLength() + num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength() + if num_input_tensors != num_input_meta: + raise ValueError( + "The number of input tensors ({0}) should match the number of " + "input tensor metadata ({1})".format(num_input_tensors, + num_input_meta)) + num_output_tensors = model.Subgraphs(0).OutputsLength() + num_output_meta = model_meta.SubgraphMetadata( + 0).OutputTensorMetadataLength() + if num_output_tensors != num_output_meta: + raise ValueError( + "The number of output tensors ({0}) should match the number of " + "output tensor metadata ({1})".format(num_output_tensors, + num_output_meta)) + + +class _MetadataPopulatorWithBuffer(MetadataPopulator): + """Subclass of MetadtaPopulator that populates metadata to a model buffer. + + This class is used to populate metadata into a in-memory model buffer. As we + use Zip API to concatenate associated files after tflite model file, the + populating operation is developed based on a model file. For in-memory model + buffer, we create a tempfile to serve the populating operation. This class is + then used to generate this tempfile, and delete the file when the + MetadataPopulator object is deleted. + """ + + def __init__(self, model_buf): + """Constructor for _MetadataPopulatorWithBuffer. + + Args: + model_buf: TensorFlow Lite model buffer in bytearray. + + Raises: + ValueError: model_buf is empty. + ValueError: model_buf does not have the expected flatbuffer identifer. + """ + if not model_buf: + raise ValueError("model_buf cannot be empty.") + + with tempfile.NamedTemporaryFile() as temp: + model_file = temp.name + + with _open_file(model_file, "wb") as f: + f.write(model_buf) + + super().__init__(model_file) + + def __del__(self): + """Destructor of _MetadataPopulatorWithBuffer. + + Deletes the tempfile. + """ + if os.path.exists(self._model_file): + os.remove(self._model_file) + + +class MetadataDisplayer(object): + """Displays metadata and associated file info in human-readable format.""" + + def __init__(self, model_buffer, metadata_buffer, associated_file_list): + """Constructor for MetadataDisplayer. + + Args: + model_buffer: valid buffer of the model file. + metadata_buffer: valid buffer of the metadata file. + associated_file_list: list of associate files in the model file. + """ + _assert_model_buffer_identifier(model_buffer) + _assert_metadata_buffer_identifier(metadata_buffer) + self._model_buffer = model_buffer + self._metadata_buffer = metadata_buffer + self._associated_file_list = associated_file_list + + @classmethod + def with_model_file(cls, model_file): + """Creates a MetadataDisplayer object for the model file. + + Args: + model_file: valid path to a TensorFlow Lite model file. + + Returns: + MetadataDisplayer object. + + Raises: + IOError: File not found. + ValueError: The model does not have metadata. + """ + _assert_file_exist(model_file) + with _open_file(model_file, "rb") as f: + return cls.with_model_buffer(f.read()) + + @classmethod + def with_model_buffer(cls, model_buffer): + """Creates a MetadataDisplayer object for a file buffer. + + Args: + model_buffer: TensorFlow Lite model buffer in bytearray. + + Returns: + MetadataDisplayer object. + """ + if not model_buffer: + raise ValueError("model_buffer cannot be empty.") + metadata_buffer = get_metadata_buffer(model_buffer) + if not metadata_buffer: + raise ValueError("The model does not have metadata.") + associated_file_list = cls._parse_packed_associted_file_list(model_buffer) + return cls(model_buffer, metadata_buffer, associated_file_list) + + def get_associated_file_buffer(self, filename): + """Get the specified associated file content in bytearray. + + Args: + filename: name of the file to be extracted. + + Returns: + The file content in bytearray. + + Raises: + ValueError: if the file does not exist in the model. + """ + if filename not in self._associated_file_list: + raise ValueError( + "The file, {}, does not exist in the model.".format(filename)) + + with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf: + return zf.read(filename) + + def get_metadata_buffer(self): + """Get the metadata buffer in bytearray out from the model.""" + return copy.deepcopy(self._metadata_buffer) + + def get_metadata_json(self): + """Converts the metadata into a json string.""" + return convert_to_json(self._metadata_buffer) + + def get_packed_associated_file_list(self): + """Returns a list of associated files that are packed in the model. + + Returns: + A name list of associated files. + """ + return copy.deepcopy(self._associated_file_list) + + @staticmethod + def _parse_packed_associted_file_list(model_buf): + """Gets a list of associated files packed to the model file. + + Args: + model_buf: valid file buffer. + + Returns: + List of packed associated files. + """ + + try: + with _open_as_zipfile(io.BytesIO(model_buf)) as zf: + return zf.namelist() + except zipfile.BadZipFile: + return [] + + +# Create an individual method for getting the metadata json file, so that it can +# be used as a standalone util. +def convert_to_json(metadata_buffer): + """Converts the metadata into a json string. + + Args: + metadata_buffer: valid metadata buffer in bytes. + + Returns: + Metadata in JSON format. + + Raises: + ValueError: error occured when parsing the metadata schema file. + """ + + opt = _pywrap_flatbuffers.IDLOptions() + opt.strict_json = True + parser = _pywrap_flatbuffers.Parser(opt) + with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f: + metadata_schema_content = f.read() + if not parser.parse(metadata_schema_content): + raise ValueError("Cannot parse metadata schema. Reason: " + parser.error) + return _pywrap_flatbuffers.generate_text(parser, metadata_buffer) + + +def _assert_file_exist(filename): + """Checks if a file exists.""" + if not _exists_file(filename): + raise IOError("File, '{0}', does not exist.".format(filename)) + + +def _assert_model_file_identifier(model_file): + """Checks if a model file has the expected TFLite schema identifier.""" + _assert_file_exist(model_file) + with _open_file(model_file, "rb") as f: + _assert_model_buffer_identifier(f.read()) + + +def _assert_model_buffer_identifier(model_buf): + if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0): + raise ValueError( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.") + + +def _assert_metadata_buffer_identifier(metadata_buf): + """Checks if a metadata buffer has the expected Metadata schema identifier.""" + if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier( + metadata_buf, 0): + raise ValueError( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.") + + +def get_metadata_buffer(model_buf): + """Returns the metadata in the model file as a buffer. + + Args: + model_buf: valid buffer of the model file. + + Returns: + Metadata buffer. Returns `None` if the model does not have metadata. + """ + tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0) + + # Gets metadata from the model file. + for i in range(tflite_model.MetadataLength()): + meta = tflite_model.Metadata(i) + if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME: + buffer_index = meta.Buffer() + metadata = tflite_model.Buffers(buffer_index) + return metadata.DataAsNumpy().tobytes() + + return None diff --git a/mediapipe/tasks/python/metadata/metadata_displayer_cli.py b/mediapipe/tasks/python/metadata/metadata_displayer_cli.py new file mode 100644 index 000000000..745da1f25 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_displayer_cli.py @@ -0,0 +1,34 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CLI tool for display metadata.""" + +from absl import app +from absl import flags + +from mediapipe.tasks.python.metadata import metadata + +FLAGS = flags.FLAGS +flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.') +flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.') + + +def main(_): + displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path) + with open(FLAGS.export_json_path, 'w') as f: + f.write(displayer.get_metadata_json()) + + +if __name__ == '__main__': + app.run(main) diff --git a/mediapipe/tasks/python/metadata/metadata_parser.py.template b/mediapipe/tasks/python/metadata/metadata_parser.py.template new file mode 100644 index 000000000..b5a64dee6 --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_parser.py.template @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Information about the metadata parser that this python library depends on.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class MetadataParser(object): + """Information about the metadata parser.""" + + # The version of the metadata parser. + VERSION = "{LATEST_METADATA_PARSER_VERSION}" diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 7d5f2451b..d4ef3a35b 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -19,11 +19,9 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) py_library( - name = "test_util", + name = "test_utils", testonly = 1, - srcs = ["test_util.py"], + srcs = ["test_utils.py"], srcs_version = "PY3", - deps = [ - "//mediapipe/python:_framework_bindings", - ], + deps = ["//mediapipe/python:_framework_bindings"], ) diff --git a/mediapipe/tasks/python/test/metadata/BUILD b/mediapipe/tasks/python/test/metadata/BUILD new file mode 100644 index 000000000..2cdc7e63a --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/BUILD @@ -0,0 +1,29 @@ +package( + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_test( + name = "metadata_test", + srcs = ["metadata_test.py"], + data = ["//mediapipe/tasks/testdata/metadata:data_files"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/metadata:schema_py", + "//mediapipe/tasks/python/metadata", + "//mediapipe/tasks/python/test:test_utils", + "@flatbuffers//:runtime_py", + ], +) + +py_test( + name = "metadata_parser_test", + srcs = ["metadata_parser_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = ["//mediapipe/tasks/python/metadata"], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_parser_test.py b/mediapipe/tasks/python/test/metadata/metadata_parser_test.py new file mode 100644 index 000000000..93b851082 --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_parser_test.py @@ -0,0 +1,37 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mediapipe.tasks.python.metadata.metadata_parser.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +from absl.testing import absltest + +from mediapipe.tasks.python.metadata import metadata_parser + + +class MetadataParserTest(absltest.TestCase): + + def testVersionWellFormedSemanticVersion(self): + # Validates that the version is well-formed (x.y.z). + self.assertTrue( + re.match('[0-9]+\\.[0-9]+\\.[0-9]+', + metadata_parser.MetadataParser.VERSION)) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_test.py b/mediapipe/tasks/python/test/metadata/metadata_test.py new file mode 100644 index 000000000..00dbe526a --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_test.py @@ -0,0 +1,857 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mediapipe.tasks.python.metadata.metadata.""" + +import enum +import os + +from absl.testing import absltest +from absl.testing import parameterized +import six + +import flatbuffers +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.metadata import schema_py_generated as _schema_fb +from mediapipe.tasks.python.metadata import metadata as _metadata +from mediapipe.tasks.python.test import test_utils + + +class Tokenizer(enum.Enum): + BERT_TOKENIZER = 0 + SENTENCE_PIECE = 1 + + +class TensorType(enum.Enum): + INPUT = 0 + OUTPUT = 1 + + +def _read_file(file_name, mode="rb"): + with open(file_name, mode) as f: + return f.read() + + +class MetadataTest(parameterized.TestCase): + + def setUp(self): + super(MetadataTest, self).setUp() + self._invalid_model_buf = None + self._invalid_file = "not_existed_file" + self._model_buf = self._create_model_buf() + self._model_file = self.create_tempfile().full_path + with open(self._model_file, "wb") as f: + f.write(self._model_buf) + self._metadata_file = self._create_metadata_file() + self._metadata_file_with_version = self._create_metadata_file_with_version( + self._metadata_file, "1.0.0") + self._file1 = self.create_tempfile("file1").full_path + self._file2 = self.create_tempfile("file2").full_path + self._file2_content = b"file2_content" + with open(self._file2, "wb") as f: + f.write(self._file2_content) + self._file3 = self.create_tempfile("file3").full_path + + def _create_model_buf(self): + # Create a model with two inputs and one output, which matches the metadata + # created by _create_metadata_file(). + metadata_field = _schema_fb.MetadataT() + subgraph = _schema_fb.SubGraphT() + subgraph.inputs = [0, 1] + subgraph.outputs = [2] + + metadata_field.name = "meta" + buffer_field = _schema_fb.BufferT() + model = _schema_fb.ModelT() + model.subgraphs = [subgraph] + # Creates the metadata and buffer fields for testing purposes. + model.metadata = [metadata_field, metadata_field] + model.buffers = [buffer_field, buffer_field, buffer_field] + model_builder = flatbuffers.Builder(0) + model_builder.Finish( + model.Pack(model_builder), + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + return model_builder.Output() + + def _create_metadata_file(self): + associated_file1 = _metadata_fb.AssociatedFileT() + associated_file1.name = b"file1" + associated_file2 = _metadata_fb.AssociatedFileT() + associated_file2.name = b"file2" + self.expected_recorded_files = [ + six.ensure_str(associated_file1.name), + six.ensure_str(associated_file2.name) + ] + + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + output_meta.associatedFiles = [associated_file2] + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] + + model_meta = _metadata_fb.ModelMetadataT() + model_meta.name = "Mobilenet_quantized" + model_meta.associatedFiles = [associated_file1] + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file = self.create_tempfile().full_path + with open(metadata_file, "wb") as f: + f.write(b.Output()) + return metadata_file + + def _create_model_buffer_with_wrong_identifier(self): + wrong_identifier = b"widn" + model = _schema_fb.ModelT() + model_builder = flatbuffers.Builder(0) + model_builder.Finish(model.Pack(model_builder), wrong_identifier) + return model_builder.Output() + + def _create_metadata_buffer_with_wrong_identifier(self): + # Creates a metadata with wrong identifier + wrong_identifier = b"widn" + metadata = _metadata_fb.ModelMetadataT() + metadata_builder = flatbuffers.Builder(0) + metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier) + return metadata_builder.Output() + + def _populate_metadata_with_identifier(self, model_buf, metadata_buf, + identifier): + # For testing purposes only. MetadataPopulator cannot populate metadata with + # wrong identifiers. + model = _schema_fb.ModelT.InitFromObj( + _schema_fb.Model.GetRootAsModel(model_buf, 0)) + buffer_field = _schema_fb.BufferT() + buffer_field.data = metadata_buf + model.buffers = [buffer_field] + # Creates a new metadata field. + metadata_field = _schema_fb.MetadataT() + metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME + metadata_field.buffer = len(model.buffers) - 1 + model.metadata = [metadata_field] + b = flatbuffers.Builder(0) + b.Finish(model.Pack(b), identifier) + return b.Output() + + def _create_metadata_file_with_version(self, metadata_file, min_version): + # Creates a new metadata file with the specified min_version for testing + # purposes. + metadata_buf = bytearray(_read_file(metadata_file)) + + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + b = flatbuffers.Builder(0) + b.Finish( + metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file_with_version = self.create_tempfile().full_path + with open(metadata_file_with_version, "wb") as f: + f.write(b.Output()) + return metadata_file_with_version + + +class MetadataPopulatorTest(MetadataTest): + + def _create_bert_tokenizer(self): + vocab_file_name = "bert_vocab" + vocab = _metadata_fb.AssociatedFileT() + vocab.name = vocab_file_name + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions + tokenizer.options = _metadata_fb.BertTokenizerOptionsT() + tokenizer.options.vocabFile = [vocab] + return tokenizer, [vocab_file_name] + + def _create_sentence_piece_tokenizer(self): + sp_model_name = "sp_model" + vocab_file_name = "sp_vocab" + sp_model = _metadata_fb.AssociatedFileT() + sp_model.name = sp_model_name + vocab = _metadata_fb.AssociatedFileT() + vocab.name = vocab_file_name + vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY + tokenizer = _metadata_fb.ProcessUnitT() + tokenizer.optionsType = ( + _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions) + tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT() + tokenizer.options.sentencePieceModel = [sp_model] + tokenizer.options.vocabFile = [vocab] + return tokenizer, [sp_model_name, vocab_file_name] + + def _create_tokenizer(self, tokenizer_type): + if tokenizer_type is Tokenizer.BERT_TOKENIZER: + return self._create_bert_tokenizer() + elif tokenizer_type is Tokenizer.SENTENCE_PIECE: + return self._create_sentence_piece_tokenizer() + else: + raise ValueError( + "The tokenizer type, {0}, is unsupported.".format(tokenizer_type)) + + def _create_tempfiles(self, file_names): + tempfiles = [] + for name in file_names: + tempfiles.append(self.create_tempfile(name).full_path) + return tempfiles + + def _create_model_meta_with_subgraph_meta(self, subgraph_meta): + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgraph_meta] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + return b.Output() + + def testToValidModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelFile(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataPopulator.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testToValidModelBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + self.assertIsInstance(populator, _metadata.MetadataPopulator) + + def testToInvalidModelBuffer(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf) + self.assertEqual("model_buf cannot be empty.", str(error.exception)) + + def testToModelBufferWithWrongIdentifier(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + with self.assertRaises(ValueError) as error: + _metadata.MetadataPopulator.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + + def testSinglePopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedPopulateAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_associated_files([self._file1, self._file2]) + # Loads file2 multiple times. + populator.load_associated_files([self._file2]) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertLen(packed_files, 2) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(IOError) as error: + populator.load_associated_files([self._invalid_file]) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulatePackedAssociatedFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + populator.load_associated_files([self._file1]) + populator.populate() + with self.assertRaises(ValueError) as error: + populator.load_associated_files([self._file1]) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testLoadAssociatedFileBuffers(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + file_buffer = _read_file(self._file1) + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [os.path.basename(self._file1)] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + def testRepeatedLoadAssociatedFileBuffers(self): + file_buffer1 = _read_file(self._file1) + file_buffer2 = _read_file(self._file2) + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + + populator.load_associated_file_buffers({ + self._file1: file_buffer1, + self._file2: file_buffer2 + }) + # Loads file2 multiple times. + populator.load_associated_file_buffers({self._file2: file_buffer2}) + populator.populate() + + packed_files = populator.get_packed_associated_file_list() + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertEqual(set(packed_files), set(expected_packed_files)) + + # Check if the model buffer read from file is the same as that read from + # get_model_buffer(). + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testLoadPackedAssociatedFileBuffersFails(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + file_buffer = _read_file(self._file1) + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + + # Load file1 again should fail. + with self.assertRaises(ValueError) as error: + populator.load_associated_file_buffers({self._file1: file_buffer}) + populator.populate() + self.assertEqual( + "File, '{0}', has already been packed.".format( + os.path.basename(self._file1)), str(error.exception)) + + def testGetPackedAssociatedFileList(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + packed_files = populator.get_packed_associated_file_list() + self.assertEqual(packed_files, []) + + def testPopulateMetadataFileToEmptyModelFile(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + model_buf_from_file = _read_file(self._model_file) + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # self._model_file already has two elements in the metadata field, so the + # populated TFLite metadata will be the third element. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + expected_metadata_buf = bytearray( + _read_file(self._metadata_file_with_version)) + self.assertEqual(metadata_buf, expected_metadata_buf) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateMetadataFileWithoutAssociatedFiles(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1]) + # Suppose to populate self._file2, because it is recorded in the metadta. + with self.assertRaises(ValueError) as error: + populator.populate() + self.assertEqual(("File, '{0}', is recorded in the metadata, but has " + "not been loaded into the populator.").format( + os.path.basename(self._file2)), str(error.exception)) + + def testPopulateMetadataBufferWithWrongIdentifier(self): + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(metadata_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def _assert_golden_metadata(self, model_file): + model_buf_from_file = _read_file(model_file) + model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0) + # There are two elements in model.Metadata array before the population. + # Metadata should be packed to the third element in the array. + metadata_field = model.Metadata(2) + self.assertEqual( + six.ensure_str(metadata_field.Name()), + six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME)) + + buffer_index = metadata_field.Buffer() + buffer_data = model.Buffers(buffer_index) + metadata_buf_np = buffer_data.DataAsNumpy() + metadata_buf = metadata_buf_np.tobytes() + expected_metadata_buf = bytearray( + _read_file(self._metadata_file_with_version)) + self.assertEqual(metadata_buf, expected_metadata_buf) + + def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self): + # First, creates a dummy metadata different from self._metadata_file. It + # needs to have the same input/output tensor numbers as self._model_file. + # Populates it and the associated files into the model. + input_meta = _metadata_fb.TensorMetadataT() + output_meta = _metadata_fb.TensorMetadataT() + subgraph = _metadata_fb.SubGraphMetadataT() + # Create a model with two inputs and one output. + subgraph.inputTensorMetadata = [input_meta, input_meta] + subgraph.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgraph] + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + # Populate the metadata. + populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator1.load_metadata_buffer(metadata_buf) + populator1.load_associated_files([self._file1, self._file2]) + populator1.populate() + + # Then, populate the metadata again. + populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator2.load_metadata_file(self._metadata_file) + populator2.populate() + + # Test if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self): + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + + # Tests if the metadata is populated correctly. + self._assert_golden_metadata(self._model_file) + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + # Up to now, we've proved the correctness of the model buffer that read from + # file. Then we'll test if get_model_buffer() gives the same model buffer. + model_buf_from_file = _read_file(self._model_file) + model_buf_from_getter = populator.get_model_buffer() + self.assertEqual(model_buf_from_file, model_buf_from_getter) + + def testPopulateInvalidMetadataFile(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(IOError) as error: + populator.load_metadata_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testPopulateInvalidMetadataBuffer(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer([]) + self.assertEqual("The metadata to be populated is empty.", + str(error.exception)) + + def testGetModelBufferBeforePopulatingData(self): + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + model_buf = populator.get_model_buffer() + expected_model_buf = self._model_buf + self.assertEqual(model_buf, expected_model_buf) + + def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self): + # Create a dummy metadata without Subgraph. + model_meta = _metadata_fb.ModelMetadataT() + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + "The number of SubgraphMetadata should be exactly one, but got 0.", + str(error.exception)) + + def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self): + # Create a dummy metadata with no input tensor metadata, while the expected + # number is 2. + output_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.outputTensorMetadata = [output_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of input tensors (2) should match the number of " + "input tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self): + # Create a dummy metadata with no output tensor metadata, while the expected + # number is 1. + input_meta = _metadata_fb.TensorMetadataT() + subgprah_meta = _metadata_fb.SubGraphMetadataT() + subgprah_meta.inputTensorMetadata = [input_meta, input_meta] + model_meta = _metadata_fb.ModelMetadataT() + model_meta.subgraphMetadata = [subgprah_meta] + builder = flatbuffers.Builder(0) + builder.Finish( + model_meta.Pack(builder), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + meta_buf = builder.Output() + + populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf) + with self.assertRaises(ValueError) as error: + populator.load_metadata_buffer(meta_buf) + self.assertEqual( + ("The number of output tensors (1) should match the number of " + "output tensor metadata (0)"), str(error.exception)) + + def testLoadMetadataAndAssociatedFilesShouldSucceeds(self): + # Create a src model with metadata and two associated files. + src_model_buf = self._create_model_buf() + populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf) + populator_src.load_metadata_file(self._metadata_file) + populator_src.load_associated_files([self._file1, self._file2]) + populator_src.populate() + + # Create a model to be populated with the metadata and files from + # src_model_buf. + dst_model_buf = self._create_model_buf() + populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf) + populator_dst.load_metadata_and_associated_files( + populator_src.get_model_buffer()) + populator_dst.populate() + + # Tests if the metadata and associated files are populated correctly. + dst_model_file = self.create_tempfile().full_path + with open(dst_model_file, "wb") as f: + f.write(populator_dst.get_model_buffer()) + self._assert_golden_metadata(dst_model_file) + + recorded_files = populator_dst.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(self.expected_recorded_files)) + + @parameterized.named_parameters( + { + "testcase_name": "InputTensorWithBert", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "OutputTensorWithBert", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "InputTensorWithSentencePiece", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }, { + "testcase_name": "OutputTensorWithSentencePiece", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }) + def testGetRecordedAssociatedFileListWithSubgraphTensor( + self, tensor_type, tokenizer_type): + # Creates a metadata with the tokenizer in the tensor process units. + tokenizer, expected_files = self._create_tokenizer(tokenizer_type) + + # Create the tensor with process units. + tensor = _metadata_fb.TensorMetadataT() + tensor.processUnits = [tokenizer] + + # Create the subgrah with the tensor. + subgraph = _metadata_fb.SubGraphMetadataT() + dummy_tensor_meta = _metadata_fb.TensorMetadataT() + subgraph.outputTensorMetadata = [dummy_tensor_meta] + if tensor_type is TensorType.INPUT: + subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta] + subgraph.outputTensorMetadata = [dummy_tensor_meta] + elif tensor_type is TensorType.OUTPUT: + subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] + subgraph.outputTensorMetadata = [tensor] + else: + raise ValueError( + "The tensor type, {0}, is unsupported.".format(tensor_type)) + + # Create a model metadata with the subgraph metadata + meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Creates the tempfiles. + tempfiles = self._create_tempfiles(expected_files) + + # Creates the MetadataPopulator object. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(meta_buffer) + populator.load_associated_files(tempfiles) + populator.populate() + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(expected_files)) + + @parameterized.named_parameters( + { + "testcase_name": "InputTensorWithBert", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "OutputTensorWithBert", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.BERT_TOKENIZER + }, { + "testcase_name": "InputTensorWithSentencePiece", + "tensor_type": TensorType.INPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }, { + "testcase_name": "OutputTensorWithSentencePiece", + "tensor_type": TensorType.OUTPUT, + "tokenizer_type": Tokenizer.SENTENCE_PIECE + }) + def testGetRecordedAssociatedFileListWithSubgraphProcessUnits( + self, tensor_type, tokenizer_type): + # Creates a metadata with the tokenizer in the subgraph process units. + tokenizer, expected_files = self._create_tokenizer(tokenizer_type) + + # Create the subgraph with process units. + subgraph = _metadata_fb.SubGraphMetadataT() + if tensor_type is TensorType.INPUT: + subgraph.inputProcessUnits = [tokenizer] + elif tensor_type is TensorType.OUTPUT: + subgraph.outputProcessUnits = [tokenizer] + else: + raise ValueError( + "The tensor type, {0}, is unsupported.".format(tensor_type)) + + # Creates the input and output tensor meta to match self._model_file. + dummy_tensor_meta = _metadata_fb.TensorMetadataT() + subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta] + subgraph.outputTensorMetadata = [dummy_tensor_meta] + + # Create a model metadata with the subgraph metadata + meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Creates the tempfiles. + tempfiles = self._create_tempfiles(expected_files) + + # Creates the MetadataPopulator object. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(meta_buffer) + populator.load_associated_files(tempfiles) + populator.populate() + + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set(expected_files)) + + def testPopulatedFullPathAssociatedFileShouldSucceed(self): + # Create AssociatedFileT using the full path file name. + associated_file = _metadata_fb.AssociatedFileT() + associated_file.name = self._file1 + + # Create model metadata with the associated file. + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.associatedFiles = [associated_file] + # Creates the input and output tensor metadata to match self._model_file. + dummy_tensor = _metadata_fb.TensorMetadataT() + subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor] + subgraph.outputTensorMetadata = [dummy_tensor] + md_buffer = self._create_model_meta_with_subgraph_meta(subgraph) + + # Populate the metadata to a model. + populator = _metadata.MetadataPopulator.with_model_file(self._model_file) + populator.load_metadata_buffer(md_buffer) + populator.load_associated_files([self._file1]) + populator.populate() + + # The recorded file name in metadata should only contain file basename; file + # directory should not be included. + recorded_files = populator.get_recorded_associated_file_list() + self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)])) + + +class MetadataDisplayerTest(MetadataTest): + + def setUp(self): + super(MetadataDisplayerTest, self).setUp() + self._model_with_meta_file = ( + self._create_model_with_metadata_and_associated_files()) + + def _create_model_with_metadata_and_associated_files(self): + model_buf = self._create_model_buf() + model_file = self.create_tempfile().full_path + with open(model_file, "wb") as f: + f.write(model_buf) + + populator = _metadata.MetadataPopulator.with_model_file(model_file) + populator.load_metadata_file(self._metadata_file) + populator.load_associated_files([self._file1, self._file2]) + populator.populate() + return model_file + + def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_buf = self._create_metadata_buffer_with_wrong_identifier() + model_buf = self._populate_metadata_with_identifier( + model_buf, metadata_buf, + _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The metadata buffer does not have the expected identifier, and may not" + " be a valid TFLite Metadata.", str(error.exception)) + + def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self): + model_buf = self._create_model_buffer_with_wrong_identifier() + metadata_file = self._create_metadata_file() + wrong_identifier = b"widn" + metadata_buf = bytearray(_read_file(metadata_file)) + model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf, + wrong_identifier) + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(model_buf) + self.assertEqual( + "The model provided does not have the expected identifier, and " + "may not be a valid TFLite model.", str(error.exception)) + + def testLoadModelFileInvalidModelFileThrowsException(self): + with self.assertRaises(IOError) as error: + _metadata.MetadataDisplayer.with_model_file(self._invalid_file) + self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file), + str(error.exception)) + + def testLoadModelFileModelWithoutMetadataThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_file(self._model_file) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def testLoadModelFileModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def testLoadModelBufferInvalidModelBufferThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1)) + self.assertEqual("model_buffer cannot be empty.", str(error.exception)) + + def testLoadModelBufferModelWithOutMetadataThrowsException(self): + with self.assertRaises(ValueError) as error: + _metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf()) + self.assertEqual("The model does not have metadata.", str(error.exception)) + + def testLoadModelBufferModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_buffer( + _read_file(self._model_with_meta_file)) + self.assertIsInstance(displayer, _metadata.MetadataDisplayer) + + def testGetAssociatedFileBufferShouldSucceed(self): + # _model_with_meta_file contains file1 and file2. + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + + actual_content = displayer.get_associated_file_buffer("file2") + self.assertEqual(actual_content, self._file2_content) + + def testGetAssociatedFileBufferFailsWithNonExistentFile(self): + # _model_with_meta_file contains file1 and file2. + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + + non_existent_file = "non_existent_file" + with self.assertRaises(ValueError) as error: + displayer.get_associated_file_buffer(non_existent_file) + self.assertEqual( + "The file, {}, does not exist in the model.".format(non_existent_file), + str(error.exception)) + + def testGetMetadataBufferShouldSucceed(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + actual_buffer = displayer.get_metadata_buffer() + actual_json = _metadata.convert_to_json(actual_buffer) + + # Verifies the generated json file. + golden_json_file_path = test_utils.get_test_data_path("golden_json.json") + with open(golden_json_file_path, "r") as f: + expected = f.read() + self.assertEqual(actual_json, expected) + + def testGetMetadataJsonModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + actual = displayer.get_metadata_json() + + # Verifies the generated json file. + golden_json_file_path = test_utils.get_test_data_path("golden_json.json") + expected = _read_file(golden_json_file_path, "r") + self.assertEqual(actual, expected) + + def testGetPackedAssociatedFileListModelWithMetadata(self): + displayer = _metadata.MetadataDisplayer.with_model_file( + self._model_with_meta_file) + packed_files = displayer.get_packed_associated_file_list() + + expected_packed_files = [ + os.path.basename(self._file1), + os.path.basename(self._file2) + ] + self.assertLen( + packed_files, 2, + "The following two associated files packed to the model: {0}; {1}" + .format(expected_packed_files[0], expected_packed_files[1])) + self.assertEqual(set(packed_files), set(expected_packed_files)) + + +class MetadataUtilTest(MetadataTest): + + def test_convert_to_json_should_succeed(self): + metadata_buf = _read_file(self._metadata_file_with_version) + metadata_json = _metadata.convert_to_json(metadata_buf) + + # Verifies the generated json file. + golden_json_file_path = test_utils.get_test_data_path("golden_json.json") + expected = _read_file(golden_json_file_path, "r") + self.assertEqual(metadata_json, expected) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/test_util.py b/mediapipe/tasks/python/test/test_util.py index cf1dfec2e..531a18f7a 100644 --- a/mediapipe/tasks/python/test/test_util.py +++ b/mediapipe/tasks/python/test/test_util.py @@ -16,7 +16,6 @@ import os from absl import flags -import cv2 from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import image_frame as image_frame_module @@ -44,12 +43,3 @@ def get_test_data_path(file_or_dirname: str) -> str: if f.endswith(file_or_dirname): return os.path.join(directory, f) raise ValueError("No %s in test directory" % file_or_dirname) - - -# TODO: Implement image util module to read image data from file. -def read_test_image(image_file: str) -> _Image: - """Reads a MediaPipe Image from the image file.""" - image_data = cv2.imread(image_file) - if image_data.shape[2] != _RGB_CHANNELS: - raise ValueError("Input image must contain three channel rgb data.") - return _Image(_ImageFormat.SRGB, cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)) diff --git a/mediapipe/tasks/python/test/test_utils.py b/mediapipe/tasks/python/test/test_utils.py new file mode 100644 index 000000000..531a18f7a --- /dev/null +++ b/mediapipe/tasks/python/test/test_utils.py @@ -0,0 +1,45 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test util for MediaPipe Tasks.""" + +import os + +from absl import flags + +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import image_frame as image_frame_module + +FLAGS = flags.FLAGS +_Image = image_module.Image +_ImageFormat = image_frame_module.ImageFormat +_RGB_CHANNELS = 3 + + +def test_srcdir(): + """Returns the path where to look for test data files.""" + if "test_srcdir" in flags.FLAGS: + return flags.FLAGS["test_srcdir"].value + elif "TEST_SRCDIR" in os.environ: + return os.environ["TEST_SRCDIR"] + else: + raise RuntimeError("Missing TEST_SRCDIR environment.") + + +def get_test_data_path(file_or_dirname: str) -> str: + """Returns full test data path.""" + for (directory, subdirs, files) in os.walk(test_srcdir()): + for f in subdirs + files: + if f.endswith(file_or_dirname): + return os.path.join(directory, f) + raise ValueError("No %s in test directory" % file_or_dirname) diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 6b6b9e3e2..df2e72f98 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -31,7 +31,7 @@ py_test( "//mediapipe/tasks/python/components/containers:category", "//mediapipe/tasks/python/components/containers:detections", "//mediapipe/tasks/python/core:base_options", - "//mediapipe/tasks/python/test:test_util", + "//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/vision:object_detector", "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index a83031342..95b6bf867 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -25,7 +25,7 @@ from mediapipe.tasks.python.components.containers import bounding_box as boundin from mediapipe.tasks.python.components.containers import category as category_module from mediapipe.tasks.python.components.containers import detections as detections_module from mediapipe.tasks.python.core import base_options as base_options_module -from mediapipe.tasks.python.test import test_util +from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import object_detector from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module @@ -44,7 +44,7 @@ _IMAGE_FILE = 'cats_and_dogs.jpg' _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[ _Detection( bounding_box=_BoundingBox( - origin_x=608, origin_y=164, width=381, height=432), + origin_x=608, origin_y=161, width=381, height=439), categories=[ _Category( index=None, @@ -64,7 +64,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[ ]), _Detection( bounding_box=_BoundingBox( - origin_x=257, origin_y=394, width=173, height=202), + origin_x=256, origin_y=395, width=173, height=202), categories=[ _Category( index=None, @@ -74,7 +74,7 @@ _EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[ ]), _Detection( bounding_box=_BoundingBox( - origin_x=362, origin_y=195, width=325, height=412), + origin_x=362, origin_y=191, width=325, height=419), categories=[ _Category( index=None, @@ -98,9 +98,9 @@ class ObjectDetectorTest(parameterized.TestCase): def setUp(self): super().setUp() - self.test_image = test_util.read_test_image( - test_util.get_test_data_path(_IMAGE_FILE)) - self.model_path = test_util.get_test_data_path(_MODEL_FILE) + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_utils.get_test_data_path(_MODEL_FILE) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. @@ -153,9 +153,9 @@ class ObjectDetectorTest(parameterized.TestCase): detector = _ObjectDetector.create_from_options(options) # Performs object detection on the input. - image_result = detector.detect(self.test_image) + detection_result = detector.detect(self.test_image) # Comparing results. - self.assertEqual(image_result, expected_detection_result) + self.assertEqual(detection_result, expected_detection_result) # Closes the detector explicitly when the detector is not used in # a context. detector.close() @@ -179,9 +179,9 @@ class ObjectDetectorTest(parameterized.TestCase): base_options=base_options, max_results=max_results) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. - image_result = detector.detect(self.test_image) + detection_result = detector.detect(self.test_image) # Comparing results. - self.assertEqual(image_result, expected_detection_result) + self.assertEqual(detection_result, expected_detection_result) def test_score_threshold_option(self): options = _ObjectDetectorOptions( @@ -189,8 +189,8 @@ class ObjectDetectorTest(parameterized.TestCase): score_threshold=_SCORE_THRESHOLD) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. - image_result = detector.detect(self.test_image) - detections = image_result.detections + detection_result = detector.detect(self.test_image) + detections = detection_result.detections for detection in detections: score = detection.categories[0].score @@ -204,8 +204,8 @@ class ObjectDetectorTest(parameterized.TestCase): max_results=_MAX_RESULTS) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. - image_result = detector.detect(self.test_image) - detections = image_result.detections + detection_result = detector.detect(self.test_image) + detections = detection_result.detections self.assertLessEqual( len(detections), _MAX_RESULTS, 'Too many results returned.') @@ -216,8 +216,8 @@ class ObjectDetectorTest(parameterized.TestCase): category_allowlist=_ALLOW_LIST) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. - image_result = detector.detect(self.test_image) - detections = image_result.detections + detection_result = detector.detect(self.test_image) + detections = detection_result.detections for detection in detections: label = detection.categories[0].category_name @@ -230,8 +230,8 @@ class ObjectDetectorTest(parameterized.TestCase): category_denylist=_DENY_LIST) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. - image_result = detector.detect(self.test_image) - detections = image_result.detections + detection_result = detector.detect(self.test_image) + detections = detection_result.detections for detection in detections: label = detection.categories[0].category_name @@ -257,8 +257,8 @@ class ObjectDetectorTest(parameterized.TestCase): score_threshold=1) with _ObjectDetector.create_from_options(options) as detector: # Performs object detection on the input. - image_result = detector.detect(self.test_image) - self.assertEmpty(image_result.detections) + detection_result = detector.detect(self.test_image) + self.assertEmpty(detection_result.detections) def test_missing_result_callback(self): options = _ObjectDetectorOptions( diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 1cf94a38f..8bda87ae2 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -28,9 +28,13 @@ mediapipe_files(srcs = [ "mobile_ica_8bit-without-model-metadata.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", + "mobilenet_v2_1.0_224_quant.tflite", ]) -exports_files(["external_file"]) +exports_files([ + "external_file", + "golden_json.json", +]) filegroup( name = "model_files", @@ -40,10 +44,14 @@ filegroup( "mobile_ica_8bit-without-model-metadata.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", + "mobilenet_v2_1.0_224_quant.tflite", ], ) filegroup( name = "data_files", - srcs = ["external_file"], + srcs = [ + "external_file", + "golden_json.json", + ], ) diff --git a/mediapipe/tasks/testdata/metadata/golden_json.json b/mediapipe/tasks/testdata/metadata/golden_json.json new file mode 100644 index 000000000..601a5976c --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/golden_json.json @@ -0,0 +1,28 @@ +{ + "name": "Mobilenet_quantized", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + }, + { + } + ], + "output_tensor_metadata": [ + { + "associated_files": [ + { + "name": "file2" + } + ] + } + ] + } + ], + "associated_files": [ + { + "name": "file1" + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 41eb44c21..5eda42601 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -85,6 +85,10 @@ filegroup( "selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg", ], + visibility = [ + "//mediapipe/python:__subpackages__", + "//mediapipe/tasks:internal", + ], ) # TODO Create individual filegroup for models required for each Tasks. diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 19fbbc14d..671f47505 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -552,6 +552,16 @@ void AnnotationRenderer::DrawText(const RenderAnnotation& annotation) { origin.y += text_size.height / 2; } + if (text.outline_thickness() > 0.0) { + const int background_thickness = ClampThickness( + round((annotation.thickness() + 2.0 * text.outline_thickness()) * + scale_factor_)); + const cv::Scalar outline_color = + MediapipeColorToOpenCVColor(text.outline_color()); + cv::putText(mat_image_, text.display_text(), origin, font_face, font_scale, + outline_color, background_thickness, /*lineType=*/8, + /*bottomLeftOrigin=*/flip_text_vertically_); + } cv::putText(mat_image_, text.display_text(), origin, font_face, font_scale, color, thickness, /*lineType=*/8, /*bottomLeftOrigin=*/flip_text_vertically_); diff --git a/mediapipe/util/render_data.proto b/mediapipe/util/render_data.proto index 0ff6b3409..62cb750b0 100644 --- a/mediapipe/util/render_data.proto +++ b/mediapipe/util/render_data.proto @@ -168,6 +168,12 @@ message RenderAnnotation { // [left, baseline] represent [center_x, center_y]. optional bool center_horizontally = 7 [default = false]; optional bool center_vertically = 8 [default = false]; + + // Thickness of the text outline. + optional double outline_thickness = 11 [default = 0.0]; + + // Color of the text outline. + optional Color outline_color = 12; } // The RenderAnnotation can be one of the below formats. diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index e246bbd8d..cd291fc1e 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -166,6 +166,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmark_with_attention.tflite?generation=1661875751615925"], ) + http_file( + name = "com_google_mediapipe_golden_json_json", + sha256 = "55c0c88748d099aa379930504df62c6c8f1d8874ea52d2f8a925f352c4c7f09c", + urls = ["https://storage.googleapis.com/mediapipe-assets/golden_json.json?generation=1664340169675228"], + ) + http_file( name = "com_google_mediapipe_hair_segmentation_tflite", sha256 = "d2c940c4fd80edeaf38f5d7387d1b4235ee320ed120080df67c663e749e77633", @@ -316,6 +322,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite", + sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite", sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339", diff --git a/third_party/stblib.BUILD b/third_party/stblib.BUILD index 5169906cc..8a419d1f2 100644 --- a/third_party/stblib.BUILD +++ b/third_party/stblib.BUILD @@ -7,16 +7,19 @@ package( licenses(["notice"]) # MIT license -exports_files(["LICENSE"]) +COPTS = select({ + "@platforms//os:windows": [], + "//conditions:default": [ + "-Wno-unused-function", + "$(STACK_FRAME_UNLIMITED)", + ], +}) cc_library( name = "stb_image", srcs = ["stb_image.c"], hdrs = ["stb_image.h"], - copts = [ - "-Wno-unused-function", - "$(STACK_FRAME_UNLIMITED)", - ], + copts = COPTS, includes = ["."], ) @@ -24,5 +27,6 @@ cc_library( name = "stb_image_write", srcs = ["stb_image_write.c"], hdrs = ["stb_image_write.h"], + copts = COPTS, includes = ["."], )