diff --git a/WORKSPACE b/WORKSPACE index 146916c5c..5a47cf6b7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -172,6 +172,10 @@ http_archive( urls = [ "https://github.com/google/sentencepiece/archive/1.0.0.zip", ], + patches = [ + "//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff", + ], + patch_args = ["-p1"], repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"}, ) diff --git a/docs/BUILD b/docs/BUILD new file mode 100644 index 000000000..cb8794dab --- /dev/null +++ b/docs/BUILD @@ -0,0 +1,14 @@ +# Placeholder for internal Python strict binary compatibility macro. + +py_binary( + name = "build_py_api_docs", + srcs = ["build_py_api_docs.py"], + deps = [ + "//mediapipe", + "//third_party/py/absl:app", + "//third_party/py/absl/flags", + "//third_party/py/tensorflow_docs", + "//third_party/py/tensorflow_docs/api_generator:generate_lib", + "//third_party/py/tensorflow_docs/api_generator:public_api", + ], +) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py new file mode 100644 index 000000000..9911d0736 --- /dev/null +++ b/docs/build_py_api_docs.py @@ -0,0 +1,85 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""MediaPipe reference docs generation script. + +This script generates API reference docs for the `mediapipe` PIP package. + +$> pip install -U git+https://github.com/tensorflow/docs mediapipe +$> python build_py_api_docs.py +""" + +import os + +from absl import app +from absl import flags + +from tensorflow_docs.api_generator import generate_lib +from tensorflow_docs.api_generator import public_api + +try: + # mediapipe has not been set up to work with bazel yet, so catch & report. + import mediapipe # pytype: disable=import-error +except ImportError as e: + raise ImportError('Please `pip install mediapipe`.') from e + + +PROJECT_SHORT_NAME = 'mp' +PROJECT_FULL_NAME = 'MediaPipe' + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + default='/tmp/generated_docs', + help='Where to write the resulting docs.') + +_URL_PREFIX = flags.DEFINE_string( + 'code_url_prefix', + 'https://github.com/google/mediapipe/tree/master/mediapipe', + 'The url prefix for links to code.') + +_SEARCH_HINTS = flags.DEFINE_bool( + 'search_hints', True, + 'Include metadata search hints in the generated files') + +_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', + 'Path prefix in the _toc.yaml') + + +def gen_api_docs(): + """Generates API docs for the mediapipe package.""" + + doc_generator = generate_lib.DocGenerator( + root_title=PROJECT_FULL_NAME, + py_modules=[(PROJECT_SHORT_NAME, mediapipe)], + base_dir=os.path.dirname(mediapipe.__file__), + code_url_prefix=_URL_PREFIX.value, + search_hints=_SEARCH_HINTS.value, + site_path=_SITE_PATH.value, + # This callback ensures that docs are only generated for objects that + # are explicitly imported in your __init__.py files. There are other + # options but this is a good starting point. + callbacks=[public_api.explicit_package_contents_filter], + ) + + doc_generator.build(_OUTPUT_DIR.value) + + print('Docs output to:', _OUTPUT_DIR.value) + + +def main(_): + gen_api_docs() + + +if __name__ == '__main__': + app.run(main) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index ae8a0cbf0..99b5b3e91 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -253,6 +253,26 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "regex_preprocessor_calculator_test", + srcs = ["regex_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:text_classifier_models"], + linkopts = ["-ldl"], + deps = [ + ":regex_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:sink", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "text_to_tensor_calculator", srcs = ["text_to_tensor_calculator.cc"], @@ -307,6 +327,27 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "universal_sentence_encoder_preprocessor_calculator_test", + srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"], + deps = [ + ":universal_sentence_encoder_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:options_map", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 1f3768ee0..bd8eb3eed 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -26,6 +26,8 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe + namespace mediapipe { namespace api2 { @@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process( CalculatorContext* cc, const std::vector& input_tensors, std::vector& output_tensors) { return gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> absl::Status { + [this, cc, &input_tensors, &output_tensors]() -> absl::Status { // Explicitly copy input. for (int i = 0; i < input_tensors.size(); ++i) { glBindBuffer(GL_COPY_READ_BUFFER, @@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process( } // Run inference. - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + { + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } output_tensors.reserve(output_size_); for (int i = 0; i < output_size_; ++i) { diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 7e11ee072..52359f7f5 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -32,6 +32,8 @@ #include "mediapipe/util/android/file/base/helpers.h" #endif // MEDIAPIPE_ANDROID +#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe + namespace mediapipe { namespace api2 { @@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl const mediapipe::InferenceCalculatorOptions::Delegate& delegate); absl::StatusOr> Process( - const std::vector& input_tensors); + CalculatorContext* cc, const std::vector& input_tensors); absl::Status Close(); @@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init( absl::StatusOr> InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( - const std::vector& input_tensors) { + CalculatorContext* cc, const std::vector& input_tensors) { std::vector output_tensors; MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> absl::Status { + [this, cc, &input_tensors, &output_tensors]() -> absl::Status { for (int i = 0; i < input_tensors.size(); ++i) { MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( input_tensors[i].GetOpenGlBufferReadView().name(), i)); @@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( output_tensors.back().GetOpenGlBufferWriteView().name(), i)); } // Run inference. - return tflite_gpu_runner_->Invoke(); + { + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + return tflite_gpu_runner_->Invoke(); + } })); return output_tensors; @@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) { auto output_tensors = absl::make_unique>(); ASSIGN_OR_RETURN(*output_tensors, - gpu_inference_runner_->Process(input_tensors)); + gpu_inference_runner_->Process(cc, input_tensors)); kOutTensors(cc).Send(std::move(output_tensors)); return absl::OkStatus(); diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 82905d2f5..7dce211c8 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -289,8 +289,15 @@ class NodeBase { template T& GetOptions() { + return GetOptions(T::ext); + } + + // Use this API when the proto extension does not follow the "ext" naming + // convention. + template + auto& GetOptions(const E& extension) { options_used_ = true; - return *options_.MutableExtension(T::ext); + return *options_.MutableExtension(extension); } protected: @@ -386,8 +393,15 @@ class PacketGenerator { template T& GetOptions() { + return GetOptions(T::ext); + } + + // Use this API when the proto extension does not follow the "ext" naming + // convention. + template + auto& GetOptions(const E& extension) { options_used_ = true; - return *options_.MutableExtension(T::ext); + return *options_.MutableExtension(extension); } template diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h index f9f0d7a8a..19f37f9de 100644 --- a/mediapipe/framework/calculator_base.h +++ b/mediapipe/framework/calculator_base.h @@ -185,7 +185,7 @@ class CalculatorBaseFactory { // Functions for checking that the calculator has the required GetContract. template constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { - typedef absl::Status (*GetContractType)(CalculatorContract * cc); + typedef absl::Status (*GetContractType)(CalculatorContract* cc); return std::is_same::value; } template diff --git a/mediapipe/framework/calculator_profile.proto b/mediapipe/framework/calculator_profile.proto index 06ec678a9..1512da6af 100644 --- a/mediapipe/framework/calculator_profile.proto +++ b/mediapipe/framework/calculator_profile.proto @@ -133,7 +133,12 @@ message GraphTrace { TPU_TASK = 13; GPU_CALIBRATION = 14; PACKET_QUEUED = 15; + GPU_TASK_INVOKE = 16; + TPU_TASK_INVOKE = 17; } + // //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags, + // //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list, + // ) // The timing for one packet set being processed at one caclulator node. message CalculatorTrace { diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index b967b27fb..c3241d911 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -293,7 +293,6 @@ mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], visibility = ["//visibility:public"], - deps = ["//mediapipe/framework/formats:location_data_proto"], ) mediapipe_register_type( diff --git a/mediapipe/framework/profiler/trace_buffer.h b/mediapipe/framework/profiler/trace_buffer.h index 069f09610..60352c705 100644 --- a/mediapipe/framework/profiler/trace_buffer.h +++ b/mediapipe/framework/profiler/trace_buffer.h @@ -109,6 +109,11 @@ struct TraceEvent { static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK; static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION; static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED; + static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE; + static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE; + // //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags, + // //depot/mediapipe/framework/calculator_profile.proto:event_type, + // ) }; // Packet trace log buffer. diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 13710a688..78b196b08 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static void EglThreadExitCallback(void* key_value) { +#if defined(__ANDROID__) + eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE, + EGL_NO_CONTEXT); +#else // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display // parameter for eglMakeCurrent. This behavior is not portable to all EGL // implementations, and should be considered as an undocumented vendor // extension. // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml + // + // NOTE: crashes on some Android devices (occurs with libGLES_meow.so). eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); +#endif eglReleaseThread(); } diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 4af9dae78..e3a878f91 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -17,8 +17,8 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; -import com.google.mediapipe.framework.image.Image; -import com.google.mediapipe.framework.image.ImageProperties; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.framework.image.MPImageProperties; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator { } /** - * Creates an Image packet from an {@link Image}. + * Creates a MediaPipe Image packet from a {@link MPImage}. * *

The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. */ - public Packet createImage(Image image) { + public Packet createImage(MPImage image) { // TODO: Choose the best storage from multiple containers. - ImageProperties properties = image.getContainedImageProperties().get(0); - if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { + MPImageProperties properties = image.getContainedImageProperties().get(0); + if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) { ByteBuffer buffer = ByteBufferExtractor.extract(image); int numChannels = 0; switch (properties.getImageFormat()) { - case Image.IMAGE_FORMAT_RGBA: + case MPImage.IMAGE_FORMAT_RGBA: numChannels = 4; break; - case Image.IMAGE_FORMAT_RGB: + case MPImage.IMAGE_FORMAT_RGB: numChannels = 3; break; - case Image.IMAGE_FORMAT_ALPHA: + case MPImage.IMAGE_FORMAT_ALPHA: numChannels = 1; break; default: // fall out @@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator { int height = image.getHeight(); return createImage(buffer, width, height, numChannels); } - if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { + if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) { Bitmap bitmap = BitmapExtractor.extract(image); if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java index 4c6cebd4d..d6f50bf30 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java @@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image; import android.graphics.Bitmap; /** - * Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. + * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}. * - *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise + *

Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise * {@link IllegalArgumentException} will be thrown. */ public final class BitmapExtractor { /** - * Extracts a {@link android.graphics.Bitmap} from an {@link Image}. + * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}. * * @param image the image to extract {@link android.graphics.Bitmap} from. - * @return the {@link android.graphics.Bitmap} stored in {@link Image} + * @return the {@link android.graphics.Bitmap} stored in {@link MPImage} * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - public static Bitmap extract(Image image) { - ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); + public static Bitmap extract(MPImage image) { + MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP); if (imageContainer != null) { return ((BitmapImageContainer) imageContainer).getBitmap(); } else { // TODO: Support ByteBuffer -> Bitmap conversion. throw new IllegalArgumentException( - "Extracting Bitmap from an Image created by objects other than Bitmap is not" + "Extracting Bitmap from a MPImage created by objects other than Bitmap is not" + " supported"); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java index ea2ca6b1f..988cdf542 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java @@ -22,7 +22,7 @@ import android.provider.MediaStore; import java.io.IOException; /** - * Builds {@link Image} from {@link android.graphics.Bitmap}. + * Builds {@link MPImage} from {@link android.graphics.Bitmap}. * *

You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content @@ -49,7 +49,7 @@ public class BitmapImageBuilder { } /** - * Creates the builder to build {@link Image} from a file. + * Creates the builder to build {@link MPImage} from a file. * * @param context the application context. * @param uri the path to the resource file. @@ -58,15 +58,15 @@ public class BitmapImageBuilder { this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ BitmapImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image( + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage( new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java index 0457e1e9b..6fbcac214 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java @@ -16,19 +16,19 @@ limitations under the License. package com.google.mediapipe.framework.image; import android.graphics.Bitmap; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; -class BitmapImageContainer implements ImageContainer { +class BitmapImageContainer implements MPImageContainer { private final Bitmap bitmap; - private final ImageProperties properties; + private final MPImageProperties properties; public BitmapImageContainer(Bitmap bitmap) { this.bitmap = bitmap; this.properties = - ImageProperties.builder() + MPImageProperties.builder() .setImageFormat(convertFormatCode(bitmap.getConfig())) - .setStorageType(Image.STORAGE_TYPE_BITMAP) + .setStorageType(MPImage.STORAGE_TYPE_BITMAP) .build(); } @@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer { } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } @@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer { bitmap.recycle(); } - @ImageFormat + @MPImageFormat static int convertFormatCode(Bitmap.Config config) { switch (config) { case ALPHA_8: - return Image.IMAGE_FORMAT_ALPHA; + return MPImage.IMAGE_FORMAT_ALPHA; case ARGB_8888: - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; default: - return Image.IMAGE_FORMAT_UNKNOWN; + return MPImage.IMAGE_FORMAT_UNKNOWN; } } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java index a0e8c3dff..748a10667 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config; import android.os.Build.VERSION; import android.os.Build.VERSION_CODES; import com.google.auto.value.AutoValue; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Locale; /** - * Utility for extracting {@link ByteBuffer} from {@link Image}. + * Utility for extracting {@link ByteBuffer} from {@link MPImage}. * - *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise - * {@link IllegalArgumentException} will be thrown. + *

Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER}, + * otherwise {@link IllegalArgumentException} will be thrown. */ public class ByteBufferExtractor { /** - * Extracts a {@link ByteBuffer} from an {@link Image}. + * Extracts a {@link ByteBuffer} from a {@link MPImage}. * *

The returned {@link ByteBuffer} is a read-only view, with the first available {@link - * ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. + * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}. * - * @see Image#getContainedImageProperties() + * @see MPImage#getContainedImageProperties() * @return A read-only {@link ByteBuffer}. * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. */ @SuppressLint("SwitchIntDef") - public static ByteBuffer extract(Image image) { - ImageContainer container = image.getContainer(); + public static ByteBuffer extract(MPImage image) { + MPImageContainer container = image.getContainer(); switch (container.getImageProperties().getStorageType()) { - case Image.STORAGE_TYPE_BYTEBUFFER: + case MPImage.STORAGE_TYPE_BYTEBUFFER: ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); default: throw new IllegalArgumentException( - "Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" + "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not" + " supported"); } } /** - * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}. * *

Format conversion spec: * @@ -70,26 +70,26 @@ public class ByteBufferExtractor { * * @param image the image to extract buffer from. * @param targetFormat the image format of the result bytebuffer. - * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @return the readonly {@link ByteBuffer} stored in {@link MPImage} * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { - ImageContainer container; - ImageProperties byteBufferProperties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { + MPImageContainer container; + MPImageProperties byteBufferProperties = + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER) .setImageFormat(targetFormat) .build(); if ((container = image.getContainer(byteBufferProperties)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); + @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) .asReadOnlyBuffer(); - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; ByteBuffer byteBuffer = extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) @@ -98,85 +98,89 @@ public class ByteBufferExtractor { return byteBuffer; } else { throw new IllegalArgumentException( - "Extracting ByteBuffer from an Image created by objects other than Bitmap or" + "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or" + " Bytebuffer is not supported"); } } - /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ + /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */ @AutoValue abstract static class Result { - /** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ + /** + * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}. + */ public abstract ByteBuffer buffer(); - /** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ - @ImageFormat + /** + * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}. + */ + @MPImageFormat public abstract int format(); - static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { + static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) { return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); } } /** - * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}. * *

It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. * - * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @return the readonly {@link ByteBuffer} stored in {@link MPImage} * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with * given {@code imageFormat} */ - static Result extractInRecommendedFormat(Image image) { - ImageContainer container; - if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + static Result extractInRecommendedFormat(MPImage image) { + MPImageContainer container; + if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); - @ImageFormat int format = adviseImageFormat(bitmap); + @MPImageFormat int format = adviseImageFormat(bitmap); Result result = Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); boolean unused = image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); return result; - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return Result.create( byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), byteBufferImageContainer.getImageFormat()); } else { throw new IllegalArgumentException( - "Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" + "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer" + " is not supported"); } } - @ImageFormat + @MPImageFormat private static int adviseImageFormat(Bitmap bitmap) { if (bitmap.getConfig() == Config.ARGB_8888) { - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; } else { throw new IllegalArgumentException( String.format( - "Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" + "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not" + " supported", bitmap.getConfig())); } } private static ByteBuffer extractByteBufferFromBitmap( - Bitmap bitmap, @ImageFormat int imageFormat) { + Bitmap bitmap, @MPImageFormat int imageFormat) { if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { throw new IllegalArgumentException( - "Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" + "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not" + " supported"); } if (bitmap.getConfig() == Config.ARGB_8888) { - if (imageFormat == Image.IMAGE_FORMAT_RGBA) { + if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) { ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); bitmap.copyPixelsToBuffer(buffer); buffer.rewind(); return buffer; - } else if (imageFormat == Image.IMAGE_FORMAT_RGB) { + } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) { // TODO: Try Use RGBA buffer to create RGB buffer which might be faster. int w = bitmap.getWidth(); int h = bitmap.getHeight(); @@ -196,14 +200,14 @@ public class ByteBufferExtractor { } throw new IllegalArgumentException( String.format( - "Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" + "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format" + " %d is not supported", bitmap.getConfig(), imageFormat)); } private static ByteBuffer convertByteBuffer( - ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { - if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { + ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) { + if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) { ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); // Extend the buffer when the target is longer than the source. Use two cursors and sweep the // array reversely to convert in-place. @@ -221,7 +225,8 @@ public class ByteBufferExtractor { target.put(array, 0, target.capacity()); target.rewind(); return target; - } else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { + } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA + && targetFormat == MPImage.IMAGE_FORMAT_RGB) { ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the // array to convert in-place. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java index 07871da38..a650e4c33 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java @@ -15,11 +15,11 @@ limitations under the License. package com.google.mediapipe.framework.image; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; /** - * Builds a {@link Image} from a {@link ByteBuffer}. + * Builds a {@link MPImage} from a {@link ByteBuffer}. * *

You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. @@ -32,7 +32,7 @@ public class ByteBufferImageBuilder { private final ByteBuffer buffer; private final int width; private final int height; - @ImageFormat private final int imageFormat; + @MPImageFormat private final int imageFormat; // Optional fields. private long timestamp; @@ -49,7 +49,7 @@ public class ByteBufferImageBuilder { * @param imageFormat how the data encode the image. */ public ByteBufferImageBuilder( - ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { + ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) { this.buffer = byteBuffer; this.width = width; this.height = height; @@ -58,14 +58,14 @@ public class ByteBufferImageBuilder { this.timestamp = 0; } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ ByteBufferImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java index 1c24c1dfd..82dbe32ca 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java @@ -15,21 +15,19 @@ limitations under the License. package com.google.mediapipe.framework.image; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; -class ByteBufferImageContainer implements ImageContainer { +class ByteBufferImageContainer implements MPImageContainer { private final ByteBuffer buffer; - private final ImageProperties properties; + private final MPImageProperties properties; - public ByteBufferImageContainer( - ByteBuffer buffer, - @ImageFormat int imageFormat) { + public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) { this.buffer = buffer; this.properties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER) .setImageFormat(imageFormat) .build(); } @@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer { } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } - /** - * Returns the image format. - */ - @ImageFormat + /** Returns the image format. */ + @MPImageFormat public int getImageFormat() { return properties.getImageFormat(); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/Image.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java similarity index 76% rename from mediapipe/java/com/google/mediapipe/framework/image/Image.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImage.java index 49e63bcc0..e17cc4d30 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/Image.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java @@ -29,10 +29,10 @@ import java.util.Map.Entry; /** * The wrapper class for image objects. * - *

{@link Image} is designed to be an immutable image container, which could be shared + *

{@link MPImage} is designed to be an immutable image container, which could be shared * cross-platforms. * - *

To construct an {@link Image}, use the provided builders: + *

To construct a {@link MPImage}, use the provided builders: * *

    *
  • {@link ByteBufferImageBuilder} @@ -40,7 +40,7 @@ import java.util.Map.Entry; *
  • {@link MediaImageBuilder} *
* - *

{@link Image} uses reference counting to maintain internal storage. When it is created the + *

{@link MPImage} uses reference counting to maintain internal storage. When it is created the * reference count is 1. Developer can call {@link #close()} to reduce reference count to release * internal storage earlier, otherwise Java garbage collection will release the storage eventually. * @@ -53,7 +53,7 @@ import java.util.Map.Entry; *

  • {@link MediaImageExtractor} * */ -public class Image implements Closeable { +public class MPImage implements Closeable { /** Specifies the image format of an image. */ @IntDef({ @@ -69,7 +69,7 @@ public class Image implements Closeable { IMAGE_FORMAT_JPEG, }) @Retention(RetentionPolicy.SOURCE) - public @interface ImageFormat {} + public @interface MPImageFormat {} public static final int IMAGE_FORMAT_UNKNOWN = 0; public static final int IMAGE_FORMAT_RGBA = 1; @@ -98,14 +98,14 @@ public class Image implements Closeable { public static final int STORAGE_TYPE_IMAGE_PROXY = 4; /** - * Returns a list of supported image properties for this {@link Image}. + * Returns a list of supported image properties for this {@link MPImage}. * - *

    Currently {@link Image} only support single storage type so the size of return list will + *

    Currently {@link MPImage} only support single storage type so the size of return list will * always be 1. * - * @see ImageProperties + * @see MPImageProperties */ - public List getContainedImageProperties() { + public List getContainedImageProperties() { return Collections.singletonList(getContainer().getImageProperties()); } @@ -124,7 +124,7 @@ public class Image implements Closeable { return height; } - /** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ + /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */ private synchronized void acquire() { referenceCount += 1; } @@ -132,7 +132,7 @@ public class Image implements Closeable { /** * Removes a reference that was previously acquired or init. * - *

    When {@link Image} is created, it has 1 reference count. + *

    When {@link MPImage} is created, it has 1 reference count. * *

    When the reference count becomes 0, it will release the resource under the hood. */ @@ -141,24 +141,24 @@ public class Image implements Closeable { public synchronized void close() { referenceCount -= 1; if (referenceCount == 0) { - for (ImageContainer imageContainer : containerMap.values()) { + for (MPImageContainer imageContainer : containerMap.values()) { imageContainer.close(); } } } - /** Advanced API access for {@link Image}. */ + /** Advanced API access for {@link MPImage}. */ static final class Internal { /** - * Acquires a reference on this {@link Image}. This will increase the reference count by 1. + * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. * *

    This method is more useful for image consumer to acquire a reference so image resource * will not be closed accidentally. As image creator, normal developer doesn't need to call this * method. * - *

    The reference count is 1 when {@link Image} is created. Developer can call {@link - * #close()} to indicate it doesn't need this {@link Image} anymore. + *

    The reference count is 1 when {@link MPImage} is created. Developer can call {@link + * #close()} to indicate it doesn't need this {@link MPImage} anymore. * * @see #close() */ @@ -166,10 +166,10 @@ public class Image implements Closeable { image.acquire(); } - private final Image image; + private final MPImage image; - // Only Image creates the internal helper. - private Internal(Image image) { + // Only MPImage creates the internal helper. + private Internal(MPImage image) { this.image = image; } } @@ -179,15 +179,15 @@ public class Image implements Closeable { return new Internal(this); } - private final Map containerMap; + private final Map containerMap; private final long timestamp; private final int width; private final int height; private int referenceCount; - /** Constructs an {@link Image} with a built container. */ - Image(ImageContainer container, long timestamp, int width, int height) { + /** Constructs a {@link MPImage} with a built container. */ + MPImage(MPImageContainer container, long timestamp, int width, int height) { this.containerMap = new HashMap<>(); containerMap.put(container.getImageProperties(), container); this.timestamp = timestamp; @@ -201,10 +201,10 @@ public class Image implements Closeable { * * @return the current container. */ - ImageContainer getContainer() { + MPImageContainer getContainer() { // According to the design, in the future we will support multiple containers in one image. // Currently just return the original container. - // TODO: Cache multiple containers in Image. + // TODO: Cache multiple containers in MPImage. return containerMap.values().iterator().next(); } @@ -214,8 +214,8 @@ public class Image implements Closeable { *

    If there are multiple containers with required {@code storageType}, returns the first one. */ @Nullable - ImageContainer getContainer(@StorageType int storageType) { - for (Entry entry : containerMap.entrySet()) { + MPImageContainer getContainer(@StorageType int storageType) { + for (Entry entry : containerMap.entrySet()) { if (entry.getKey().getStorageType() == storageType) { return entry.getValue(); } @@ -225,13 +225,13 @@ public class Image implements Closeable { /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ @Nullable - ImageContainer getContainer(ImageProperties imageProperties) { + MPImageContainer getContainer(MPImageProperties imageProperties) { return containerMap.get(imageProperties); } /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ - boolean addContainer(ImageContainer container) { - ImageProperties imageProperties = container.getImageProperties(); + boolean addContainer(MPImageContainer container) { + MPImageProperties imageProperties = container.getImageProperties(); if (containerMap.containsKey(imageProperties)) { return false; } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java similarity index 87% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java index 18eed68c6..f9f343e93 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ package com.google.mediapipe.framework.image; -/** Lightweight abstraction for an object that can receive {@link Image} */ -public interface ImageConsumer { +/** Lightweight abstraction for an object that can receive {@link MPImage} */ +public interface MPImageConsumer { /** - * Called when an {@link Image} is available. + * Called when a {@link MPImage} is available. * *

    The argument is only guaranteed to be available until this method returns. if you need to * extend its life time, acquire it, then release it when done. */ - void onNewImage(Image image); + void onNewMPImage(MPImage image); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java similarity index 93% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java index 727ec0893..674073b5b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java @@ -16,9 +16,9 @@ limitations under the License. package com.google.mediapipe.framework.image; /** Manages internal image data storage. The interface is package-private. */ -interface ImageContainer { +interface MPImageContainer { /** Returns the properties of the contained image. */ - ImageProperties getImageProperties(); + MPImageProperties getImageProperties(); /** Close the image container and releases the image resource inside. */ void close(); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java similarity index 75% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java index 4f3641d6f..9783935d4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ package com.google.mediapipe.framework.image; -/** Lightweight abstraction for an object that produce {@link Image} */ -public interface ImageProducer { +/** Lightweight abstraction for an object that produce {@link MPImage} */ +public interface MPImageProducer { - /** Sets the consumer that receives the {@link Image}. */ - void setImageConsumer(ImageConsumer imageConsumer); + /** Sets the consumer that receives the {@link MPImage}. */ + void setMPImageConsumer(MPImageConsumer imageConsumer); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java similarity index 63% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java index e33b33e7f..6005ce77b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java @@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image; import com.google.auto.value.AutoValue; import com.google.auto.value.extension.memoized.Memoized; -import com.google.mediapipe.framework.image.Image.ImageFormat; -import com.google.mediapipe.framework.image.Image.StorageType; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; +import com.google.mediapipe.framework.image.MPImage.StorageType; /** Groups a set of properties to describe how an image is stored. */ @AutoValue -public abstract class ImageProperties { +public abstract class MPImageProperties { /** * Gets the pixel format of the image. * - * @see Image.ImageFormat + * @see MPImage.MPImageFormat */ - @ImageFormat + @MPImageFormat public abstract int getImageFormat(); /** * Gets the storage type of the image. * - * @see Image.StorageType + * @see MPImage.StorageType */ @StorageType public abstract int getStorageType(); @@ -45,36 +45,36 @@ public abstract class ImageProperties { public abstract int hashCode(); /** - * Creates a builder of {@link ImageProperties}. + * Creates a builder of {@link MPImageProperties}. * - * @see ImageProperties.Builder + * @see MPImageProperties.Builder */ static Builder builder() { - return new AutoValue_ImageProperties.Builder(); + return new AutoValue_MPImageProperties.Builder(); } - /** Builds a {@link ImageProperties}. */ + /** Builds a {@link MPImageProperties}. */ @AutoValue.Builder abstract static class Builder { /** - * Sets the {@link Image.ImageFormat}. + * Sets the {@link MPImage.MPImageFormat}. * - * @see ImageProperties#getImageFormat + * @see MPImageProperties#getImageFormat */ - abstract Builder setImageFormat(@ImageFormat int value); + abstract Builder setImageFormat(@MPImageFormat int value); /** - * Sets the {@link Image.StorageType}. + * Sets the {@link MPImage.StorageType}. * - * @see ImageProperties#getStorageType + * @see MPImageProperties#getStorageType */ abstract Builder setStorageType(@StorageType int value); - /** Builds the {@link ImageProperties}. */ - abstract ImageProperties build(); + /** Builds the {@link MPImageProperties}. */ + abstract MPImageProperties build(); } // Hide the constructor. - ImageProperties() {} + MPImageProperties() {} } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java index e351a87fd..9e719715d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java @@ -15,11 +15,12 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; /** - * Builds {@link Image} from {@link android.media.Image}. + * Builds {@link MPImage} from {@link android.media.Image}. * *

    Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify * content in it. @@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi; public class MediaImageBuilder { // Mandatory fields. - private final android.media.Image mediaImage; + private final Image mediaImage; // Optional fields. private long timestamp; @@ -40,20 +41,20 @@ public class MediaImageBuilder { * * @param mediaImage image data object. */ - public MediaImageBuilder(android.media.Image mediaImage) { + public MediaImageBuilder(Image mediaImage) { this.mediaImage = mediaImage; this.timestamp = 0; } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ MediaImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image( + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage( new MediaImageContainer(mediaImage), timestamp, mediaImage.getWidth(), diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java index 144b64def..864c76df2 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java @@ -15,33 +15,34 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build; import android.os.Build.VERSION; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; @RequiresApi(VERSION_CODES.KITKAT) -class MediaImageContainer implements ImageContainer { +class MediaImageContainer implements MPImageContainer { - private final android.media.Image mediaImage; - private final ImageProperties properties; + private final Image mediaImage; + private final MPImageProperties properties; - public MediaImageContainer(android.media.Image mediaImage) { + public MediaImageContainer(Image mediaImage) { this.mediaImage = mediaImage; this.properties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE) .setImageFormat(convertFormatCode(mediaImage.getFormat())) .build(); } - public android.media.Image getImage() { + public Image getImage() { return mediaImage; } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } @@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer { mediaImage.close(); } - @ImageFormat + @MPImageFormat static int convertFormatCode(int graphicsFormat) { // We only cover the format mentioned in // https://developer.android.com/reference/android/media/Image#getFormat() if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { - return Image.IMAGE_FORMAT_RGB; + return MPImage.IMAGE_FORMAT_RGB; } } switch (graphicsFormat) { case android.graphics.ImageFormat.JPEG: - return Image.IMAGE_FORMAT_JPEG; + return MPImage.IMAGE_FORMAT_JPEG; case android.graphics.ImageFormat.YUV_420_888: - return Image.IMAGE_FORMAT_YUV_420_888; + return MPImage.IMAGE_FORMAT_YUV_420_888; default: - return Image.IMAGE_FORMAT_UNKNOWN; + return MPImage.IMAGE_FORMAT_UNKNOWN; } } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java index 718cb471f..76bb5a5ec 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java @@ -15,13 +15,14 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; /** - * Utility for extracting {@link android.media.Image} from {@link Image}. + * Utility for extracting {@link android.media.Image} from {@link MPImage}. * - *

    Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, + *

    Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE}, * otherwise {@link IllegalArgumentException} will be thrown. */ @RequiresApi(VERSION_CODES.KITKAT) @@ -30,20 +31,20 @@ public class MediaImageExtractor { private MediaImageExtractor() {} /** - * Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for - * {@link Image} that built from {@link MediaImageBuilder}. + * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for + * {@link MPImage} that built from {@link MediaImageBuilder}. * * @param image the image to extract {@link android.media.Image} from. - * @return {@link android.media.Image} that stored in {@link Image}. + * @return {@link android.media.Image} that stored in {@link MPImage}. * @throws IllegalArgumentException if the extraction failed. */ - public static android.media.Image extract(Image image) { - ImageContainer container; - if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { + public static Image extract(MPImage image) { + MPImageContainer container; + if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) { return ((MediaImageContainer) container).getImage(); } throw new IllegalArgumentException( - "Extract Media Image from an Image created by objects other than Media Image" + "Extract Media Image from a MPImage created by objects other than Media Image" + " is not supported"); } } diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 9b01e2f0b..645e8b722 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -1,4 +1,4 @@ -# Copyright 2019-2020 The MediaPipe Authors. +# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -209,9 +209,9 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []): def mediapipe_build_aar_with_jni(name, android_library): """Builds MediaPipe AAR with jni. - Args: - name: The bazel target name. - android_library: the android library that contains jni. + Args: + name: The bazel target name. + android_library: the android library that contains jni. """ # Generates dummy AndroidManifest.xml for dummy apk usage @@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""): src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java", )) - proto_src_list.append(mediapipe_java_proto_src_extractor( - target = "//mediapipe/framework/formats:landmark_java_proto_lite", - src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", - )) - proto_src_list.append(mediapipe_java_proto_src_extractor( target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", )) proto_src_list.append(mediapipe_java_proto_src_extractor( - target = "//mediapipe/framework/formats:location_data_java_proto_lite", - src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + target = "//mediapipe/framework/formats:classification_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", )) proto_src_list.append(mediapipe_java_proto_src_extractor( @@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""): )) proto_src_list.append(mediapipe_java_proto_src_extractor( - target = "//mediapipe/framework/formats:classification_java_proto_lite", - src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", + target = "//mediapipe/framework/formats:landmark_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:location_data_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:rect_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/RectProto.java", )) return proto_src_list diff --git a/mediapipe/model_maker/python/core/BUILD b/mediapipe/model_maker/python/core/BUILD index 9f205bb11..636a1a720 100644 --- a/mediapipe/model_maker/python/core/BUILD +++ b/mediapipe/model_maker/python/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package( default_visibility = ["//mediapipe:__subpackages__"], diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD index c4c659d56..70a62e8f7 100644 --- a/mediapipe/model_maker/python/core/data/BUILD +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) @@ -23,15 +24,12 @@ package( py_library( name = "data_util", srcs = ["data_util.py"], - srcs_version = "PY3", ) py_test( name = "data_util_test", srcs = ["data_util_test.py"], data = ["//mediapipe/model_maker/python/core/data/testdata"], - python_version = "PY3", - srcs_version = "PY3", deps = [":data_util"], ) @@ -44,8 +42,6 @@ py_library( py_test( name = "dataset_test", srcs = ["dataset_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":dataset", "//mediapipe/model_maker/python/core/utils:test_util", @@ -55,14 +51,11 @@ py_test( py_library( name = "classification_dataset", srcs = ["classification_dataset.py"], - srcs_version = "PY3", deps = [":dataset"], ) py_test( name = "classification_dataset_test", srcs = ["classification_dataset_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [":classification_dataset"], ) diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD index b3588f0be..124de621a 100644 --- a/mediapipe/model_maker/python/core/tasks/BUILD +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. package( default_visibility = ["//mediapipe:__subpackages__"], @@ -23,7 +24,6 @@ licenses(["notice"]) py_library( name = "custom_model", srcs = ["custom_model.py"], - srcs_version = "PY3", deps = [ "//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/utils:model_util", @@ -34,8 +34,6 @@ py_library( py_test( name = "custom_model_test", srcs = ["custom_model_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":custom_model", "//mediapipe/model_maker/python/core/utils:test_util", @@ -45,7 +43,6 @@ py_test( py_library( name = "classifier", srcs = ["classifier.py"], - srcs_version = "PY3", deps = [ ":custom_model", "//mediapipe/model_maker/python/core/data:dataset", @@ -55,8 +52,6 @@ py_library( py_test( name = "classifier_test", srcs = ["classifier_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":classifier", "//mediapipe/model_maker/python/core/utils:test_util", diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 2538ec8fa..a2ec52044 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -13,6 +13,7 @@ # limitations under the License. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) @@ -24,7 +25,6 @@ py_library( name = "test_util", testonly = 1, srcs = ["test_util.py"], - srcs_version = "PY3", deps = [ ":model_util", "//mediapipe/model_maker/python/core/data:dataset", @@ -34,7 +34,6 @@ py_library( py_library( name = "model_util", srcs = ["model_util.py"], - srcs_version = "PY3", deps = [ ":quantization", "//mediapipe/model_maker/python/core/data:dataset", @@ -44,8 +43,6 @@ py_library( py_test( name = "model_util_test", srcs = ["model_util_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":model_util", ":quantization", @@ -62,8 +59,6 @@ py_library( py_test( name = "loss_functions_test", srcs = ["loss_functions_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [":loss_functions"], ) @@ -77,8 +72,6 @@ py_library( py_test( name = "quantization_test", srcs = ["quantization_test.py"], - python_version = "PY3", - srcs_version = "PY3", deps = [ ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/vision/core/BUILD b/mediapipe/model_maker/python/vision/core/BUILD index 2658841ae..0b15a0276 100644 --- a/mediapipe/model_maker/python/vision/core/BUILD +++ b/mediapipe/model_maker/python/vision/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict test compatibility macro. licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index 5b4ec2bd1..a2268059f 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python library rule. # Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python library rule. licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py index 704d71a5a..265c36a6e 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams, return model.fit( x=train_ds, epochs=hparams.train_epochs, - steps_per_epoch=hparams.steps_per_epoch, validation_data=validation_ds, callbacks=callbacks) diff --git a/mediapipe/modules/face_geometry/libs/effect_renderer.cc b/mediapipe/modules/face_geometry/libs/effect_renderer.cc index 27a54e011..73f473084 100644 --- a/mediapipe/modules/face_geometry/libs/effect_renderer.cc +++ b/mediapipe/modules/face_geometry/libs/effect_renderer.cc @@ -161,7 +161,7 @@ class Texture { ~Texture() { if (is_owned_) { - glDeleteProgram(handle_); + glDeleteTextures(1, &handle_); } } diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 07ad97810..e830e3c58 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -87,6 +87,7 @@ cc_library( cc_library( name = "builtin_task_graphs", deps = [ + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", ], diff --git a/mediapipe/python/packet_getter.py b/mediapipe/python/packet_getter.py index 4d93e713b..cf6e7574a 100644 --- a/mediapipe/python/packet_getter.py +++ b/mediapipe/python/packet_getter.py @@ -14,7 +14,7 @@ """The public facing packet getter APIs.""" -from typing import List, Type +from typing import List from google.protobuf import message from google.protobuf import symbol_database @@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame get_matrix = _packet_getter.get_matrix -def get_proto(packet: mp_packet.Packet) -> Type[message.Message]: +def get_proto(packet: mp_packet.Packet) -> message.Message: """Get the content of a MediaPipe proto Packet as a proto message. Args: diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index e4905546a..344fafb4e 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -46,8 +46,10 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto index a154e5f4e..2ba760e99 100644 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -17,7 +17,7 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; -option java_package = "com.google.mediapipe.tasks.components.container.proto"; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "CategoryProto"; // A single classification result. diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index 0f5086b95..712607fa6 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto; import "mediapipe/tasks/cc/components/containers/proto/category.proto"; -option java_package = "com.google.mediapipe.tasks.components.container.proto"; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; // List of predicted categories with an optional timestamp. diff --git a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto index d57b08b53..39811e6c0 100644 --- a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto +++ b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "EmbeddingsProto"; + // Defines a dense floating-point embedding. message FloatEmbedding { repeated float values = 1 [packed = true]; diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index 046a97e4d..7940080e1 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -30,9 +30,11 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator( options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / std); } + // TODO: need to support different GPU origin on differnt + // platforms or applications. + options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT); return absl::OkStatus(); } } // namespace +bool DetermineImagePreprocessingGpuBackend( + const core::proto::Acceleration& acceleration) { + return acceleration.has_gpu(); +} + absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, + bool use_gpu, ImagePreprocessingOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); @@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, image_tensor_specs, options->mutable_image_to_tensor_options())); // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. - if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { + if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { + options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + } else { options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); } return absl::OkStatus(); diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/image_preprocessing.h index a5b767f3a..6963b6556 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/image_preprocessing.h @@ -19,20 +19,26 @@ limitations under the License. #include "absl/status/status.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { -// Configures an ImagePreprocessing subgraph using the provided model resources. +// Configures an ImagePreprocessing subgraph using the provided model resources +// When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = // graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// core::proto::Acceleration acceleration; +// acceleration.mutable_xnnpack(); +// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // model_resources, +// use_gpu, // &preprocessing.GetOptions())); // // The resulting ImagePreprocessing subgraph has the following I/O: @@ -56,9 +62,14 @@ namespace components { // The image that has the pixel data stored on the target storage (CPU vs // GPU). absl::Status ConfigureImagePreprocessing( - const core::ModelResources& model_resources, + const core::ModelResources& model_resources, bool use_gpu, ImagePreprocessingOptions* options); +// Determine if the image preprocessing subgraph should use GPU as the backend +// according to the given acceleration setting. +bool DetermineImagePreprocessingGpuBackend( + const core::proto::Acceleration& acceleration); + } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 47334b673..66434483b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -156,21 +156,24 @@ absl::StatusOr ModelTaskGraph::GetConfig( } absl::StatusOr ModelTaskGraph::CreateModelResources( - SubgraphContext* sc, std::unique_ptr external_file) { + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); if (!model_resources_cache_service.IsAvailable()) { - ASSIGN_OR_RETURN(local_model_resources_, + ASSIGN_OR_RETURN(auto local_model_resource, ModelResources::Create("", std::move(external_file))); LOG(WARNING) << "A local ModelResources object is created. Please consider using " "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; - return local_model_resources_.get(); + local_model_resources_.push_back(std::move(local_model_resource)); + return local_model_resources_.back().get(); } ASSIGN_OR_RETURN( auto op_resolver_packet, model_resources_cache_service.GetObject().GetGraphOpResolverPacket()); - const std::string tag = CreateModelResourcesTag(sc->OriginalNode()); + const std::string tag = + absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix); ASSIGN_OR_RETURN(auto model_resources, ModelResources::Create(tag, std::move(external_file), op_resolver_packet)); @@ -182,7 +185,8 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( absl::StatusOr ModelTaskGraph::CreateModelAssetBundleResources( - SubgraphContext* sc, std::unique_ptr external_file) { + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); bool has_file_pointer_meta = external_file->has_file_pointer_meta(); // if external file is set by file pointer, no need to add the model asset @@ -190,7 +194,7 @@ ModelTaskGraph::CreateModelAssetBundleResources( // not owned by this model asset bundle resources. if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) { ASSIGN_OR_RETURN( - local_model_asset_bundle_resources_, + auto local_model_asset_bundle_resource, ModelAssetBundleResources::Create("", std::move(external_file))); if (!has_file_pointer_meta) { LOG(WARNING) @@ -198,10 +202,12 @@ ModelTaskGraph::CreateModelAssetBundleResources( "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; } - return local_model_asset_bundle_resources_.get(); + local_model_asset_bundle_resources_.push_back( + std::move(local_model_asset_bundle_resource)); + return local_model_asset_bundle_resources_.back().get(); } - const std::string tag = - CreateModelAssetBundleResourcesTag(sc->OriginalNode()); + const std::string tag = absl::StrCat( + CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix); ASSIGN_OR_RETURN( auto model_bundle_resources, ModelAssetBundleResources::Create(tag, std::move(external_file))); diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 5ee70e8f3..50dcc903b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -75,9 +76,14 @@ class ModelTaskGraph : public Subgraph { // construction stage. Note that the external file contents will be moved // into the model resources object on creation. The returned model resources // pointer will provide graph authors with the access to the metadata - // extractor and the tflite model. + // extractor and the tflite model. When the model resources graph service is + // available, a tag is generated internally asscoiated with the created model + // resource. If more than one model resources are created in a graph, the + // model resources graph service add the tag_suffix to support multiple + // resources. absl::StatusOr CreateModelResources( - SubgraphContext* sc, std::unique_ptr external_file); + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix = ""); // If the model resources graph service is available, creates a model asset // bundle resources object from the subgraph context, and caches the created @@ -103,10 +109,15 @@ class ModelTaskGraph : public Subgraph { // that can only be used in the graph construction stage. Note that the // external file contents will be moved into the model asset bundle resources // object on creation. The returned model asset bundle resources pointer will - // provide graph authors with the access to extracted model files. + // provide graph authors with the access to extracted model files. When the + // model resources graph service is available, a tag is generated internally + // asscoiated with the created model asset bundle resource. If more than one + // model asset bundle resources are created in a graph, the model resources + // graph service add the tag_suffix to support multiple resources. absl::StatusOr CreateModelAssetBundleResources( - SubgraphContext* sc, std::unique_ptr external_file); + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix = ""); // Inserts a mediapipe task inference subgraph into the provided // GraphBuilder. The returned node provides the following interfaces to the @@ -124,9 +135,9 @@ class ModelTaskGraph : public Subgraph { api2::builder::Graph& graph) const; private: - std::unique_ptr local_model_resources_; + std::vector> local_model_resources_; - std::unique_ptr + std::vector> local_model_asset_bundle_resources_; }; diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index a85538631..336b1bb45 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -63,6 +63,29 @@ cc_library( ], ) +cc_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_classifier", + ":text_classifier_test_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) + cc_library( name = "text_classifier_test_utils", srcs = ["text_classifier_test_utils.cc"], diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 5b33f6606..62837be8c 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; -using ::testing::proto::Approximately; -using ::testing::proto::IgnoringRepeatedFieldOrdering; -using ::testing::proto::Partially; constexpr float kEpsilon = 0.001; constexpr int kMaxSeqLen = 128; @@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { MP_ASSERT_OK(TextClassifier::Create(std::move(options))); } -TEST_F(TextClassifierTest, TextClassifierWithBert) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult negative_result, - classifier->Classify("unflinchingly bleak and desperate")); - ASSERT_THAT(negative_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.956 } - categories { category_name: "positive" score: 0.044 } - } - } - )pb"), - kEpsilon)))); - - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult positive_result, - classifier->Classify("it's a charming and often affecting journey")); - ASSERT_THAT(positive_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.0 } - categories { category_name: "positive" score: 1.0 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - -TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result, - classifier->Classify("What a waste of my time.")); - ASSERT_THAT(negative_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "Negative" score: 0.813 } - categories { category_name: "Positive" score: 0.187 } - } - } - )pb"), - kEpsilon)))); - - MP_ASSERT_OK_AND_ASSIGN( - ClassificationResult positive_result, - classifier->Classify("This is the best movie I’ve seen in recent years. " - "Strongly recommend it!")); - ASSERT_THAT(positive_result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "Negative" score: 0.487 } - categories { category_name: "Positive" score: 0.513 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - -TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); - options->base_options.op_resolver = CreateCustomResolver(); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, - classifier->Classify("hello")); - ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( - classifications { - entries { - categories { index: 1 score: 1 } - categories { index: 0 score: 1 } - categories { index: 2 score: 0 } - } - } - )pb")))); -} - -TEST_F(TextClassifierTest, BertLongPositive) { - std::stringstream ss_for_positive_review; - ss_for_positive_review - << "it's a charming and often affecting journey and this is a long"; - for (int i = 0; i < kMaxSeqLen; ++i) { - ss_for_positive_review << " long"; - } - ss_for_positive_review << " movie review"; - auto options = std::make_unique(); - options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); - MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, - TextClassifier::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, - classifier->Classify(ss_for_positive_review.str())); - ASSERT_THAT(result, - Partially(IgnoringRepeatedFieldOrdering(Approximately( - EqualsProto(R"pb( - classifications { - entries { - categories { category_name: "negative" score: 0.014 } - categories { category_name: "positive" score: 0.986 } - } - } - )pb"), - kEpsilon)))); - MP_ASSERT_OK(classifier->Close()); -} - } // namespace } // namespace text_classifier } // namespace text diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 048c7021d..5ce08b2d7 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -73,7 +73,18 @@ cc_library( ], ) -# TODO: This test fails in OSS +cc_test( + name = "sentencepiece_tokenizer_test", + srcs = ["sentencepiece_tokenizer_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:albert_model", + ], + deps = [ + ":sentencepiece_tokenizer", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/core:utils", + ], +) cc_library( name = "tokenizer_utils", @@ -97,7 +108,32 @@ cc_library( ], ) -# TODO: This test fails in OSS +cc_test( + name = "tokenizer_utils_test", + srcs = ["tokenizer_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:albert_model", + "//mediapipe/tasks/testdata/text:mobile_bert_model", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + linkopts = ["-ldl"], + deps = [ + ":bert_tokenizer", + ":regex_tokenizer", + ":sentencepiece_tokenizer", + ":tokenizer_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) cc_library( name = "regex_tokenizer", diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index 12d789901..e8e197a1d 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -21,12 +21,23 @@ cc_library( hdrs = ["running_mode.h"], ) +cc_library( + name = "image_processing_options", + hdrs = ["image_processing_options.h"], + deps = [ + "//mediapipe/tasks/cc/components/containers:rect", + ], +) + cc_library( name = "base_vision_task_api", hdrs = ["base_vision_task_api.h"], deps = [ + ":image_processing_options", ":running_mode", "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:rect", "//mediapipe/tasks/cc/core:base_task_api", "//mediapipe/tasks/cc/core:task_runner", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index 4586cbbdd..c3c0a0261 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -16,15 +16,20 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ #define MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ +#include #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -87,6 +92,60 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { return runner_->Send(std::move(inputs)); } + // Convert from ImageProcessingOptions to NormalizedRect, performing sanity + // checks on-the-fly. If the input ImageProcessingOptions is not present, + // returns a default NormalizedRect covering the whole image with rotation set + // to 0. If 'roi_allowed' is false, an error will be returned if the input + // ImageProcessingOptions has its 'region_or_interest' field set. + static absl::StatusOr ConvertToNormalizedRect( + std::optional options, bool roi_allowed = true) { + mediapipe::NormalizedRect normalized_rect; + normalized_rect.set_rotation(0); + normalized_rect.set_x_center(0.5); + normalized_rect.set_y_center(0.5); + normalized_rect.set_width(1.0); + normalized_rect.set_height(1.0); + if (!options.has_value()) { + return normalized_rect; + } + + if (options->rotation_degrees % 90 != 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected rotation to be a multiple of 90°.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + // Convert to radians counter-clockwise. + normalized_rect.set_rotation(-options->rotation_degrees * M_PI / 180.0); + + if (options->region_of_interest.has_value()) { + if (!roi_allowed) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "This task doesn't support region-of-interest.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + auto& roi = *options->region_of_interest; + if (roi.left >= roi.right || roi.top >= roi.bottom) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected Rect with left < right and top < bottom.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected Rect values to be in [0,1].", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + normalized_rect.set_x_center((roi.left + roi.right) / 2.0); + normalized_rect.set_y_center((roi.top + roi.bottom) / 2.0); + normalized_rect.set_width(roi.right - roi.left); + normalized_rect.set_height(roi.bottom - roi.top); + } + return normalized_rect; + } + private: RunningMode running_mode_; }; diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h new file mode 100644 index 000000000..7e764c1fe --- /dev/null +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ + +#include + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace core { + +// Options for image processing. +// +// If both region-or-interest and rotation are specified, the crop around the +// region-of-interest is extracted first, the the specified rotation is applied +// to the crop. +struct ImageProcessingOptions { + // The optional region-of-interest to crop from the image. If not specified, + // the full image is used. + // + // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + std::optional region_of_interest = std::nullopt; + + // The rotation to apply to the image (or cropped region-of-interest), in + // degrees clockwise. + // + // The rotation must be a multiple (positive or negative) of 90°. + int rotation_degrees = 0; +}; + +} // namespace core +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index e5b1f0479..6296017d4 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -62,13 +62,19 @@ cc_library( "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", @@ -93,10 +99,14 @@ cc_library( "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", @@ -137,8 +147,10 @@ cc_library( "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 08f7f45d0..8c2c2e593 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -93,3 +93,46 @@ cc_test( "@com_google_absl//absl/strings", ], ) + +mediapipe_proto_library( + name = "combined_prediction_calculator_proto", + srcs = ["combined_prediction_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "combined_prediction_calculator", + srcs = ["combined_prediction_calculator.cc"], + deps = [ + ":combined_prediction_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "combined_prediction_calculator_test", + srcs = ["combined_prediction_calculator_test.cc"], + deps = [ + ":combined_prediction_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc new file mode 100644 index 000000000..c7147ea6e --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.cc @@ -0,0 +1,187 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h" + +namespace mediapipe { +namespace api2 { +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +Classification GetMaxScoringClassification( + const ClassificationList& classifications) { + Classification max_classification; + max_classification.set_score(0); + for (const auto& input : classifications.classification()) { + if (max_classification.score() < input.score()) { + max_classification = input; + } + } + return max_classification; +} + +float GetScoreThreshold( + const std::string& input_label, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + float threshold = default_threshold; + auto it = classwise_thresholds.find(input_label); + if (it != classwise_thresholds.end()) { + threshold = it->second; + } + return threshold; +} + +std::unique_ptr GetWinningPrediction( + const ClassificationList& classification_list, + const absl::btree_map& classwise_thresholds, + const std::string& background_label, const float default_threshold) { + auto prediction_list = std::make_unique(); + if (classification_list.classification().empty()) { + return prediction_list; + } + Classification& prediction = *prediction_list->add_classification(); + auto argmax_prediction = GetMaxScoringClassification(classification_list); + float argmax_prediction_thresh = + GetScoreThreshold(argmax_prediction.label(), classwise_thresholds, + background_label, default_threshold); + if (argmax_prediction.score() >= argmax_prediction_thresh) { + prediction.set_label(argmax_prediction.label()); + prediction.set_score(argmax_prediction.score()); + } else { + for (const auto& input : classification_list.classification()) { + if (input.label() == background_label) { + prediction.set_label(input.label()); + prediction.set_score(input.score()); + break; + } + } + } + return prediction_list; +} + +} // namespace + +// This calculator accepts multiple ClassificationList input streams. Each +// ClassificationList should contain classifications with labels and +// corresponding softmax scores. The calculator computes the best prediction for +// each ClassificationList input stream via argmax and thresholding. Thresholds +// for all classes can be specified in the +// `CombinedPredictionCalculatorOptions`, along with a default global +// threshold. +// Please note that for this calculator to work as designed, the class names +// other than the background class in the ClassificationList objects must be +// different, but the background class name has to be the same. This background +// label name can be set via `background_label` in +// `CombinedPredictionCalculatorOptions`. +// The ClassificationList in the PREDICTION output stream contains the label of +// the winning class and corresponding softmax score. If none of the +// ClassificationList objects has a non-background winning class, the output +// contains the background class and score of the background class in the first +// ClassificationList. If multiple ClassificationList objects have a +// non-background winning class, the output contains the winning prediction from +// the ClassificationList with the highest priority. Priority is in decreasing +// order of input streams to the graph node using this calculator. +// Input: +// At least one stream with ClassificationList. +// Output: +// PREDICTION - A ClassificationList with the winning label as the only item. +// +// Usage example: +// node { +// calculator: "CombinedPredictionCalculator" +// input_stream: "classification_list_0" +// input_stream: "classification_list_1" +// output_stream: "PREDICTION:prediction" +// options { +// [mediapipe.CombinedPredictionCalculatorOptions.ext] { +// class { +// label: "A" +// score_threshold: 0.7 +// } +// default_global_threshold: 0.1 +// background_label: "B" +// } +// } +// } + +class CombinedPredictionCalculator : public Node { + public: + static constexpr Input::Multiple kClassificationListIn{ + ""}; + static constexpr Output kPredictionOut{"PREDICTION"}; + MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut); + + absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + for (const auto& input : options_.class_()) { + classwise_thresholds_[input.label()] = input.score_threshold(); + } + classwise_thresholds_[options_.background_label()] = 0; + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + // After loop, if have winning prediction return. Otherwise empty packet. + std::unique_ptr first_winning_prediction = nullptr; + auto collection = kClassificationListIn(cc); + for (int idx = 0; idx < collection.Count(); ++idx) { + const auto& packet = collection[idx]; + if (packet.IsEmpty()) { + continue; + } + auto prediction = GetWinningPrediction( + packet.Get(), classwise_thresholds_, options_.background_label(), + options_.default_global_threshold()); + if (prediction->classification(0).label() != + options_.background_label()) { + kPredictionOut(cc).Send(std::move(prediction)); + return absl::OkStatus(); + } + if (first_winning_prediction == nullptr) { + first_winning_prediction = std::move(prediction); + } + } + if (first_winning_prediction != nullptr) { + kPredictionOut(cc).Send(std::move(first_winning_prediction)); + } + return absl::OkStatus(); + } + + private: + CombinedPredictionCalculatorOptions options_; + absl::btree_map classwise_thresholds_; +}; + +MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto new file mode 100644 index 000000000..730e7dd78 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.proto @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message CombinedPredictionCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional CombinedPredictionCalculatorOptions ext = 483738635; + } + + message Class { + optional string label = 1; + optional float score_threshold = 2; + } + + // List of classes with score thresholds. + repeated Class class = 1; + + // Default score threshold applied to a label. + optional float default_global_threshold = 2 [default = 0]; + + // Name of the background class whose input scores will be ignored while + // thresholding. + optional string background_label = 3; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc new file mode 100644 index 000000000..ecf49795b --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator_test.cc @@ -0,0 +1,315 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +constexpr char kPredictionTag[] = "PREDICTION"; + +std::unique_ptr BuildNodeRunnerWithOptions( + float drama_thresh, float llama_thresh, float bazinga_thresh, + float joy_thresh, float peace_thresh) { + constexpr absl::string_view kCalculatorProto = R"pb( + calculator: "CombinedPredictionCalculator" + input_stream: "custom_softmax_scores" + input_stream: "canned_softmax_scores" + output_stream: "PREDICTION:prediction" + options { + [mediapipe.CombinedPredictionCalculatorOptions.ext] { + class { label: "CustomDrama" score_threshold: $0 } + class { label: "CustomLlama" score_threshold: $1 } + class { label: "CannedBazinga" score_threshold: $2 } + class { label: "CannedJoy" score_threshold: $3 } + class { label: "CannedPeace" score_threshold: $4 } + background_label: "Negative" + } + } + )pb"; + auto runner = std::make_unique( + absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh, + bazinga_thresh, joy_thresh, peace_thresh)); + return runner; +} + +std::unique_ptr BuildCustomScoreInput( + const float negative_score, const float drama_score, + const float llama_score) { + auto custom_scores = std::make_unique(); + auto custom_negative = custom_scores->add_classification(); + custom_negative->set_label("Negative"); + custom_negative->set_score(negative_score); + auto drama = custom_scores->add_classification(); + drama->set_label("CustomDrama"); + drama->set_score(drama_score); + auto llama = custom_scores->add_classification(); + llama->set_label("CustomLlama"); + llama->set_score(llama_score); + return custom_scores; +} + +std::unique_ptr BuildCannedScoreInput( + const float negative_score, const float bazinga_score, + const float joy_score, const float peace_score) { + auto canned_scores = std::make_unique(); + auto canned_negative = canned_scores->add_classification(); + canned_negative->set_label("Negative"); + canned_negative->set_score(negative_score); + auto bazinga = canned_scores->add_classification(); + bazinga->set_label("CannedBazinga"); + bazinga->set_score(bazinga_score); + auto joy = canned_scores->add_classification(); + joy->set_label("CannedJoy"); + joy->set_score(joy_score); + auto peace = canned_scores->add_classification(); + peace->set_label("CannedPeace"); + peace->set_score(peace_score); + return canned_scores; +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedEmpty_ResultIsEmpty) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty()); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomEmpty_CannedNotEmpty_ResultIsCanned) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9, + /*joy_thresh=*/0.5, /*peace_thresh=*/0.8); + auto canned_scores = BuildCannedScoreInput( + /*negative_score=*/0.1, + /*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CannedJoy"); + EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4); +} + +TEST(CombinedPredictionCalculatorPacketTest, + CustomNotEmpty_CannedEmpty_ResultIsCustom) { + auto runner = BuildNodeRunnerWithOptions( + /*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0, + /*joy_thresh=*/0.0, /*peace_thresh=*/0.0); + auto custom_scores = + BuildCustomScoreInput(/*negative_score=*/0.1, + /*drama_score=*/0.2, /*llama_score=*/0.7); + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), "CustomLlama"); + EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4); +} + +struct CombinedPredictionCalculatorTestCase { + std::string test_name; + float custom_negative_score; + float drama_score; + float llama_score; + float drama_thresh; + float llama_thresh; + float canned_negative_score; + float bazinga_score; + float joy_score; + float peace_score; + float bazinga_thresh; + float joy_thresh; + float peace_thresh; + std::string max_scoring_label; + float max_score; +}; + +using CombinedPredictionCalculatorTest = + testing::TestWithParam; + +TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) { + const CombinedPredictionCalculatorTestCase& test_case = GetParam(); + + auto runner = BuildNodeRunnerWithOptions( + test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh, + test_case.joy_thresh, test_case.peace_thresh); + + auto custom_scores = + BuildCustomScoreInput(test_case.custom_negative_score, + test_case.drama_score, test_case.llama_score); + + runner->MutableInputs()->Index(0).packets.push_back( + Adopt(custom_scores.release()).At(Timestamp(1))); + + auto canned_scores = BuildCannedScoreInput( + test_case.canned_negative_score, test_case.bazinga_score, + test_case.joy_score, test_case.peace_score); + runner->MutableInputs()->Index(1).packets.push_back( + Adopt(canned_scores.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner->Run()) << "Calculator execution failed."; + + auto output_prediction_packets = + runner->Outputs().Tag(kPredictionTag).packets; + ASSERT_EQ(output_prediction_packets.size(), 1); + Classification output_prediction = + output_prediction_packets[0].Get().classification(0); + + EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label); + EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4); +} + +INSTANTIATE_TEST_CASE_P( + CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest, + testing::ValuesIn({ + { + .test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.5, + .llama_score = 0.3, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "CustomDrama", + .max_score = 0.5, + }, + { + .test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.1, + .bazinga_score = 0.4, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.0, + .joy_thresh = 0.0, + .peace_thresh = 0.0, + .max_scoring_label = "CannedBazinga", + .max_score = 0.4, + }, + { + .test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh", + .custom_negative_score = 0.5, + .drama_score = 0.1, + .llama_score = 0.4, + .drama_thresh = 0.1, + .llama_thresh = 0.05, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.5, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh", + .custom_negative_score = 0.8, + .drama_score = 0.1, + .llama_score = 0.1, + .drama_thresh = 0.25, + .llama_thresh = 0.7, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.8, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2", + .custom_negative_score = 0.1, + .drama_score = 0.2, + .llama_score = 0.7, + .drama_thresh = 1.1, + .llama_thresh = 1.1, + .canned_negative_score = 0.1, + .bazinga_score = 0.3, + .joy_score = 0.3, + .peace_score = 0.3, + .bazinga_thresh = 0.7, + .joy_thresh = 0.7, + .peace_thresh = 0.7, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + { + .test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3", + .custom_negative_score = 0.1, + .drama_score = 0.3, + .llama_score = 0.6, + .drama_thresh = 0.4, + .llama_thresh = 0.8, + .canned_negative_score = 0.3, + .bazinga_score = 0.2, + .joy_score = 0.3, + .peace_score = 0.2, + .bazinga_thresh = 0.5, + .joy_thresh = 0.5, + .peace_thresh = 0.5, + .max_scoring_label = "Negative", + .max_score = 0.1, + }, + }), + [](const testing::TestParamInfo< + CombinedPredictionCalculatorTest::ParamType>& info) { + return info.param.test_name; + }); + +} // namespace + +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 333edb6fb..d4ab16ac8 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -39,7 +39,9 @@ limitations under the License. #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" @@ -76,31 +78,6 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; constexpr int kMicroSecondsPerMilliSecond = 1000; -// Returns a NormalizedRect filling the whole image. If input is present, its -// rotation is set in the returned NormalizedRect and a check is performed to -// make sure no region-of-interest was provided. Otherwise, rotation is set to -// 0. -absl::StatusOr FillNormalizedRect( - std::optional normalized_rect) { - NormalizedRect result; - if (normalized_rect.has_value()) { - result = *normalized_rect; - } - bool has_coordinates = result.has_x_center() || result.has_y_center() || - result.has_width() || result.has_height(); - if (has_coordinates) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "GestureRecognizer does not support region-of-interest.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - result.set_x_center(0.5); - result.set_y_center(0.5); - result.set_width(1); - result.set_height(1); - return result; -} - // Creates a MediaPipe graph config that contains a subgraph node of // "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running // in the live stream mode, a "FlowLimiterCalculator" will be added to limit the @@ -136,57 +113,38 @@ CalculatorGraphConfig CreateGraphConfig( std::unique_ptr ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); - bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE; - - // TODO remove these workarounds for base options of subgraphs. // Configure hand detector options. - auto base_options_proto_for_hand_detector = - std::make_unique( - tasks::core::ConvertBaseOptionsToProto( - &(options->base_options_for_hand_detector))); - base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode); auto* hand_detector_graph_options = options_proto->mutable_hand_landmarker_graph_options() ->mutable_hand_detector_graph_options(); - hand_detector_graph_options->mutable_base_options()->Swap( - base_options_proto_for_hand_detector.get()); hand_detector_graph_options->set_num_hands(options->num_hands); hand_detector_graph_options->set_min_detection_confidence( options->min_hand_detection_confidence); // Configure hand landmark detector options. - auto base_options_proto_for_hand_landmarker = - std::make_unique( - tasks::core::ConvertBaseOptionsToProto( - &(options->base_options_for_hand_landmarker))); - base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode); - auto* hand_landmarks_detector_graph_options = - options_proto->mutable_hand_landmarker_graph_options() - ->mutable_hand_landmarks_detector_graph_options(); - hand_landmarks_detector_graph_options->mutable_base_options()->Swap( - base_options_proto_for_hand_landmarker.get()); - hand_landmarks_detector_graph_options->set_min_detection_confidence( - options->min_hand_presence_confidence); - auto* hand_landmarker_graph_options = options_proto->mutable_hand_landmarker_graph_options(); hand_landmarker_graph_options->set_min_tracking_confidence( options->min_tracking_confidence); + auto* hand_landmarks_detector_graph_options = + hand_landmarker_graph_options + ->mutable_hand_landmarks_detector_graph_options(); + hand_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_hand_presence_confidence); // Configure hand gesture recognizer options. - auto base_options_proto_for_gesture_recognizer = - std::make_unique( - tasks::core::ConvertBaseOptionsToProto( - &(options->base_options_for_gesture_recognizer))); - base_options_proto_for_gesture_recognizer->set_use_stream_mode( - use_stream_mode); auto* hand_gesture_recognizer_graph_options = options_proto->mutable_hand_gesture_recognizer_graph_options(); - hand_gesture_recognizer_graph_options->mutable_base_options()->Swap( - base_options_proto_for_gesture_recognizer.get()); if (options->min_gesture_confidence >= 0) { - hand_gesture_recognizer_graph_options->mutable_classifier_options() + hand_gesture_recognizer_graph_options + ->mutable_canned_gesture_classifier_graph_options() + ->mutable_classifier_options() ->set_score_threshold(options->min_gesture_confidence); } return options_proto; @@ -248,15 +206,16 @@ absl::StatusOr> GestureRecognizer::Create( absl::StatusOr GestureRecognizer::Recognize( mediapipe::Image image, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -283,15 +242,16 @@ absl::StatusOr GestureRecognizer::Recognize( absl::StatusOr GestureRecognizer::RecognizeForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -321,15 +281,16 @@ absl::StatusOr GestureRecognizer::RecognizeForVideo( absl::Status GestureRecognizer::RecognizeAsync( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 750a99797..3e281b26e 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -39,12 +39,6 @@ struct GestureRecognizerOptions { // model file with metadata, accelerator options, op resolver, etc. tasks::core::BaseOptions base_options; - // TODO: remove these. Temporary solutions before bundle asset is - // ready. - tasks::core::BaseOptions base_options_for_hand_landmarker; - tasks::core::BaseOptions base_options_for_hand_detector; - tasks::core::BaseOptions base_options_for_gesture_recognizer; - // The running mode of the task. Default to the image mode. // GestureRecognizer has three running modes: // 1) The image mode for recognizing hand gestures on single image inputs. @@ -129,36 +123,36 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // Only use this method when the GestureRecognizer is created with the image // running mode. // - // image - mediapipe::Image - // Image to perform hand gesture recognition on. - // imageProcessingOptions - std::optional - // If provided, can be used to specify the rotation to apply to the image - // before performing classification, by setting its 'rotation' field in - // radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that - // specifying a region-of-interest using the 'x_center', 'y_center', 'width' - // and 'height' fields is NOT supported and will result in an invalid - // argument error being returned. + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. // // The image can be of any size with format RGB or RGBA. // TODO: Describes how the input image will be preprocessed // after the yuv support is implemented. - // TODO: use an ImageProcessingOptions struct instead of - // NormalizedRect. absl::StatusOr Recognize( Image image, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Performs gesture recognition on the provided video frame. // Only use this method when the GestureRecognizer is created with the video // running mode. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. absl::StatusOr RecognizeForVideo(Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Sends live image data to perform gesture recognition, and the results will @@ -171,6 +165,12 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // sent to the gesture recognizer. The input timestamps must be monotonically // increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The "result_callback" provides // - A vector of GestureRecognitionResult, each is the recognized results // for a input frame. @@ -180,7 +180,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. absl::Status RecognizeAsync(Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Shuts down the GestureRecognizer when all works are done. diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index e02eadde8..7ab4847dd 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -25,9 +25,13 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" @@ -46,6 +50,8 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: GestureRecognizerGraphOptions; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: @@ -61,6 +67,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task"; +constexpr char kHandGestureRecognizerBundleAssetName[] = + "hand_gesture_recognizer.task"; struct GestureRecognizerOutputs { Source> gesture; @@ -70,6 +79,53 @@ struct GestureRecognizerOutputs { Source image; }; +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + GestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto hand_landmarker_file, + resources.GetModelFile(kHandLandmarkerBundleAssetName)); + auto* hand_landmarker_graph_options = + options->mutable_hand_landmarker_graph_options(); + SetExternalFile(hand_landmarker_file, + hand_landmarker_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_landmarker_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + + ASSIGN_OR_RETURN( + const auto hand_gesture_recognizer_file, + resources.GetModelFile(kHandGestureRecognizerBundleAssetName)); + auto* hand_gesture_recognizer_graph_options = + options->mutable_hand_gesture_recognizer_graph_options(); + SetExternalFile(hand_gesture_recognizer_file, + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + if (!hand_gesture_recognizer_graph_options->base_options() + .acceleration() + .has_xnnpack() && + !hand_gesture_recognizer_graph_options->base_options() + .acceleration() + .has_tflite()) { + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_acceleration() + ->mutable_xnnpack(); + LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " + << "HandGestureRecognizerGraph acceleartion to Xnnpack."; + } + hand_gesture_recognizer_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + } // namespace // A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs @@ -136,6 +192,21 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // When the model resources cache service is available, filling in + // the file pointer meta in the subtasks' base options. Otherwise, + // providing the file contents instead. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, BuildGestureRecognizerGraph( *sc->MutableOptions(), diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 4bbe94974..7b7746956 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -30,11 +30,17 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -51,6 +57,8 @@ using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::processors:: ConfigureTensorsToClassificationCalculator; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: HandGestureRecognizerGraphOptions; @@ -70,6 +78,14 @@ constexpr char kVectorTag[] = "VECTOR"; constexpr char kIndexTag[] = "INDEX"; constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; +constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite"; +constexpr char kCannedGestureClassifierTFLiteName[] = + "canned_gesture_classifier.tflite"; + +struct SubTaskModelResources { + const core::ModelResources* gesture_embedder_model_resource; + const core::ModelResources* canned_gesture_classifier_model_resource; +}; Source> ConvertMatrixToTensor(Source matrix, Graph& graph) { @@ -78,6 +94,41 @@ Source> ConvertMatrixToTensor(Source matrix, return node[Output>{"TENSORS"}]; } +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + HandGestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto gesture_embedder_file, + resources.GetModelFile(kGestureEmbedderTFLiteName)); + auto* gesture_embedder_graph_options = + options->mutable_gesture_embedder_graph_options(); + SetExternalFile(gesture_embedder_file, + gesture_embedder_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + gesture_embedder_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + + ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file, + resources.GetModelFile(kCannedGestureClassifierTFLiteName)); + auto* canned_gesture_classifier_graph_options = + options->mutable_canned_gesture_classifier_graph_options(); + SetExternalFile( + canned_gesture_classifier_file, + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + canned_gesture_classifier_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + } // namespace // A @@ -128,27 +179,70 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN( - const auto* model_resources, - CreateModelResources(sc)); + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources( + sc)); + // When the model resources cache service is available, filling in + // the file pointer meta in the subtasks' base options. Otherwise, + // providing the file contents instead. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN(const auto sub_task_model_resources, + CreateSubTaskModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN( - auto hand_gestures, - BuildGestureRecognizerGraph( - sc->Options(), *model_resources, - graph[Input(kHandednessTag)], - graph[Input(kLandmarksTag)], - graph[Input(kWorldLandmarksTag)], - graph[Input>(kImageSizeTag)], - graph[Input(kNormRectTag)], graph)); + ASSIGN_OR_RETURN(auto hand_gestures, + BuildGestureRecognizerGraph( + sc->Options(), + sub_task_model_resources, + graph[Input(kHandednessTag)], + graph[Input(kLandmarksTag)], + graph[Input(kWorldLandmarksTag)], + graph[Input>(kImageSizeTag)], + graph[Input(kNormRectTag)], graph)); hand_gestures >> graph[Output(kHandGesturesTag)]; return graph.GetConfig(); } private: + absl::StatusOr CreateSubTaskModelResources( + SubgraphContext* sc) { + auto* options = sc->MutableOptions(); + SubTaskModelResources sub_task_model_resources; + auto& gesture_embedder_model_asset = + *options->mutable_gesture_embedder_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.gesture_embedder_model_resource, + CreateModelResources(sc, + std::make_unique( + std::move(gesture_embedder_model_asset)), + "_gesture_embedder")); + auto& canned_gesture_classifier_model_asset = + *options->mutable_canned_gesture_classifier_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.canned_gesture_classifier_model_resource, + CreateModelResources( + sc, + std::make_unique( + std::move(canned_gesture_classifier_model_asset)), + "_canned_gesture_classifier")); + return sub_task_model_resources; + } + absl::StatusOr> BuildGestureRecognizerGraph( const HandGestureRecognizerGraphOptions& graph_options, - const core::ModelResources& model_resources, + const SubTaskModelResources& sub_task_model_resources, Source handedness, Source hand_landmarks, Source hand_world_landmarks, @@ -209,17 +303,33 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { auto concatenated_tensors = concatenate_tensor_vector.Out(""); // Inference for static hand gesture recognition. - // TODO add embedding step. - auto& inference = AddInference( - model_resources, graph_options.base_options().acceleration(), graph); - concatenated_tensors >> inference.In(kTensorsTag); - auto inference_output_tensors = inference.Out(kTensorsTag); + auto& gesture_embedder_inference = + AddInference(*sub_task_model_resources.gesture_embedder_model_resource, + graph_options.gesture_embedder_graph_options() + .base_options() + .acceleration(), + graph); + concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag); + auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag); + + auto& canned_gesture_classifier_inference = AddInference( + *sub_task_model_resources.canned_gesture_classifier_model_resource, + graph_options.canned_gesture_classifier_graph_options() + .base_options() + .acceleration(), + graph); + embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag); + auto inference_output_tensors = + canned_gesture_classifier_inference.Out(kTensorsTag); auto& tensors_to_classification = graph.AddNode("TensorsToClassificationCalculator"); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( - graph_options.classifier_options(), - *model_resources.GetMetadataExtractor(), 0, + graph_options.canned_gesture_classifier_graph_options() + .classifier_options(), + *sub_task_model_resources.canned_gesture_classifier_model_resource + ->GetMetadataExtractor(), + 0, &tensors_to_classification.GetOptions< mediapipe::TensorsToClassificationCalculatorOptions>())); inference_output_tensors >> tensors_to_classification.In(kTensorsTag); diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD index 3b73bf2b0..0db47da7a 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -49,7 +49,6 @@ mediapipe_proto_library( ":gesture_embedder_graph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index a3281702a..7df2fed37 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -18,7 +18,6 @@ syntax = "proto2"; package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; @@ -37,15 +36,11 @@ message HandGestureRecognizerGraphOptions { // Options for GestureEmbedder. optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2; - // Options for GestureClassifier of default gestures. + // Options for GestureClassifier of canned gestures. optional GestureClassifierGraphOptions canned_gesture_classifier_graph_options = 3; // Options for GestureClassifier of custom gestures. optional GestureClassifierGraphOptions custom_gesture_classifier_graph_options = 4; - - // TODO: remove these. Temporary solutions before bundle asset is - // ready. - optional components.processors.proto.ClassifierOptions classifier_options = 5; } diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index e876d7d09..06bb2e549 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph { image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 3fbe38c1c..e610a412e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -92,18 +92,30 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, bool is_copy) { ASSIGN_OR_RETURN(const auto hand_detector_file, resources.GetModelFile(kHandDetectorTFLiteName)); + auto* hand_detector_graph_options = + options->mutable_hand_detector_graph_options(); SetExternalFile(hand_detector_file, - options->mutable_hand_detector_graph_options() - ->mutable_base_options() + hand_detector_graph_options->mutable_base_options() ->mutable_model_asset(), is_copy); + hand_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_detector_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + auto* hand_landmarks_detector_graph_options = + options->mutable_hand_landmarks_detector_graph_options(); SetExternalFile(hand_landmarks_detector_file, - options->mutable_hand_landmarks_detector_graph_options() - ->mutable_base_options() + hand_landmarks_detector_graph_options->mutable_base_options() ->mutable_model_asset(), is_copy); + hand_landmarks_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_landmarks_detector_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index 08beb1a1b..f275486f5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -67,7 +67,7 @@ using ::testing::proto::Approximately; using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task"; +constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task"; constexpr char kLeftHandsImage[] = "left_hands.jpg"; constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 23521790d..1f127deb8 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index dfa77cb96..3d655cd50 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -59,6 +59,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index f3dcdd07d..8a32758f4 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -59,26 +60,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; -// Returns a NormalizedRect covering the full image if input is not present. -// Otherwise, makes sure the x_center, y_center, width and height are set in -// case only a rotation was provided in the input. -NormalizedRect FillNormalizedRect( - std::optional normalized_rect) { - NormalizedRect result; - if (normalized_rect.has_value()) { - result = *normalized_rect; - } - bool has_coordinates = result.has_x_center() || result.has_y_center() || - result.has_width() || result.has_height(); - if (!has_coordinates) { - result.set_x_center(0.5); - result.set_y_center(0.5); - result.set_width(1); - result.set_height(1); - } - return result; -} - // Creates a MediaPipe graph config that contains a subgraph node of // type "ImageClassifierGraph". If the task is running in the live stream mode, // a "FlowLimiterCalculator" will be added to limit the number of frames in @@ -164,14 +145,16 @@ absl::StatusOr> ImageClassifier::Create( } absl::StatusOr ImageClassifier::Classify( - Image image, std::optional image_processing_options) { + Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -183,14 +166,15 @@ absl::StatusOr ImageClassifier::Classify( absl::StatusOr ImageClassifier::ClassifyForVideo( Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -206,14 +190,15 @@ absl::StatusOr ImageClassifier::ClassifyForVideo( absl::Status ImageClassifier::ClassifyAsync( Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 5dff06cc7..de69b7994 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -22,11 +22,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -109,12 +109,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // // The optional 'image_processing_options' parameter can be used to specify: // - the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). + // setting its 'rotation_degrees' field. // and/or // - the region-of-interest on which to perform classification, by setting its - // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is - // set, they will automatically be set to cover the full image. + // 'region_of_interest' field. If not specified, the full image is used. // If both are specified, the crop around the region-of-interest is extracted // first, then the specified rotation is applied to the crop. // @@ -126,19 +124,17 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // YUVToImageCalculator is integrated. absl::StatusOr Classify( mediapipe::Image image, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Performs image classification on the provided video frame. // // The optional 'image_processing_options' parameter can be used to specify: // - the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). + // setting its 'rotation_degrees' field. // and/or // - the region-of-interest on which to perform classification, by setting its - // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is - // set, they will automatically be set to cover the full image. + // 'region_of_interest' field. If not specified, the full image is used. // If both are specified, the crop around the region-of-interest is extracted // first, then the specified rotation is applied to the crop. // @@ -150,7 +146,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Sends live image data to image classification, and the results will be @@ -158,12 +154,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // // The optional 'image_processing_options' parameter can be used to specify: // - the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). + // setting its 'rotation_degrees' field. // and/or // - the region-of-interest on which to perform classification, by setting its - // 'x_center', 'y_center', 'width' and 'height' fields. If none of these is - // set, they will automatically be set to cover the full image. + // 'region_of_interest' field. If not specified, the full image is used. // If both are specified, the crop around the region-of-interest is extracted // first, then the specified rotation is applied to the crop. // @@ -175,7 +169,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // sent to the object detector. The input timestamps must be monotonically // increasing. // - // The "result_callback" prvoides + // The "result_callback" provides: // - The classification results as a ClassificationResult object. // - The const reference to the corresponding input image that the image // classifier runs on. Note that the const reference to the image will no @@ -183,12 +177,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); - // TODO: add Classify() variants taking a region of interest as - // additional argument. - // Shuts down the ImageClassifier when all works are done. absl::Status Close() { return runner_->Close(); } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 9a0078c5c..8a1b17ce9 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 55830e520..0c45122c0 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -35,6 +34,8 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -49,9 +50,11 @@ namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::Classifications; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -547,12 +550,9 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // Crop around the soccer ball. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.532); - image_processing_options.set_y_center(0.521); - image_processing_options.set_width(0.164); - image_processing_options.set_height(0.427); + // Region-of-interest around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); @@ -572,8 +572,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { ImageClassifier::Create(std::move(options))); // Specify a 90° anti-clockwise rotation. - NormalizedRect image_processing_options; - image_processing_options.set_rotation(M_PI / 2.0); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); @@ -616,13 +616,10 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // Crop around the chair, with 90° anti-clockwise rotation. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.2821); - image_processing_options.set_y_center(0.2406); - image_processing_options.set_width(0.5642); - image_processing_options.set_height(0.1286); - image_processing_options.set_rotation(M_PI / 2.0); + // Region-of-interest around the chair, with 90° anti-clockwise rotation. + Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/-90}; MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); @@ -633,7 +630,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { entries { categories { index: 560 - score: 0.6800408 + score: 0.6522213 category_name: "folding chair" } timestamp_ms: 0 @@ -643,6 +640,69 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { })pb")); } +// Testing all these once with ImageClassifier. +TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + // Invalid: left > right. + Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/0}; + auto results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect with left < right and top < bottom")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: top > bottom. + roi = {/*left=*/0, /*top=*/0.9, /*right=*/1, /*bottom=*/0.1}; + image_processing_options = {roi, + /*rotation_degrees=*/0}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect with left < right and top < bottom")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: coordinates out of [0,1] range. + roi = {/*left=*/-0.1, /*top=*/0, /*right=*/1, /*bottom=*/1}; + image_processing_options = {roi, + /*rotation_degrees=*/0}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect values to be in [0,1]")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: rotation not a multiple of 90°. + image_processing_options = {/*region_of_interest=*/std::nullopt, + /*rotation_degrees=*/1}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected rotation to be a multiple of 90°")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { @@ -732,11 +792,9 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.532); - image_processing_options.set_y_center(0.521); - image_processing_options.set_width(0.164); - image_processing_options.set_height(0.427); + // Region-of-interest around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN( @@ -877,11 +935,8 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); // Crop around the soccer ball. - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.532); - image_processing_options.set_y_center(0.521); - image_processing_options.set_width(0.164); - image_processing_options.set_height(0.427); + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK( diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index e619b8d1b..0f63f87e4 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -58,6 +58,7 @@ cc_library( "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index 24fd2862c..1dc316305 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::vision::image_embedder::proto:: ImageEmbedderGraphOptions; -// Builds a NormalizedRect covering the entire image. -NormalizedRect BuildFullImageNormRect() { - NormalizedRect norm_rect; - norm_rect.set_x_center(0.5); - norm_rect.set_y_center(0.5); - norm_rect.set_width(1); - norm_rect.set_height(1); - return norm_rect; -} - // Creates a MediaPipe graph config that contains a single node of type // "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is // running in the live stream mode, a "FlowLimiterCalculator" will be added to @@ -148,15 +139,16 @@ absl::StatusOr> ImageEmbedder::Create( } absl::StatusOr ImageEmbedder::Embed( - Image image, std::optional roi) { + Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -167,15 +159,16 @@ absl::StatusOr ImageEmbedder::Embed( } absl::StatusOr ImageEmbedder::EmbedForVideo( - Image image, int64 timestamp_ms, std::optional roi) { + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -188,16 +181,17 @@ absl::StatusOr ImageEmbedder::EmbedForVideo( return output_packets[kEmbeddingResultStreamName].Get(); } -absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms, - std::optional roi) { +absl::Status ImageEmbedder::EmbedAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h index 13f4702d1..3a2a1dbee 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h @@ -21,11 +21,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Performs embedding extraction on the provided single image. Extraction - // is performed on the region of interest specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs embedding extraction on the provided single image. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the image // running mode. @@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. absl::StatusOr Embed( mediapipe::Image image, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); - // Performs embedding extraction on the provided video frame. Extraction - // is performed on the region of interested specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs embedding extraction on the provided video frame. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the video // running mode. @@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr EmbedForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); // Sends live image data to embedder, and the results will be available via - // the "result_callback" provided in the ImageEmbedderOptions. Embedding - // extraction is performed on the region of interested specified by the `roi` - // argument if provided, or on the entire image otherwise. + // the "result_callback" provided in the ImageEmbedderOptions. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the live // stream running mode. @@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status EmbedAsync( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageEmbedder when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index fff0f4366..f0f440986 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index db1019b33..386b6c8eb 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" @@ -42,7 +41,9 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN( Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); - // Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". - NormalizedRect roi; - roi.set_x_center(200.0 / 480); - roi.set_y_center(0.5); - roi.set_width(400.0 / 480); - roi.set_height(1.0f); + // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". + Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, - image_embedder->Embed(image, roi)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& image_result, + image_embedder->Embed(image, image_processing_options)); MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, image_embedder->Embed(crop)); @@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a rotated version of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN(Image rotated, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& rotated_result, + image_embedder->Embed(rotated, image_processing_options)); + + // Check results. + CheckMobileNetV3Result(image_result, false); + CheckMobileNetV3Result(rotated_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + rotated_result.embeddings(0).entries(0))); + double expected_similarity = 0.572265; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + MP_ASSERT_OK_AND_ASSIGN(Image rotated, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + // Region-of-interest corresponding to burger_crop.jpg. + Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/-90}; + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& rotated_result, + image_embedder->Embed(rotated, image_processing_options)); + + // Check results. + CheckMobileNetV3Result(crop_result, false); + CheckMobileNetV3Result(rotated_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), + rotated_result.embeddings(0).entries(0))); + double expected_similarity = 0.62838; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 6bdbf41da..81cd43e34 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -24,10 +24,12 @@ cc_library( ":image_segmenter_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", @@ -48,6 +50,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 84ceea88a..209ee0df3 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -17,8 +17,10 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" @@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig( auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap(options.get()); graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> graph.Out(kGroupedSegmentationTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator( - graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag); + return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, + {kImageTag, kNormRectTag}, + kGroupedSegmentationTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); return graph.GetConfig(); } @@ -139,47 +146,68 @@ absl::StatusOr> ImageSegmenter::Create( } absl::StatusOr> ImageSegmenter::Segment( - mediapipe::Image image) { + mediapipe::Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, - ProcessImageData({{kImageInStreamName, - mediapipe::MakePacket(std::move(image))}})); + ProcessImageData( + {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); return output_packets[kSegmentationStreamName].Get>(); } absl::StatusOr> ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms) { + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); return output_packets[kSegmentationStreamName].Get>(); } -absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) { +absl::Status ImageSegmenter::SegmentAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index e2734c4e4..54269ec0e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -25,6 +25,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "tensorflow/lite/kernels/register.h" @@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // running mode. // // The image can be of any size with format RGB or RGBA. - // TODO: Describes how the input image will be preprocessed - // after the yuv support is implemented. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. // // If the output_type is CATEGORY_MASK, the returned vector of images is // per-category segmented image mask. // If the output_type is CONFIDENCE_MASK, the returned vector of images // contains only one confidence image mask. - absl::StatusOr> Segment(mediapipe::Image image); + absl::StatusOr> Segment( + mediapipe::Image image, + std::optional image_processing_options = + std::nullopt); // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video @@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // If the output_type is CATEGORY_MASK, the returned vector of images is // per-category segmented image mask. // If the output_type is CONFIDENCE_MASK, the returned vector of images // contains only one confidence image mask. absl::StatusOr> SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms); + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the @@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The "result_callback" prvoides // - A vector of segmented image masks. // If the output_type is CATEGORY_MASK, the returned vector of images is @@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms); + absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 1678dd083..d3e522d92 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" @@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; @@ -159,6 +161,10 @@ absl::StatusOr GetOutputTensor( // Inputs: // IMAGE - Image // Image to perform segmentation on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection +// on. +// @Optional: rect covering the whole image is used if not specified. // // Outputs: // SEGMENTATION - mediapipe::Image @Multiple @@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto output_streams, - BuildSegmentationTask( - sc->Options(), *model_resources, - graph[Input(kImageTag)], graph)); + ASSIGN_OR_RETURN( + auto output_streams, + BuildSegmentationTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); @@ -228,18 +236,21 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index ab23a725c..07235563b 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -29,8 +29,10 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -44,6 +46,8 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 21); @@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "cat_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + EXPECT_EQ(confidence_masks.size(), 21); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + // Cat category index 8. + cv::Mat cat_mask = mediapipe::formats::MatView( + confidence_masks[8].GetImageFrameSharedPtr().get()); + EXPECT_THAT(cat_mask, + SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = segmenter->Segment(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 186909509..8220d8b7f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -75,6 +75,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index 9149a3cbe..dd19237ff 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h" @@ -58,31 +59,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; -// Returns a NormalizedRect filling the whole image. If input is present, its -// rotation is set in the returned NormalizedRect and a check is performed to -// make sure no region-of-interest was provided. Otherwise, rotation is set to -// 0. -absl::StatusOr FillNormalizedRect( - std::optional normalized_rect) { - NormalizedRect result; - if (normalized_rect.has_value()) { - result = *normalized_rect; - } - bool has_coordinates = result.has_x_center() || result.has_y_center() || - result.has_width() || result.has_height(); - if (has_coordinates) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "ObjectDetector does not support region-of-interest.", - MediaPipeTasksStatus::kInvalidArgumentError); - } - result.set_x_center(0.5); - result.set_y_center(0.5); - result.set_width(1); - result.set_height(1); - return result; -} - // Creates a MediaPipe graph config that contains a subgraph node of // "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the // live stream mode, a "FlowLimiterCalculator" will be added to limit the @@ -170,15 +146,16 @@ absl::StatusOr> ObjectDetector::Create( absl::StatusOr> ObjectDetector::Detect( mediapipe::Image image, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -189,15 +166,16 @@ absl::StatusOr> ObjectDetector::Detect( absl::StatusOr> ObjectDetector::DetectForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -212,15 +190,16 @@ absl::StatusOr> ObjectDetector::DetectForVideo( absl::Status ObjectDetector::DetectAsync( Image image, int64 timestamp_ms, - std::optional image_processing_options) { + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(NormalizedRect norm_rect, - FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 2e5ed7b8d..44ce68ed9 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -27,9 +27,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -154,10 +154,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // after the yuv support is implemented. // // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). Note that specifying a region-of-interest using - // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // // For CPU images, the returned bounding boxes are expressed in the @@ -168,7 +167,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // images after enabling the gpu support in MediaPipe Tasks. absl::StatusOr> Detect( mediapipe::Image image, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Performs object detection on the provided video frame. @@ -180,10 +179,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. // // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). Note that specifying a region-of-interest using - // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // // For CPU images, the returned bounding boxes are expressed in the @@ -192,7 +190,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // underlying image data. absl::StatusOr> DetectForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional image_processing_options = + std::optional image_processing_options = std::nullopt); // Sends live image data to perform object detection, and the results will be @@ -206,10 +204,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // increasing. // // The optional 'image_processing_options' parameter can be used to specify - // the rotation to apply to the image before performing classification, by - // setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° - // anti-clockwise rotation). Note that specifying a region-of-interest using - // the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported // and will result in an invalid argument error being returned. // // The "result_callback" provides @@ -223,7 +220,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms, - std::optional + std::optional image_processing_options = std::nullopt); // Shuts down the ObjectDetector when all works are done. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index 07e912cfc..b149cea0f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // stream. auto& preprocessing = graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + bool use_gpu = components::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( - model_resources, + model_resources, use_gpu, &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 8db3fa767..1747685dd 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -31,11 +31,12 @@ limitations under the License. #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/location_data.pb.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/c/common.h" @@ -64,6 +65,8 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -532,8 +535,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - NormalizedRect image_processing_options; - image_processing_options.set_rotation(M_PI / 2.0); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; MP_ASSERT_OK_AND_ASSIGN( auto results, object_detector->Detect(image, image_processing_options)); MP_ASSERT_OK(object_detector->Close()); @@ -557,16 +560,17 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); - NormalizedRect image_processing_options; - image_processing_options.set_x_center(0.5); - image_processing_options.set_y_center(0.5); - image_processing_options.set_width(1.0); - image_processing_options.set_height(1.0); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; auto results = object_detector->Detect(image, image_processing_options); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(results.status().message(), - HasSubstr("ObjectDetector does not support region-of-interest")); + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); } class VideoModeTest : public tflite_shims::testing::Test {}; diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD index acbdbd6eb..89c1edcb3 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -31,6 +31,7 @@ android_binary( multidex = "native", resource_files = ["//mediapipe/tasks/examples/android:resource_files"], deps = [ + "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java index 7f7ec1389..18c010a00 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector; import android.content.Intent; import android.graphics.Bitmap; -import android.graphics.Matrix; import android.media.MediaMetadataRetriever; import android.os.Bundle; import android.provider.MediaStore; @@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.exifinterface.media.ExifInterface; // ContentResolver dependency +import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; @@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity { if (resultIntent != null) { if (result.getResultCode() == RESULT_OK) { Bitmap bitmap = null; + int rotation = 0; try { bitmap = downscaleBitmap( @@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity { try { InputStream imageData = this.getContentResolver().openInputStream(resultIntent.getData()); - bitmap = rotateBitmap(bitmap, imageData); - } catch (IOException e) { + rotation = getImageRotation(imageData); + } catch (IOException | MediaPipeException e) { Log.e(TAG, "Bitmap rotation error:" + e); } if (bitmap != null) { - Image image = new BitmapImageBuilder(bitmap).build(); - ObjectDetectionResult detectionResult = objectDetector.detect(image); + MPImage image = new BitmapImageBuilder(bitmap).build(); + ObjectDetectionResult detectionResult = + objectDetector.detect( + image, + ImageProcessingOptions.builder().setRotationDegrees(rotation).build()); imageView.setData(image, detectionResult); runOnUiThread(() -> imageView.update()); } @@ -144,7 +149,8 @@ public class MainActivity extends AppCompatActivity { MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); long frameIntervalMs = duration / numFrames; for (int i = 0; i < numFrames; ++i) { - Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); + MPImage image = + new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); ObjectDetectionResult detectionResult = objectDetector.detectForVideo(image, frameIntervalMs * i); // Currently only annotates the detection result on the first video frame and @@ -209,28 +215,25 @@ public class MainActivity extends AppCompatActivity { return Bitmap.createScaledBitmap(originalBitmap, width, height, false); } - private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { + private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException { int orientation = new ExifInterface(imageData) .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); - if (orientation == ExifInterface.ORIENTATION_NORMAL) { - return inputBitmap; - } - Matrix matrix = new Matrix(); switch (orientation) { + case ExifInterface.ORIENTATION_NORMAL: + return 0; case ExifInterface.ORIENTATION_ROTATE_90: - matrix.postRotate(90); - break; + return 90; case ExifInterface.ORIENTATION_ROTATE_180: - matrix.postRotate(180); - break; + return 180; case ExifInterface.ORIENTATION_ROTATE_270: - matrix.postRotate(270); - break; + return 270; default: - matrix.postRotate(0); + // TODO: use getRotationDegrees() and isFlipped() instead of switch once flip + // is supported. + throw new MediaPipeException( + MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(), + "Flipped images are not supported yet."); } - return Bitmap.createBitmap( - inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true); } } diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java index 94a4a90dc..283e48857 100644 --- a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java @@ -22,7 +22,7 @@ import android.graphics.Matrix; import android.graphics.Paint; import androidx.appcompat.widget.AppCompatImageView; import com.google.mediapipe.framework.image.BitmapExtractor; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; @@ -40,12 +40,12 @@ public class ObjectDetectionResultImageView extends AppCompatImageView { } /** - * Sets an {@link Image} and an {@link ObjectDetectionResult} to render. + * Sets a {@link MPImage} and an {@link ObjectDetectionResult} to render. * - * @param image an {@link Image} object for annotation. + * @param image a {@link MPImage} object for annotation. * @param result an {@link ObjectDetectionResult} object that contains the detection result. */ - public void setData(Image image, ObjectDetectionResult result) { + public void setData(MPImage image, ObjectDetectionResult result) { if (image == null || result == null) { return; } diff --git a/mediapipe/tasks/java/BUILD b/mediapipe/tasks/java/BUILD index 024510737..7e6283261 100644 --- a/mediapipe/tasks/java/BUILD +++ b/mediapipe/tasks/java/BUILD @@ -1 +1,15 @@ -# dummy file for tap test to find the pattern +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD new file mode 100644 index 000000000..7e6283261 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD @@ -0,0 +1,15 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD index b4ebfe8cc..cb9d67424 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/BUILD @@ -36,3 +36,15 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar") + +mediapipe_tasks_core_aar( + name = "tasks_core", + srcs = glob(["*.java"]) + [ + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src", + "//mediapipe/java/com/google/mediapipe/framework/image:java_src", + ], + manifest = "AndroidManifest.xml", +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl new file mode 100644 index 000000000..0260e3fab --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -0,0 +1,256 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Building MediaPipe Tasks AARs.""" + +load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build_aar_with_jni", "mediapipe_java_proto_src_extractor", "mediapipe_java_proto_srcs") +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + +_CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite", +] + +_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", +] + +_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", +] + +def mediapipe_tasks_core_aar(name, srcs, manifest): + """Builds medaipipe tasks core AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Tasks' core layer source files. + manifest: The Android manifest. + """ + + mediapipe_tasks_java_proto_srcs = [] + for target in _CORE_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + + for target in _VISION_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + + for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS: + mediapipe_tasks_java_proto_srcs.append( + _mediapipe_tasks_java_proto_src_extractor(target = target), + ) + + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java", + )) + + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + src_out = "com/google/mediapipe/calculator/proto/InferenceCalculatorProto.java", + )) + + android_library( + name = name, + srcs = srcs + [ + "//mediapipe/java/com/google/mediapipe/framework:java_src", + ] + mediapipe_java_proto_srcs() + + mediapipe_tasks_java_proto_srcs, + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = manifest, + deps = [ + "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite", + "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework:calculator_profile_java_proto_lite", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:mediapipe_options_java_proto_lite", + "//mediapipe/framework:packet_factory_java_proto_lite", + "//mediapipe/framework:packet_generator_java_proto_lite", + "//mediapipe/framework:status_handler_java_proto_lite", + "//mediapipe/framework:stream_handler_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + "//third_party:androidx_annotation", + "//third_party:autovalue", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:com_google_guava_guava", + "@maven//:com_google_flogger_flogger", + "@maven//:com_google_flogger_flogger_system_backend", + "@maven//:com_google_code_findbugs_jsr305", + ] + + _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS + + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS, + ) + +def mediapipe_tasks_vision_aar(name, srcs, native_library): + """Builds medaipipe tasks vision AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Vision Tasks' source files. + native_library: The native library that contains vision tasks' graph and calculators. + """ + + native.genrule( + name = name + "tasks_manifest_generator", + outs = ["AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + + +EOF +""", + ) + + _mediapipe_tasks_aar( + name = name, + srcs = srcs, + manifest = "AndroidManifest.xml", + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, + native_library = native_library, + ) + +def mediapipe_tasks_text_aar(name, srcs, native_library): + """Builds medaipipe tasks text AAR. + + Args: + name: The bazel target name. + srcs: MediaPipe Text Tasks' source files. + native_library: The native library that contains text tasks' graph and calculators. + """ + + native.genrule( + name = name + "tasks_manifest_generator", + outs = ["AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + + +EOF +""", + ) + + _mediapipe_tasks_aar( + name = name, + srcs = srcs, + manifest = "AndroidManifest.xml", + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS, + native_library = native_library, + ) + +def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library): + """Builds medaipipe tasks AAR.""" + + # When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command, + # the OpenCV so libraries will be excluded from the AAR package to + # save the package size. + native.config_setting( + name = "exclude_opencv_so_lib", + define_values = { + "EXCLUDE_OPENCV_SO_LIB": "1", + }, + visibility = ["//visibility:public"], + ) + + native.cc_library( + name = name + "_jni_opencv_cc_lib", + srcs = select({ + "//mediapipe:android_arm64": ["@android_opencv//:libopencv_java3_so_arm64-v8a"], + "//mediapipe:android_armeabi": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"], + "//mediapipe:android_arm": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"], + "//mediapipe:android_x86": ["@android_opencv//:libopencv_java3_so_x86"], + "//mediapipe:android_x86_64": ["@android_opencv//:libopencv_java3_so_x86_64"], + "//conditions:default": [], + }), + alwayslink = 1, + ) + + android_library( + name = name + "_android_lib", + srcs = srcs, + manifest = manifest, + proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], + deps = java_proto_lite_targets + [native_library] + [ + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework:calculator_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/framework/formats:location_data_java_proto_lite", + "//mediapipe/framework/formats:rect_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ] + select({ + "//conditions:default": [":" + name + "_jni_opencv_cc_lib"], + "//mediapipe/framework/port:disable_opencv": [], + "exclude_opencv_so_lib": [], + }), + ) + + mediapipe_build_aar_with_jni(name, name + "_android_lib") + +def _mediapipe_tasks_java_proto_src_extractor(target): + proto_path = "com/google/" + target.split(":")[0].replace("cc/", "").replace("//", "").replace("_", "") + "/" + proto_name = target.split(":")[-1].replace("_java_proto_lite", "").replace("_", " ").title().replace(" ", "") + "Proto.java" + return mediapipe_java_proto_src_extractor( + target = target, + src_out = proto_path + proto_name, + ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 1719707d8..fa2a547c2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -61,3 +61,11 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar") + +mediapipe_tasks_text_aar( + name = "tasks_text", + srcs = glob(["**/*.java"]), + native_library = ":libmediapipe_tasks_text_jni_lib", +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java index dd9b9a1b3..c1e2446cd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java @@ -15,11 +15,11 @@ package com.google.mediapipe.tasks.text.textclassifier; import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.container.proto.CategoryProto; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.ClassificationEntry; import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 76117d2e4..07a4fa48f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index dcf3b3542..ed65fbcac 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -28,6 +28,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", "@maven//:com_google_guava_guava", ], ) @@ -128,6 +129,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", @@ -140,3 +142,11 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") + +mediapipe_tasks_vision_aar( + name = "tasks_vision", + srcs = glob(["**/*.java"]), + native_library = ":libmediapipe_tasks_vision_jni_lib", +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 7ab8e75a1..0774b69a2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -19,12 +19,11 @@ import com.google.mediapipe.formats.proto.RectProto.NormalizedRect; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.ProtoUtil; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskRunner; import java.util.HashMap; import java.util.Map; -import java.util.Optional; /** The base class of MediaPipe vision tasks. */ public class BaseVisionTaskApi implements AutoCloseable { @@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable { private final TaskRunner runner; private final RunningMode runningMode; private final String imageStreamName; - private final Optional normRectStreamName; + private final String normRectStreamName; static { System.loadLibrary("mediapipe_tasks_vision_jni"); @@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable { } /** - * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input. + * Constructor to initialize a {@link BaseVisionTaskApi}. * * @param runner a {@link TaskRunner}. * @param runningMode a mediapipe vision task {@link RunningMode}. * @param imageStreamName the name of the input image stream. - */ - public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) { - this.runner = runner; - this.runningMode = runningMode; - this.imageStreamName = imageStreamName; - this.normRectStreamName = Optional.empty(); - } - - /** - * Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as - * input. - * - * @param runner a {@link TaskRunner}. - * @param runningMode a mediapipe vision task {@link RunningMode}. - * @param imageStreamName the name of the input image stream. - * @param normRectStreamName the name of the input normalized rect image stream. + * @param normRectStreamName the name of the input normalized rect image stream used to provide + * (mandatory) rotation and (optional) region-of-interest. */ public BaseVisionTaskApi( TaskRunner runner, @@ -70,61 +55,31 @@ public class BaseVisionTaskApi implements AutoCloseable { this.runner = runner; this.runningMode = runningMode; this.imageStreamName = imageStreamName; - this.normRectStreamName = Optional.of(normRectStreamName); + this.normRectStreamName = normRectStreamName; } /** * A synchronous method to process single image inputs. The call blocks the current thread until a * failure status or a successful result is returned. * - * @param image a MediaPipe {@link Image} object for processing. - * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect - * input. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if the task is not in the image mode. */ - protected TaskResult processImageData(Image image) { + protected TaskResult processImageData( + MPImage image, ImageProcessingOptions imageProcessingOptions) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the image mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - return runner.process(inputPackets); - } - - /** - * A synchronous method to process single image inputs. The call blocks the current thread until a - * failure status or a successful result is returned. - * - * @param image a MediaPipe {@link Image} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized - * rect. - */ - protected TaskResult processImageData(Image image, RectF roi) { - if (runningMode != RunningMode.IMAGE) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the image mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); return runner.process(inputPackets); } @@ -132,56 +87,25 @@ public class BaseVisionTaskApi implements AutoCloseable { * A synchronous method to process continuous video frames. The call blocks the current thread * until a failure status or a successful result is returned. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect - * input. + * @throws MediaPipeException if the task is not in the video mode. */ - protected TaskResult processVideoData(Image image, long timestampMs) { + protected TaskResult processVideoData( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the video mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); - } - - /** - * A synchronous method to process continuous video frames. The call blocks the current thread - * until a failure status or a successful result is returned. - * - * @param image a MediaPipe {@link Image} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized - * rect. - */ - protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) { - if (runningMode != RunningMode.VIDEO) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the video mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -189,56 +113,25 @@ public class BaseVisionTaskApi implements AutoCloseable { * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. * - * @param image a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect - * input. + * @throws MediaPipeException if the task is not in the stream mode. */ - protected void sendLiveStreamData(Image image, long timestampMs) { + protected void sendLiveStreamData( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the live stream mode. Current running mode:" + runningMode.name()); } - if (normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task expects a normalized rect as input."); - } - Map inputPackets = new HashMap<>(); - inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); - runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); - } - - /** - * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be - * available in the user-defined result listener. - * - * @param image a MediaPipe {@link Image} object for processing. - * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates - * are expected to be specified as normalized values in [0,1]. - * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized - * rect. - */ - protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) { - if (runningMode != RunningMode.LIVE_STREAM) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task is not initialized with the live stream mode. Current running mode:" - + runningMode.name()); - } - if (!normRectStreamName.isPresent()) { - throw new MediaPipeException( - MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), - "Task doesn't expect a normalized rect as input."); - } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put( - normRectStreamName.get(), - runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + normRectStreamName, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } @@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable { runner.close(); } - /** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */ - private static NormalizedRect convertToNormalizedRect(RectF rect) { + /** + * Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf + * message. + */ + private static NormalizedRect convertToNormalizedRect( + ImageProcessingOptions imageProcessingOptions) { + RectF regionOfInterest = + imageProcessingOptions.regionOfInterest().isPresent() + ? imageProcessingOptions.regionOfInterest().get() + : new RectF(0, 0, 1, 1); return NormalizedRect.newBuilder() - .setXCenter(rect.centerX()) - .setYCenter(rect.centerY()) - .setWidth(rect.width()) - .setHeight(rect.height()) + .setXCenter(regionOfInterest.centerX()) + .setYCenter(regionOfInterest.centerY()) + .setWidth(regionOfInterest.width()) + .setHeight(regionOfInterest.height()) + // Convert to radians anti-clockwise. + .setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f) .build(); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java new file mode 100644 index 000000000..a34a9787d --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/ImageProcessingOptions.java @@ -0,0 +1,92 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import android.graphics.RectF; +import com.google.auto.value.AutoValue; +import java.util.Optional; + +// TODO: add support for image flipping. +/** Options for image processing. */ +@AutoValue +public abstract class ImageProcessingOptions { + + /** + * Builder for {@link ImageProcessingOptions}. + * + *

    If both region-of-interest and rotation are specified, the crop around the + * region-of-interest is extracted first, then the specified rotation is applied to the crop. + */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets the optional region-of-interest to crop from the image. If not specified, the full image + * is used. + * + *

    Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be + * < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()} + * is called. + */ + public abstract Builder setRegionOfInterest(RectF value); + + /** + * Sets the rotation to apply to the image (or cropped region-of-interest), in degrees + * clockwise. Defaults to 0. + * + *

    The rotation must be a multiple (positive or negative) of 90°, otherwise an + * IllegalArgumentException will be thrown when {@link #build()} is called. + */ + public abstract Builder setRotationDegrees(int value); + + abstract ImageProcessingOptions autoBuild(); + + /** + * Validates and builds the {@link ImageProcessingOptions} instance. + * + * @throws IllegalArgumentException if some of the provided values do not meet their + * requirements. + */ + public final ImageProcessingOptions build() { + ImageProcessingOptions options = autoBuild(); + if (options.regionOfInterest().isPresent()) { + RectF roi = options.regionOfInterest().get(); + if (roi.left >= roi.right || roi.top >= roi.bottom) { + throw new IllegalArgumentException( + String.format( + "Expected left < right and top < bottom, found: %s.", roi.toShortString())); + } + if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) { + throw new IllegalArgumentException( + String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString())); + } + } + if (options.rotationDegrees() % 90 != 0) { + throw new IllegalArgumentException( + String.format( + "Expected rotation to be a multiple of 90°, found: %d.", + options.rotationDegrees())); + } + return options; + } + } + + public abstract Optional regionOfInterest(); + + public abstract int rotationDegrees(); + + public static Builder builder() { + return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index 660645d9c..8e5a30eab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import android.content.Context; -import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; @@ -26,7 +25,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -37,7 +36,9 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto; import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.HandGestureRecognizerGraphOptionsProto; import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto; @@ -59,7 +60,7 @@ import java.util.Optional; * Model Maker. See . * *

      - *
    • Input image {@link Image} + *
    • Input image {@link MPImage} *
        *
      • The image that gesture recognition runs on. *
      @@ -151,9 +152,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { public static GestureRecognizer createFromOptions( Context context, GestureRecognizerOptions recognizerOptions) { // TODO: Consolidate OutputHandler and TaskRunner. - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override public GestureRecognitionResult convertToTaskResult(List packets) { // If there is no hands detected in the image, just returns empty lists. @@ -178,7 +179,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { } @Override - public Image convertToTaskInput(List packets) { + public MPImage convertToTaskInput(List packets) { return new BitmapImageBuilder( AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) .build(); @@ -211,6 +212,25 @@ public final class GestureRecognizer extends BaseVisionTaskApi { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } + /** + * Performs gesture recognition on the provided single image with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

      {@link GestureRecognizer} supports the following color space types: + * + *

        + *
      • {@link Bitmap.Config.ARGB_8888} + *
      + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognize(MPImage image) { + return recognize(image, ImageProcessingOptions.builder().build()); + } + /** * Performs gesture recognition on the provided single image. Only use this method when the {@link * GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc @@ -222,12 +242,41 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
    • {@link Bitmap.Config.ARGB_8888} *
    * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognize(Image inputImage) { - // TODO: add proper support for rotations. - return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF()); + public GestureRecognitionResult recognize( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (GestureRecognitionResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs gesture recognition on the provided video frame with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link GestureRecognizer} is + * created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public GestureRecognitionResult recognizeForVideo(MPImage image, long timestampMs) { + return recognizeForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -243,14 +292,43 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) { - // TODO: add proper support for rotations. - return (GestureRecognitionResult) - processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + public GestureRecognitionResult recognizeForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (GestureRecognitionResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform gesture recognition with default image processing options, + * i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method when the + * {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the gesture recognizer. The input timestamps must be monotonically increasing. + * + *

    {@link GestureRecognizer} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void recognizeAsync(MPImage image, long timestampMs) { + recognizeAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -267,13 +345,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public void recognizeAsync(Image inputImage, long inputTimestampMs) { - // TODO: add proper support for rotations. - sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); + public void recognizeAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up an {@link GestureRecognizer}. */ @@ -300,13 +385,6 @@ public final class GestureRecognizer extends BaseVisionTaskApi { */ public abstract Builder setRunningMode(RunningMode value); - // TODO: remove these. Temporary solutions before bundle asset is ready. - public abstract Builder setBaseOptionsHandDetector(BaseOptions value); - - public abstract Builder setBaseOptionsHandLandmarker(BaseOptions value); - - public abstract Builder setBaseOptionsGestureRecognizer(BaseOptions value); - /** Sets the maximum number of hands can be detected by the GestureRecognizer. */ public abstract Builder setNumHands(Integer value); @@ -333,7 +411,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { * recognizer is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener value); + ResultListener value); /** Sets an optional error listener. */ public abstract Builder setErrorListener(ErrorListener value); @@ -366,13 +444,6 @@ public final class GestureRecognizer extends BaseVisionTaskApi { abstract BaseOptions baseOptions(); - // TODO: remove these. Temporary solutions before bundle asset is ready. - abstract BaseOptions baseOptionsHandDetector(); - - abstract BaseOptions baseOptionsHandLandmarker(); - - abstract BaseOptions baseOptionsGestureRecognizer(); - abstract RunningMode runningMode(); abstract Optional numHands(); @@ -386,7 +457,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { // TODO update gesture confidence options after score merging calculator is ready. abstract Optional minGestureConfidence(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); @@ -405,22 +476,18 @@ public final class GestureRecognizer extends BaseVisionTaskApi { */ @Override public CalculatorOptions convertToCalculatorOptionsProto() { - BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptions())); GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.Builder taskOptionsBuilder = GestureRecognizerGraphOptionsProto.GestureRecognizerGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); // Setup HandDetectorGraphOptions. HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder handDetectorGraphOptionsBuilder = - HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsHandDetector()))); + HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder(); numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands); minHandDetectionConfidence() .ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence); @@ -428,19 +495,12 @@ public final class GestureRecognizer extends BaseVisionTaskApi { // Setup HandLandmarkerGraphOptions. HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder handLandmarksDetectorGraphOptionsBuilder = - HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder(); minHandPresenceConfidence() .ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder handLandmarkerGraphOptionsBuilder = - HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE)); + HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder(); minTrackingConfidence() .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); handLandmarkerGraphOptionsBuilder @@ -450,16 +510,13 @@ public final class GestureRecognizer extends BaseVisionTaskApi { // Setup HandGestureRecognizerGraphOptions. HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.Builder handGestureRecognizerGraphOptionsBuilder = - HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder() - .setBaseOptions( - BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsGestureRecognizer()))); + HandGestureRecognizerGraphOptionsProto.HandGestureRecognizerGraphOptions.newBuilder(); ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = ClassifierOptionsProto.ClassifierOptions.newBuilder(); minGestureConfidence().ifPresent(classifierOptionsBuilder::setScoreThreshold); - handGestureRecognizerGraphOptionsBuilder.setClassifierOptions( - classifierOptionsBuilder.build()); + handGestureRecognizerGraphOptionsBuilder.setCannedGestureClassifierGraphOptions( + GestureClassifierGraphOptionsProto.GestureClassifierGraphOptions.newBuilder() + .setClassifierOptions(classifierOptionsBuilder.build())); taskOptionsBuilder .setHandLandmarkerGraphOptions(handLandmarkerGraphOptionsBuilder.build()) @@ -472,8 +529,14 @@ public final class GestureRecognizer extends BaseVisionTaskApi { } } - /** Creates a RectF covering the full image. */ - private static RectF buildFullImageRectF() { - return new RectF(0, 0, 1, 1); + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("GestureRecognizer doesn't support region-of-interest."); + } } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java index 09f854caa..d82a47b86 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassificationResult.java @@ -15,11 +15,11 @@ package com.google.mediapipe.tasks.vision.imageclassifier; import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.container.proto.CategoryProto; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.ClassificationEntry; import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.components.containers.proto.CategoryProto; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index e8e263b71..3863b6fe0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -15,7 +15,6 @@ package com.google.mediapipe.tasks.vision.imageclassifier; import android.content.Context; -import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; @@ -25,8 +24,8 @@ import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; -import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; import java.io.File; @@ -164,9 +164,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation. */ public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) { - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override public ImageClassificationResult convertToTaskResult(List packets) { try { @@ -182,7 +182,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { } @Override - public Image convertToTaskInput(List packets) { + public MPImage convertToTaskInput(List packets) { return new BitmapImageBuilder( AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) .build(); @@ -215,6 +215,24 @@ public final class ImageClassifier extends BaseVisionTaskApi { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } + /** + * Performs classification on the provided single image with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied. Only use + * this method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * + *

    {@link ImageClassifier} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageClassificationResult classify(MPImage image) { + return classify(image, ImageProcessingOptions.builder().build()); + } + /** * Performs classification on the provided single image. Only use this method when the {@link * ImageClassifier} is created with {@link RunningMode.IMAGE}. @@ -225,16 +243,23 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(Image inputImage) { - return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF()); + public ImageClassificationResult classify( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageClassificationResult) processImageData(image, imageProcessingOptions); } /** - * Performs classification on the provided single image and region-of-interest. Only use this - * method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}. + * Performs classification on the provided video frame with default image processing options, i.e. + * using the whole image as region-of-interest and without any rotation applied. Only use this + * method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. * *

    {@link ImageClassifier} supports the following color space types: * @@ -242,13 +267,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *

  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classify(Image inputImage, RectF roi) { - return (ImageClassificationResult) processImageData(inputImage, roi); + public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) { + return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -264,21 +288,26 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo(Image inputImage, long inputTimestampMs) { - return (ImageClassificationResult) - processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); + public ImageClassificationResult classifyForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs); } /** - * Performs classification on the provided video frame with additional region-of-interest. Only - * use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}. + * Sends live image data to perform classification with default image processing options, i.e. + * using the whole image as region-of-interest and without any rotation applied, and the results + * will be available via the {@link ResultListener} provided in the {@link + * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with + * {@link RunningMode.LIVE_STREAM}. * - *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps - * must be monotonically increasing. + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. * *

    {@link ImageClassifier} supports the following color space types: * @@ -286,15 +315,12 @@ public final class ImageClassifier extends BaseVisionTaskApi { *

  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public ImageClassificationResult classifyForVideo( - Image inputImage, RectF roi, long inputTimestampMs) { - return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs); + public void classifyAsync(MPImage image, long timestampMs) { + classifyAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -311,37 +337,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). * @throws MediaPipeException if there is an internal error. */ - public void classifyAsync(Image inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); - } - - /** - * Sends live image data and additional region-of-interest to perform classification, and the - * results will be available via the {@link ResultListener} provided in the {@link - * ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with - * {@link RunningMode.LIVE_STREAM}. - * - *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is - * sent to the object detector. The input timestamps must be monotonically increasing. - * - *

    {@link ImageClassifier} supports the following color space types: - * - *

      - *
    • {@link Bitmap.Config.ARGB_8888} - *
    - * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param roi a {@link RectF} specifying the region of interest on which to perform - * classification. Coordinates are expected to be specified as normalized values in [0,1]. - * @param inputTimestampMs the input timestamp (in milliseconds). - * @throws MediaPipeException if there is an internal error. - */ - public void classifyAsync(Image inputImage, RectF roi, long inputTimestampMs) { - sendLiveStreamData(inputImage, roi, inputTimestampMs); + public void classifyAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up and {@link ImageClassifier}. */ @@ -379,7 +383,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { * the image classifier is in the live stream mode. */ public abstract Builder setResultListener( - ResultListener resultListener); + ResultListener resultListener); /** Sets an optional {@link ErrorListener}. */ public abstract Builder setErrorListener(ErrorListener errorListener); @@ -416,7 +420,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract Optional classifierOptions(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); @@ -447,9 +451,4 @@ public final class ImageClassifier extends BaseVisionTaskApi { .build(); } } - - /** Creates a RectF covering the full image. */ - private static RectF buildFullImageRectF() { - return new RectF(0, 0, 1, 1); - } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index bfce62791..3f944eaee 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -22,7 +22,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -32,6 +32,7 @@ import com.google.mediapipe.tasks.core.TaskOptions; import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto; import com.google.mediapipe.formats.proto.DetectionProto.Detection; @@ -96,8 +97,10 @@ import java.util.Optional; public final class ObjectDetector extends BaseVisionTaskApi { private static final String TAG = ObjectDetector.class.getSimpleName(); private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; private static final List INPUT_STREAMS = - Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); private static final List OUTPUT_STREAMS = Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); private static final int DETECTIONS_OUT_STREAM_INDEX = 0; @@ -162,9 +165,9 @@ public final class ObjectDetector extends BaseVisionTaskApi { public static ObjectDetector createFromOptions( Context context, ObjectDetectorOptions detectorOptions) { // TODO: Consolidate OutputHandler and TaskRunner. - OutputHandler handler = new OutputHandler<>(); + OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( - new OutputHandler.OutputPacketConverter() { + new OutputHandler.OutputPacketConverter() { @Override public ObjectDetectionResult convertToTaskResult(List packets) { return ObjectDetectionResult.create( @@ -174,7 +177,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { } @Override - public Image convertToTaskInput(List packets) { + public MPImage convertToTaskInput(List packets) { return new BitmapImageBuilder( AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) .build(); @@ -204,7 +207,25 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param runningMode a mediapipe vision task {@link RunningMode}. */ private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { - super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs object detection on the provided single image with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is + * created with {@link RunningMode.IMAGE}. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); } /** @@ -217,11 +238,41 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detect(Image inputImage) { - return (ObjectDetectionResult) processImageData(inputImage); + public ObjectDetectionResult detect( + MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (ObjectDetectionResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs object detection on the provided video frame with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is + * created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -237,12 +288,43 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { - return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs); + public ObjectDetectionResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform object detection with default image processing options, i.e. + * without any rotation applied, and the results will be available via the {@link ResultListener} + * provided in the {@link ObjectDetectorOptions}. Only use this method when the {@link + * ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

    {@link ObjectDetector} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); } /** @@ -259,12 +341,20 @@ public final class ObjectDetector extends BaseVisionTaskApi { *
  • {@link Bitmap.Config.ARGB_8888} * * - * @param inputImage a MediaPipe {@link Image} object for processing. - * @param inputTimestampMs the input timestamp (in milliseconds). + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. * @throws MediaPipeException if there is an internal error. */ - public void detectAsync(Image inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, inputTimestampMs); + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); } /** Options for setting up an {@link ObjectDetector}. */ @@ -333,7 +423,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { * Sets the {@link ResultListener} to receive the detection results asynchronously when the * object detector is in the live stream mode. */ - public abstract Builder setResultListener(ResultListener value); + public abstract Builder setResultListener( + ResultListener value); /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); @@ -378,7 +469,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { abstract List categoryDenylist(); - abstract Optional> resultListener(); + abstract Optional> resultListener(); abstract Optional errorListener(); @@ -414,4 +505,15 @@ public final class ObjectDetector extends BaseVisionTaskApi { .build(); } } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("ObjectDetector doesn't support region-of-interest."); + } + } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml new file mode 100644 index 000000000..aa2df6baf --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java new file mode 100644 index 000000000..078b62af1 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/core/ImageProcessingOptionsTest.java @@ -0,0 +1,70 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.core; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.graphics.RectF; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link ImageProcessingOptions}/ */ +@RunWith(AndroidJUnit4.class) +public final class ImageProcessingOptionsTest { + + @Test + public void succeedsWithValidInputs() throws Exception { + ImageProcessingOptions options = + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.0f, 0.1f, 1.0f, 0.9f)) + .setRotationDegrees(270) + .build(); + } + + @Test + public void failsWithLeftHigherThanRight() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.9f, 0.0f, 0.1f, 1.0f)) + .build()); + assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom"); + } + + @Test + public void failsWithBottomHigherThanTop() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageProcessingOptions.builder() + .setRegionOfInterest(new RectF(0.0f, 0.9f, 1.0f, 0.1f)) + .build()); + assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom"); + } + + @Test + public void failsWithInvalidRotation() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ImageProcessingOptions.builder().setRotationDegrees(1).build()); + assertThat(exception).hasMessageThat().contains("Expected rotation to be a multiple of 90°"); + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index efec02b2a..eca5d35c2 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -19,17 +19,19 @@ import static org.junit.Assert.assertThrows; import android.content.res.AssetManager; import android.graphics.BitmapFactory; +import android.graphics.RectF; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.common.truth.Correspondence; import com.google.mediapipe.formats.proto.ClassificationProto; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Landmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions; import java.io.InputStream; @@ -43,17 +45,17 @@ import org.junit.runners.Suite.SuiteClasses; @RunWith(Suite.class) @SuiteClasses({GestureRecognizerTest.General.class, GestureRecognizerTest.RunningModeTest.class}) public class GestureRecognizerTest { - private static final String HAND_DETECTOR_MODEL_FILE = "palm_detection_full.tflite"; - private static final String HAND_LANDMARKER_MODEL_FILE = "hand_landmark_full.tflite"; - private static final String GESTURE_RECOGNIZER_MODEL_FILE = - "cg_classifier_screen3d_landmark_features_nn_2022_08_04_base_simple_model.tflite"; + private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task"; private static final String TWO_HANDS_IMAGE = "right_hands.jpg"; private static final String THUMB_UP_IMAGE = "thumb_up.jpg"; + private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg"; private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg"; private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb"; private static final String TAG = "Gesture Recognizer Test"; private static final String THUMB_UP_LABEL = "Thumb_Up"; private static final int THUMB_UP_INDEX = 5; + private static final String POINTING_UP_LABEL = "Pointing_Up"; + private static final int POINTING_UP_INDEX = 3; private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; private static final int IMAGE_WIDTH = 382; private static final int IMAGE_HEIGHT = 406; @@ -66,13 +68,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -88,13 +86,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -111,16 +105,12 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) // TODO update the confidence to be in range [0,1] after embedding model // and scoring calculator is integrated. - .setMinGestureConfidence(3.0f) + .setMinGestureConfidence(2.0f) .build(); GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -139,13 +129,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setNumHands(2) .build(); GestureRecognizer gestureRecognizer = @@ -154,6 +140,53 @@ public class GestureRecognizerTest { gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE)); assertThat(actualResult.handednesses()).hasSize(2); } + + @Test + public void recognize_successWithRotation() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + GestureRecognitionResult actualResult = + gestureRecognizer.recognize( + getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions); + assertThat(actualResult.gestures()).hasSize(1); + assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX); + assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL); + } + + @Test + public void recognize_failsWithRegionOfInterest() throws Exception { + GestureRecognizerOptions options = + GestureRecognizerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) + .setNumHands(1) + .build(); + GestureRecognizer gestureRecognizer = + GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + gestureRecognizer.recognize( + getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("GestureRecognizer doesn't support region-of-interest"); + } } @RunWith(AndroidJUnit4.class) @@ -168,19 +201,7 @@ public class GestureRecognizerTest { GestureRecognizerOptions.builder() .setBaseOptions( BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) - .build()) - .setBaseOptionsHandDetector( - BaseOptions.builder() - .setModelAssetPath(HAND_DETECTOR_MODEL_FILE) - .build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder() - .setModelAssetPath(HAND_LANDMARKER_MODEL_FILE) - .build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .build()) .setRunningMode(mode) .setResultListener((gestureRecognitionResult, inputImage) -> {}) @@ -201,15 +222,7 @@ public class GestureRecognizerTest { GestureRecognizerOptions.builder() .setBaseOptions( BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) - .build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder() - .setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE) + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) .build()) .setRunningMode(RunningMode.LIVE_STREAM) .build()); @@ -223,13 +236,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.IMAGE) .build(); @@ -238,12 +247,16 @@ public class GestureRecognizerTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeAsync( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -252,13 +265,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.VIDEO) .build(); @@ -272,7 +281,9 @@ public class GestureRecognizerTest { exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeAsync( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -281,13 +292,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener((gestureRecognitionResult, inputImage) -> {}) .build(); @@ -302,7 +309,9 @@ public class GestureRecognizerTest { exception = assertThrows( MediaPipeException.class, - () -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0)); + () -> + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -311,13 +320,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.IMAGE) .build(); @@ -335,13 +340,9 @@ public class GestureRecognizerTest { GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.VIDEO) .build(); GestureRecognizer gestureRecognizer = @@ -350,26 +351,23 @@ public class GestureRecognizerTest { getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); for (int i = 0; i < 3; i++) { GestureRecognitionResult actualResult = - gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i); + gestureRecognizer.recognizeForVideo( + getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i); assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); } } @Test public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { - Image image = getImageFromAsset(THUMB_UP_IMAGE); + MPImage image = getImageFromAsset(THUMB_UP_IMAGE); GestureRecognitionResult expectedResult = getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (actualResult, inputImage) -> { @@ -380,9 +378,11 @@ public class GestureRecognizerTest { .build(); try (GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - gestureRecognizer.recognizeAsync(image, 1); + gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -391,19 +391,15 @@ public class GestureRecognizerTest { @Test public void recognize_successWithLiveSteamMode() throws Exception { - Image image = getImageFromAsset(THUMB_UP_IMAGE); + MPImage image = getImageFromAsset(THUMB_UP_IMAGE); GestureRecognitionResult expectedResult = getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX); GestureRecognizerOptions options = GestureRecognizerOptions.builder() .setBaseOptions( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) - .setBaseOptionsHandDetector( - BaseOptions.builder().setModelAssetPath(HAND_DETECTOR_MODEL_FILE).build()) - .setBaseOptionsHandLandmarker( - BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_MODEL_FILE).build()) - .setBaseOptionsGestureRecognizer( - BaseOptions.builder().setModelAssetPath(GESTURE_RECOGNIZER_MODEL_FILE).build()) + BaseOptions.builder() + .setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE) + .build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (actualResult, inputImage) -> { @@ -415,12 +411,12 @@ public class GestureRecognizerTest { try (GestureRecognizer gestureRecognizer = GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - gestureRecognizer.recognizeAsync(image, i); + gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ i); } } } - private static Image getImageFromAsset(String filePath) throws Exception { + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); @@ -487,7 +483,7 @@ public class GestureRecognizerTest { assertThat(expectedGesture.categoryName()).isEqualTo(expectedGesture.categoryName()); } - private static void assertImageSizeIsExpected(Image inputImage) { + private static void assertImageSizeIsExpected(MPImage inputImage) { assertThat(inputImage).isNotNull(); assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index e02e8ebe7..99ebd9777 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -24,11 +24,12 @@ import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions; import java.io.InputStream; @@ -47,7 +48,9 @@ public class ImageClassifierTest { private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite"; private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite"; private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg"; + private static final String MULTI_OBJECTS_ROTATED_IMAGE = "multi_objects_rotated.jpg"; @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { @@ -209,13 +212,60 @@ public class ImageClassifierTest { ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); // RectF around the soccer ball. RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); ImageClassificationResult results = - imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi); + imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions); assertHasOneHeadAndOneTimestamp(results, 0); assertCategoriesAre( results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", ""))); } + + @Test + public void classify_succeedsWithRotation() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageClassificationResult results = + imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, + Arrays.asList( + Category.create(0.6390683f, 934, "cheeseburger", ""), + Category.create(0.0495407f, 963, "meat loaf", ""), + Category.create(0.0469720f, 925, "guacamole", ""))); + } + + @Test + public void classify_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageClassifierOptions options = + ImageClassifierOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .build(); + ImageClassifier imageClassifier = + ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); + // RectF around the chair. + RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageClassificationResult results = + imageClassifier.classify( + getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions); + + assertHasOneHeadAndOneTimestamp(results, 0); + assertCategoriesAre( + results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", ""))); + } } @RunWith(AndroidJUnit4.class) @@ -269,12 +319,16 @@ public class ImageClassifierTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyForVideo( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyAsync( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -296,7 +350,9 @@ public class ImageClassifierTest { exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyAsync( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -320,7 +376,9 @@ public class ImageClassifierTest { exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0)); + () -> + imageClassifier.classifyForVideo( + getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -342,7 +400,7 @@ public class ImageClassifierTest { @Test public void classify_succeedsWithVideoMode() throws Exception { - Image image = getImageFromAsset(BURGER_IMAGE); + MPImage image = getImageFromAsset(BURGER_IMAGE); ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) @@ -352,7 +410,8 @@ public class ImageClassifierTest { ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassificationResult results = imageClassifier.classifyForVideo(image, i); + ImageClassificationResult results = + imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); assertHasOneHeadAndOneTimestamp(results, i); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -361,7 +420,7 @@ public class ImageClassifierTest { @Test public void classify_failsWithOutOfOrderInputTimestamps() throws Exception { - Image image = getImageFromAsset(BURGER_IMAGE); + MPImage image = getImageFromAsset(BURGER_IMAGE); ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) @@ -377,9 +436,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -388,7 +449,7 @@ public class ImageClassifierTest { @Test public void classify_succeedsWithLiveStreamMode() throws Exception { - Image image = getImageFromAsset(BURGER_IMAGE); + MPImage image = getImageFromAsset(BURGER_IMAGE); ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) @@ -405,13 +466,13 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, i); + imageClassifier.classifyAsync(image, /*timestampMs=*/ i); } } } } - private static Image getImageFromAsset(String filePath) throws Exception { + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); @@ -437,7 +498,7 @@ public class ImageClassifierTest { } } - private static void assertImageSizeIsExpected(Image inputImage) { + private static void assertImageSizeIsExpected(MPImage inputImage) { assertThat(inputImage).isNotNull(); assertThat(inputImage.getWidth()).isEqualTo(480); assertThat(inputImage.getHeight()).isEqualTo(325); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java index cdec57d76..2878c380d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectorTest.java @@ -24,11 +24,12 @@ import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; -import com.google.mediapipe.framework.image.Image; +import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; import java.io.InputStream; @@ -45,10 +46,11 @@ import org.junit.runners.Suite.SuiteClasses; public class ObjectDetectorTest { private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg"; + private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg"; private static final int IMAGE_WIDTH = 1200; private static final int IMAGE_HEIGHT = 600; private static final float CAT_SCORE = 0.69f; - private static final RectF catBoundingBox = new RectF(611, 164, 986, 596); + private static final RectF CAT_BOUNDING_BOX = new RectF(611, 164, 986, 596); // TODO: Figure out why android_x86 and android_arm tests have slightly different // scores (0.6875 vs 0.69921875). private static final float SCORE_DIFF_TOLERANCE = 0.01f; @@ -67,7 +69,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -104,7 +106,7 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); // The score threshold should block all other other objects, except cat. - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -175,7 +177,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -228,6 +230,46 @@ public class ObjectDetectorTest { .contains("`category_allowlist` and `category_denylist` are mutually exclusive options."); } + @Test + public void detect_succeedsWithRotation() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMaxResults(1) + .setCategoryAllowlist(Arrays.asList("cat")) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ObjectDetectionResult results = + objectDetector.detect( + getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions); + + assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + ObjectDetector objectDetector = + ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + objectDetector.detect( + getImageFromAsset(CAT_AND_DOG_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("ObjectDetector doesn't support region-of-interest"); + } + // TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation, // detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions, // detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero. @@ -282,12 +324,16 @@ public class ObjectDetectorTest { MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectAsync( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -309,7 +355,9 @@ public class ObjectDetectorTest { exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectAsync( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -333,7 +381,9 @@ public class ObjectDetectorTest { exception = assertThrows( MediaPipeException.class, - () -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0)); + () -> + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -348,7 +398,7 @@ public class ObjectDetectorTest { ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } @Test @@ -363,30 +413,33 @@ public class ObjectDetectorTest { ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { ObjectDetectionResult results = - objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i); - assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE); + objectDetector.detectForVideo( + getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i); + assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE); } } @Test public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { - Image image = getImageFromAsset(CAT_AND_DOG_IMAGE); + MPImage image = getImageFromAsset(CAT_AND_DOG_IMAGE); ObjectDetectorOptions options = ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) .build(); try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - objectDetector.detectAsync(image, 1); + objectDetector.detectAsync(image, /*timestampsMs=*/ 1); MediaPipeException exception = - assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0)); + assertThrows( + MediaPipeException.class, + () -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -395,14 +448,14 @@ public class ObjectDetectorTest { @Test public void detect_successWithLiveSteamMode() throws Exception { - Image image = getImageFromAsset(CAT_AND_DOG_IMAGE); + MPImage image = getImageFromAsset(CAT_AND_DOG_IMAGE); ObjectDetectorOptions options = ObjectDetectorOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (objectDetectionResult, inputImage) -> { - assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE); + assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE); assertImageSizeIsExpected(inputImage); }) .setMaxResults(1) @@ -410,13 +463,13 @@ public class ObjectDetectorTest { try (ObjectDetector objectDetector = ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; i++) { - objectDetector.detectAsync(image, i); + objectDetector.detectAsync(image, /*timestampsMs=*/ i); } } } } - private static Image getImageFromAsset(String filePath) throws Exception { + private static MPImage getImageFromAsset(String filePath) throws Exception { AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); InputStream istr = assetManager.open(filePath); return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); @@ -448,7 +501,7 @@ public class ObjectDetectorTest { assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom); } - private static void assertImageSizeIsExpected(Image inputImage) { + private static void assertImageSizeIsExpected(MPImage inputImage) { assertThat(inputImage).isNotNull(); assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 8dd9fcd60..f24230a9e 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -27,6 +27,15 @@ py_library( ], ) +py_library( + name = "rect", + srcs = ["rect.py"], + deps = [ + "//mediapipe/framework/formats:rect_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + py_library( name = "category", srcs = ["category.py"], @@ -47,3 +56,13 @@ py_library( "//mediapipe/tasks/python/core:optional_dependencies", ], ) + +py_library( + name = "classifications", + srcs = ["classifications.py"], + deps = [ + ":category", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/containers/classifications.py b/mediapipe/tasks/python/components/containers/classifications.py new file mode 100644 index 000000000..90ab22614 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/classifications.py @@ -0,0 +1,168 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifications data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassificationEntryProto = classifications_pb2.ClassificationEntry +_ClassificationsProto = classifications_pb2.Classifications +_ClassificationResultProto = classifications_pb2.ClassificationResult + + +@dataclasses.dataclass +class ClassificationEntry: + """List of predicted classes (aka labels) for a given classifier head. + + Attributes: + categories: The array of predicted categories, usually sorted by descending + scores (e.g. from high to low probability). + timestamp_ms: The optional timestamp (in milliseconds) associated to the + classification entry. This is useful for time series use cases, e.g., + audio classification. + """ + + categories: List[category_module.Category] + timestamp_ms: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationEntryProto: + """Generates a ClassificationEntry protobuf object.""" + return _ClassificationEntryProto( + categories=[category.to_pb2() for category in self.categories], + timestamp_ms=self.timestamp_ms) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry': + """Creates a `ClassificationEntry` object from the given protobuf object.""" + return ClassificationEntry( + categories=[ + category_module.Category.create_from_pb2(category) + for category in pb2_obj.categories + ], + timestamp_ms=pb2_obj.timestamp_ms) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationEntry): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class Classifications: + """Represents the classifications for a given classifier head. + + Attributes: + entries: A list of `ClassificationEntry` objects. + head_index: The index of the classifier head these categories refer to. This + is useful for multi-head models. + head_name: The name of the classifier head, which is the corresponding + tensor metadata name. + """ + + entries: List[ClassificationEntry] + head_index: int + head_name: str + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationsProto: + """Generates a Classifications protobuf object.""" + return _ClassificationsProto( + entries=[entry.to_pb2() for entry in self.entries], + head_index=self.head_index, + head_name=self.head_name) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': + """Creates a `Classifications` object from the given protobuf object.""" + return Classifications( + entries=[ + ClassificationEntry.create_from_pb2(entry) + for entry in pb2_obj.entries + ], + head_index=pb2_obj.head_index, + head_name=pb2_obj.head_name) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Classifications): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class ClassificationResult: + """Contains one set of results per classifier head. + + Attributes: + classifications: A list of `Classifications` objects. + """ + + classifications: List[Classifications] + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassificationResultProto: + """Generates a ClassificationResult protobuf object.""" + return _ClassificationResultProto(classifications=[ + classification.to_pb2() for classification in self.classifications + ]) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult': + """Creates a `ClassificationResult` object from the given protobuf object. + """ + return ClassificationResult(classifications=[ + Classifications.create_from_pb2(classification) + for classification in pb2_obj.classifications + ]) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassificationResult): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/containers/rect.py b/mediapipe/tasks/python/components/containers/rect.py new file mode 100644 index 000000000..510561592 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/rect.py @@ -0,0 +1,140 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rect data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.framework.formats import rect_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_RectProto = rect_pb2.Rect +_NormalizedRectProto = rect_pb2.NormalizedRect + + +@dataclasses.dataclass +class Rect: + """A rectangle with rotation in image coordinates. + + Attributes: x_center : The X coordinate of the top-left corner, in pixels. + y_center : The Y coordinate of the top-left corner, in pixels. + width: The width of the rectangle, in pixels. + height: The height of the rectangle, in pixels. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: int + y_center: int + width: int + height: int + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _RectProto: + """Generates a Rect protobuf object.""" + return _RectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _RectProto) -> 'Rect': + """Creates a `Rect` object from the given protobuf object.""" + return Rect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, Rect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) + + +@dataclasses.dataclass +class NormalizedRect: + """A rectangle with rotation in normalized coordinates. + + The values of box + + center location and size are within [0, 1]. + + Attributes: x_center : The X normalized coordinate of the top-left corner. + y_center : The Y normalized coordinate of the top-left corner. + width: The width of the rectangle. + height: The height of the rectangle. + rotation: Rotation angle is clockwise in radians. + rect_id: Optional unique id to help associate different rectangles to each + other. + """ + + x_center: float + y_center: float + width: float + height: float + rotation: Optional[float] = 0.0 + rect_id: Optional[int] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _NormalizedRectProto: + """Generates a NormalizedRect protobuf object.""" + return _NormalizedRectProto( + x_center=self.x_center, + y_center=self.y_center, + width=self.width, + height=self.height, + rotation=self.rotation, + rect_id=self.rect_id) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, pb2_obj: _NormalizedRectProto) -> 'NormalizedRect': + """Creates a `NormalizedRect` object from the given protobuf object.""" + return NormalizedRect( + x_center=pb2_obj.x_center, + y_center=pb2_obj.y_center, + width=pb2_obj.width, + height=pb2_obj.height, + rotation=pb2_obj.rotation, + rect_id=pb2_obj.rect_id) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedRect): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/processors/BUILD b/mediapipe/tasks/python/components/processors/BUILD new file mode 100644 index 000000000..f87a579b0 --- /dev/null +++ b/mediapipe/tasks/python/components/processors/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. + +# Placeholder for internal Python strict library and test compatibility macro. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +py_library( + name = "classifier_options", + srcs = ["classifier_options.py"], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) diff --git a/mediapipe/tasks/python/components/processors/classifier_options.py b/mediapipe/tasks/python/components/processors/classifier_options.py new file mode 100644 index 000000000..2e77f93b5 --- /dev/null +++ b/mediapipe/tasks/python/components/processors/classifier_options.py @@ -0,0 +1,86 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classifier options data class.""" + +import dataclasses +from typing import Any, List, Optional + +from mediapipe.tasks.cc.components.processors.proto import classifier_options_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_ClassifierOptionsProto = classifier_options_pb2.ClassifierOptions + + +@dataclasses.dataclass +class ClassifierOptions: + """Options for classification processor. + + Attributes: + display_names_locale: The locale to use for display names specified through + the TFLite Model Metadata. + max_results: The maximum number of top-scored classification results to + return. + score_threshold: Overrides the ones provided in the model metadata. Results + below this value are rejected. + category_allowlist: Allowlist of category names. If non-empty, detection + results whose category name is not in this set will be filtered out. + Duplicate or unknown category names are ignored. Mutually exclusive with + `category_denylist`. + category_denylist: Denylist of category names. If non-empty, detection + results whose category name is in this set will be filtered out. Duplicate + or unknown category names are ignored. Mutually exclusive with + `category_allowlist`. + """ + + display_names_locale: Optional[str] = None + max_results: Optional[int] = None + score_threshold: Optional[float] = None + category_allowlist: Optional[List[str]] = None + category_denylist: Optional[List[str]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ClassifierOptionsProto: + """Generates a ClassifierOptions protobuf object.""" + return _ClassifierOptionsProto( + score_threshold=self.score_threshold, + category_allowlist=self.category_allowlist, + category_denylist=self.category_denylist, + display_names_locale=self.display_names_locale, + max_results=self.max_results) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2(cls, + pb2_obj: _ClassifierOptionsProto) -> 'ClassifierOptions': + """Creates a `ClassifierOptions` object from the given protobuf object.""" + return ClassifierOptions( + score_threshold=pb2_obj.score_threshold, + category_allowlist=[str(name) for name in pb2_obj.category_allowlist], + category_denylist=[str(name) for name in pb2_obj.category_denylist], + display_names_locale=pb2_obj.display_names_locale, + max_results=pb2_obj.max_results) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, ClassifierOptions): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index d5cdeecda..76e2f4f4a 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD index 3e44218ac..d2b55d47d 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -1,4 +1,4 @@ -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package( default_visibility = [ @@ -37,3 +37,9 @@ py_library( srcs = ["writer_utils.py"], deps = ["//mediapipe/tasks/metadata:schema_py"], ) + +py_library( + name = "image_classifier", + srcs = ["image_classifier.py"], + deps = [":metadata_writer"], +) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py b/mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py new file mode 100644 index 000000000..c516a342d --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/image_classifier.py @@ -0,0 +1,71 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the image classifier models.""" + +from typing import List, Optional + +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer + +_MODEL_NAME = "ImageClassifier" +_MODEL_DESCRIPTION = ("Identify the most prominent object in the image from a " + "known set of categories.") + + +class MetadataWriter(metadata_writer.MetadataWriterBase): + """MetadataWriter to write the metadata for image classifier.""" + + @classmethod + def create( + cls, + model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + labels: metadata_writer.Labels, + score_calibration: Optional[metadata_writer.ScoreCalibration] = None + ) -> "MetadataWriter": + """Creates MetadataWriter to write the metadata for image classifier. + + The parameters required in this method are mandatory when using MediaPipe + Tasks. + + Note that only the output TFLite is used for deployment. The output JSON + content is used to interpret the metadata content. + + Args: + model_buffer: A valid flatbuffer loaded from the TFLite model file. + input_norm_mean: the mean value used in the input tensor normalization + [1]. + input_norm_std: the std value used in the input tensor normalizarion [1]. + labels: an instance of Labels helper class used in the output + classification tensor [2]. + score_calibration: A container of the score calibration operation [3] in + the classification tensor. Optional if the model does not use score + calibration. + + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99 + [3]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L456 + + Returns: + An MetadataWrite object. + """ + writer = metadata_writer.MetadataWriter(model_buffer) + writer.add_genernal_info(_MODEL_NAME, _MODEL_DESCRIPTION) + writer.add_image_input(input_norm_mean, input_norm_std) + writer.add_classification_output(labels, score_calibration) + return cls(writer) diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py index e69efd015..5a2eaba07 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py @@ -15,19 +15,22 @@ """Generic metadata writer.""" import collections +import csv import dataclasses import os import tempfile from typing import List, Optional, Tuple import flatbuffers -from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb -from mediapipe.tasks.python.metadata import metadata as _metadata +from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb +from mediapipe.tasks.python.metadata import metadata from mediapipe.tasks.python.metadata.metadata_writers import metadata_info from mediapipe.tasks.python.metadata.metadata_writers import writer_utils _INPUT_IMAGE_NAME = 'image' _INPUT_IMAGE_DESCRIPTION = 'Input image to be processed.' +_OUTPUT_CLASSIFICATION_NAME = 'score' +_OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.' @dataclasses.dataclass @@ -140,26 +143,85 @@ class Labels(object): class ScoreCalibration: """Simple container holding score calibration related parameters.""" - # A shortcut to avoid client side code importing _metadata_fb - transformation_types = _metadata_fb.ScoreTransformationType + # A shortcut to avoid client side code importing metadata_fb + transformation_types = metadata_fb.ScoreTransformationType def __init__(self, - transformation_type: _metadata_fb.ScoreTransformationType, - parameters: List[CalibrationParameter], + transformation_type: metadata_fb.ScoreTransformationType, + parameters: List[Optional[CalibrationParameter]], default_score: int = 0): self.transformation_type = transformation_type self.parameters = parameters self.default_score = default_score + @classmethod + def create_from_file(cls, + transformation_type: metadata_fb.ScoreTransformationType, + file_path: str, + default_score: int = 0) -> 'ScoreCalibration': + """Creates ScoreCalibration from the file. + + Args: + transformation_type: type of the function used for transforming the + uncalibrated score before applying score calibration. + file_path: file_path of the score calibration file [1]. Contains + sigmoid-based score calibration parameters, formatted as CSV. Lines + contain for each index of an output tensor the scale, slope, offset and + (optional) min_score parameters to be used for sigmoid fitting (in this + order and in `strtof`-compatible [2] format). Scale should be a + non-negative value. A line may be left empty to default calibrated + scores for this index to default_score. In summary, each line should + thus contain 0, 3 or 4 comma-separated values. + default_score: the default calibrated score to apply if the uncalibrated + score is below min_score or if no parameters were specified for a given + index. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L133 + [2]: + https://en.cppreference.com/w/c/string/byte/strtof + + Returns: + A ScoreCalibration object. + Raises: + ValueError: if the score_calibration file is malformed. + """ + with open(file_path, 'r') as calibration_file: + csv_reader = csv.reader(calibration_file, delimiter=',') + parameters = [] + for row in csv_reader: + if not row: + parameters.append(None) + continue + + if len(row) != 3 and len(row) != 4: + raise ValueError( + f'Expected empty lines or 3 or 4 parameters per line in score' + f' calibration file, but got {len(row)}.') + + if float(row[0]) < 0: + raise ValueError( + f'Expected scale to be a non-negative value, but got ' + f'{float(row[0])}.') + + parameters.append( + CalibrationParameter( + scale=float(row[0]), + slope=float(row[1]), + offset=float(row[2]), + min_score=None if len(row) == 3 else float(row[3]))) + + return cls(transformation_type, parameters, default_score) + def _fill_default_tensor_names( - tensor_metadata: List[_metadata_fb.TensorMetadataT], + tensor_metadata_list: List[metadata_fb.TensorMetadataT], tensor_names_from_model: List[str]): """Fills the default tensor names.""" # If tensor name in metadata is empty, default to the tensor name saved in # the model. - for metadata, name in zip(tensor_metadata, tensor_names_from_model): - metadata.name = metadata.name or name + for tensor_metadata, name in zip(tensor_metadata_list, + tensor_names_from_model): + tensor_metadata.name = tensor_metadata.name or name def _pair_tensor_metadata( @@ -212,7 +274,7 @@ def _create_metadata_buffer( input_metadata = [m.create_metadata() for m in input_md] else: num_input_tensors = writer_utils.get_subgraph(model_buffer).InputsLength() - input_metadata = [_metadata_fb.TensorMetadataT()] * num_input_tensors + input_metadata = [metadata_fb.TensorMetadataT()] * num_input_tensors _fill_default_tensor_names(input_metadata, writer_utils.get_input_tensor_names(model_buffer)) @@ -224,12 +286,12 @@ def _create_metadata_buffer( output_metadata = [m.create_metadata() for m in output_md] else: num_output_tensors = writer_utils.get_subgraph(model_buffer).OutputsLength() - output_metadata = [_metadata_fb.TensorMetadataT()] * num_output_tensors + output_metadata = [metadata_fb.TensorMetadataT()] * num_output_tensors _fill_default_tensor_names(output_metadata, writer_utils.get_output_tensor_names(model_buffer)) # Create the subgraph metadata. - subgraph_metadata = _metadata_fb.SubGraphMetadataT() + subgraph_metadata = metadata_fb.SubGraphMetadataT() subgraph_metadata.inputTensorMetadata = input_metadata subgraph_metadata.outputTensorMetadata = output_metadata @@ -243,7 +305,7 @@ def _create_metadata_buffer( b = flatbuffers.Builder(0) b.Finish( model_metadata.Pack(b), - _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) return b.Output() @@ -291,7 +353,7 @@ class MetadataWriter(object): name=model_name, description=model_description) return self - color_space_types = _metadata_fb.ColorSpaceType + color_space_types = metadata_fb.ColorSpaceType def add_feature_input(self, name: Optional[str] = None, @@ -305,7 +367,7 @@ class MetadataWriter(object): self, norm_mean: List[float], norm_std: List[float], - color_space_type: Optional[int] = _metadata_fb.ColorSpaceType.RGB, + color_space_type: Optional[int] = metadata_fb.ColorSpaceType.RGB, name: str = _INPUT_IMAGE_NAME, description: str = _INPUT_IMAGE_DESCRIPTION) -> 'MetadataWriter': """Adds an input image metadata for the image input. @@ -341,9 +403,6 @@ class MetadataWriter(object): self._input_mds.append(input_md) return self - _OUTPUT_CLASSIFICATION_NAME = 'score' - _OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively' - def add_classification_output( self, labels: Optional[Labels] = None, @@ -416,8 +475,7 @@ class MetadataWriter(object): A tuple of (model_with_metadata_in_bytes, metdata_json_content) """ # Populates metadata and associated files into TFLite model buffer. - populator = _metadata.MetadataPopulator.with_model_buffer( - self._model_buffer) + populator = metadata.MetadataPopulator.with_model_buffer(self._model_buffer) metadata_buffer = _create_metadata_buffer( model_buffer=self._model_buffer, general_md=self._general_md, @@ -429,7 +487,7 @@ class MetadataWriter(object): populator.populate() tflite_content = populator.get_model_buffer() - displayer = _metadata.MetadataDisplayer.with_model_buffer(tflite_content) + displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content) metadata_json_content = displayer.get_metadata_json() return tflite_content, metadata_json_content @@ -452,9 +510,7 @@ class MetadataWriter(object): """Stores calibration parameters in a csv file.""" filepath = os.path.join(self._temp_folder.name, filename) with open(filepath, 'w') as f: - for idx, item in enumerate(calibrations): - if idx != 0: - f.write('\n') + for item in calibrations: if item: if item.scale is None or item.slope is None or item.offset is None: raise ValueError('scale, slope and offset values can not be set to ' @@ -463,6 +519,30 @@ class MetadataWriter(object): f.write(f'{item.scale},{item.slope},{item.offset},{item.min_score}') else: f.write(f'{item.scale},{item.slope},{item.offset}') + f.write('\n') - self._associated_files.append(filepath) + self._associated_files.append(filepath) return filepath + + +class MetadataWriterBase: + """Base MetadataWriter class which contains the apis exposed to users. + + MetadataWriter for Tasks e.g. image classifier / object detector will inherit + this class for their own usage. + """ + + def __init__(self, writer: MetadataWriter) -> None: + self.writer = writer + + def populate(self) -> Tuple[bytearray, str]: + """Populates metadata into the TFLite file. + + Note that only the output tflite is used for deployment. The output JSON + content is used to interpret the metadata content. + + Returns: + A tuple of (model_with_metadata_in_bytes, metdata_json_content) + """ + return self.writer.populate() + diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index 8e5b91cf9..92c5f4038 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD index 8779d2fb6..a7bfd297d 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -28,9 +28,28 @@ py_test( py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], - data = ["//mediapipe/tasks/testdata/metadata:model_files"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + "//mediapipe/tasks/testdata/metadata:model_files", + ], deps = [ "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + "//mediapipe/tasks/testdata/metadata:model_files", + ], + deps = [ + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/python/metadata", + "//mediapipe/tasks/python/metadata/metadata_writers:image_classifier", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py new file mode 100644 index 000000000..4bbd91667 --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/image_classifier_test.py @@ -0,0 +1,76 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata_writer.image_classifier.""" + +from absl.testing import absltest +from absl.testing import parameterized + +from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb +from mediapipe.tasks.python.metadata import metadata +from mediapipe.tasks.python.metadata.metadata_writers import image_classifier +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer +from mediapipe.tasks.python.test import test_utils + +_FLOAT_MODEL = test_utils.get_test_data_path( + "mobilenet_v2_1.0_224_without_metadata.tflite") +_QUANT_MODEL = test_utils.get_test_data_path( + "mobilenet_v2_1.0_224_quant_without_metadata.tflite") +_LABEL_FILE = test_utils.get_test_data_path("labels.txt") +_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path("score_calibration.txt") +_SCORE_CALIBRATION_FILENAME = "score_calibration.txt" +_DEFAULT_SCORE_CALIBRATION_VALUE = 0.2 +_NORM_MEAN = 127.5 +_NORM_STD = 127.5 +_FLOAT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224.json") +_QUANT_JSON = test_utils.get_test_data_path("mobilenet_v2_1.0_224_quant.json") + + +class ImageClassifierTest(parameterized.TestCase): + + @parameterized.named_parameters( + { + "testcase_name": "test_float_model", + "model_file": _FLOAT_MODEL, + "golden_json": _FLOAT_JSON + }, { + "testcase_name": "test_quant_model", + "model_file": _QUANT_MODEL, + "golden_json": _QUANT_JSON + }) + def test_write_metadata(self, model_file: str, golden_json: str): + with open(model_file, "rb") as f: + model_buffer = f.read() + writer = image_classifier.MetadataWriter.create( + model_buffer, [_NORM_MEAN], [_NORM_STD], + labels=metadata_writer.Labels().add_from_file(_LABEL_FILE), + score_calibration=metadata_writer.ScoreCalibration.create_from_file( + metadata_fb.ScoreTransformationType.LOG, _SCORE_CALIBRATION_FILE, + _DEFAULT_SCORE_CALIBRATION_VALUE)) + tflite_content, metadata_json = writer.populate() + + with open(golden_json, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content) + file_buffer = displayer.get_associated_file_buffer( + _SCORE_CALIBRATION_FILENAME) + with open(_SCORE_CALIBRATION_FILE, "rb") as f: + expected_file_buffer = f.read() + self.assertEqual(file_buffer, expected_file_buffer) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py index c39b4a555..51b043c7d 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_writer_test.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================== """Tests for metadata writer classes.""" +import os +import tempfile + from absl.testing import absltest from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer @@ -20,6 +23,7 @@ from mediapipe.tasks.python.test import test_utils _IMAGE_CLASSIFIER_MODEL = test_utils.get_test_data_path( 'mobilenet_v1_0.25_224_1_default_1.tflite') +_SCORE_CALIBRATION_FILE = test_utils.get_test_data_path('score_calibration.txt') class LabelsTest(absltest.TestCase): @@ -49,6 +53,54 @@ class LabelsTest(absltest.TestCase): ]) +class ScoreCalibrationTest(absltest.TestCase): + + def test_create_from_file_successful(self): + score_calibration = metadata_writer.ScoreCalibration.create_from_file( + metadata_writer.ScoreCalibration.transformation_types.LOG, + _SCORE_CALIBRATION_FILE) + self.assertLen(score_calibration.parameters, 511) + self.assertIsNone(score_calibration.parameters[0]) + self.assertEqual( + score_calibration.parameters[1], + metadata_writer.CalibrationParameter( + scale=0.9876328110694885, + slope=0.36622241139411926, + offset=0.5352765321731567, + min_score=0.71484375)) + self.assertEqual( + score_calibration.parameters[510], + metadata_writer.CalibrationParameter( + scale=0.9901729226112366, + slope=0.8561913371086121, + offset=0.8783953189849854, + min_score=0.5859375)) + + def test_create_from_file_fail(self): + with tempfile.TemporaryDirectory() as temp_dir: + test_file = os.path.join(temp_dir, 'score_calibration.csv') + with open(test_file, 'w') as f: + f.write('0.98,0.5\n') + + with self.assertRaisesRegex( + ValueError, + 'Expected empty lines or 3 or 4 parameters per line in score ' + 'calibration file, but got 2.' + ): + metadata_writer.ScoreCalibration.create_from_file( + metadata_writer.ScoreCalibration.transformation_types.LOG, + test_file) + + with open(test_file, 'w') as f: + f.write('-0.98,0.5,0.34\n') + with self.assertRaisesRegex( + ValueError, + 'Expected scale to be a non-negative value, but got -0.98.'): + metadata_writer.ScoreCalibration.create_from_file( + metadata_writer.ScoreCalibration.transformation_types.LOG, + test_file) + + class MetadataWriterForTaskTest(absltest.TestCase): def setUp(self): @@ -197,7 +249,7 @@ class MetadataWriterForTaskTest(absltest.TestCase): "output_tensor_metadata": [ { "name": "score", - "description": "Score of the labels respectively", + "description": "Score of the labels respectively.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { @@ -298,7 +350,7 @@ class MetadataWriterForTaskTest(absltest.TestCase): "output_tensor_metadata": [ { "name": "score", - "description": "Score of the labels respectively", + "description": "Score of the labels respectively.", "content": { "content_properties_type": "FeatureProperties", "content_properties": { diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 321b33a61..970d4dd8b 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -37,6 +37,26 @@ py_test( ], ) +py_test( + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:image_classifier", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) + py_test( name = "image_segmenter_test", srcs = ["image_segmenter_test.py"], diff --git a/mediapipe/tasks/python/test/vision/image_classifier_test.py b/mediapipe/tasks/python/test/vision/image_classifier_test.py new file mode 100644 index 000000000..afaf921a7 --- /dev/null +++ b/mediapipe/tasks/python/test/vision/image_classifier_test.py @@ -0,0 +1,515 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for image classifier.""" + +import enum +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from mediapipe.python._framework_bindings import image +from mediapipe.tasks.python.components.containers import category +from mediapipe.tasks.python.components.containers import classifications as classifications_module +from mediapipe.tasks.python.components.containers import rect +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import image_classifier +from mediapipe.tasks.python.vision.core import vision_task_running_mode + +_NormalizedRect = rect.NormalizedRect +_BaseOptions = base_options_module.BaseOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_Category = category.Category +_ClassificationEntry = classifications_module.ClassificationEntry +_Classifications = classifications_module.Classifications +_ClassificationResult = classifications_module.ClassificationResult +_Image = image.Image +_ImageClassifier = image_classifier.ImageClassifier +_ImageClassifierOptions = image_classifier.ImageClassifierOptions +_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode + +_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite' +_IMAGE_FILE = 'burger.jpg' +_ALLOW_LIST = ['cheeseburger', 'guacamole'] +_DENY_LIST = ['cheeseburger'] +_SCORE_THRESHOLD = 0.5 +_MAX_RESULTS = 3 + + +# TODO: Port assertProtoEquals +def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument + pass + + +def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry(categories=[], timestamp_ms=timestamp_ms) + ], + head_index=0, + head_name='probability') + ]) + + +def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=934, + score=0.7939587831497192, + display_name='', + category_name='cheeseburger'), + _Category( + index=932, + score=0.02739289402961731, + display_name='', + category_name='bagel'), + _Category( + index=925, + score=0.01934075355529785, + display_name='', + category_name='guacamole'), + _Category( + index=963, + score=0.006327860057353973, + display_name='', + category_name='meat loaf') + ], + timestamp_ms=timestamp_ms) + ], + head_index=0, + head_name='probability') + ]) + + +def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult: + return _ClassificationResult(classifications=[ + _Classifications( + entries=[ + _ClassificationEntry( + categories=[ + _Category( + index=806, + score=0.9965274930000305, + display_name='', + category_name='soccer ball') + ], + timestamp_ms=timestamp_ms) + ], + head_index=0, + head_name='probability') + ]) + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class ImageClassifierTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path(_IMAGE_FILE)) + self.model_path = test_utils.get_test_data_path(_MODEL_FILE) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _ImageClassifier.create_from_model_path(self.model_path) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _ImageClassifierOptions(base_options=base_options) + with _ImageClassifier.create_from_options(options) as classifier: + self.assertIsInstance(classifier, _ImageClassifier) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + ValueError, + r"ExternalFile must specify at least one of 'file_content', " + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): + base_options = _BaseOptions(model_asset_path='') + options = _ImageClassifierOptions(base_options=base_options) + _ImageClassifier.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _ImageClassifierOptions(base_options=base_options) + classifier = _ImageClassifier.create_from_options(options) + self.assertIsInstance(classifier, _ImageClassifier) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) + def test_classify(self, model_file_type, max_results, + expected_classification_result): + # Creates classifier. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + custom_classifier_options = _ClassifierOptions(max_results=max_results) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=custom_classifier_options) + classifier = _ImageClassifier.create_from_options(options) + + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + _assert_proto_equals(image_result.to_pb2(), + expected_classification_result.to_pb2()) + # Closes the classifier explicitly when the classifier is not used in + # a context. + classifier.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, 4, _generate_burger_results(0)), + (ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0))) + def test_classify_in_context(self, model_file_type, max_results, + expected_classification_result): + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + custom_classifier_options = _ClassifierOptions(max_results=max_results) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + # Comparing results. + _assert_proto_equals(image_result.to_pb2(), + expected_classification_result.to_pb2()) + + def test_classify_succeeds_with_region_of_interest(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + custom_classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=base_options, classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect( + x_center=0.532, y_center=0.521, width=0.164, height=0.427) + # Performs image classification on the input. + image_result = classifier.classify(test_image, roi) + # Comparing results. + _assert_proto_equals(image_result.to_pb2(), + _generate_soccer_ball_results(0).to_pb2()) + + def test_score_threshold_option(self): + custom_classifier_options = _ClassifierOptions( + score_threshold=_SCORE_THRESHOLD) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + score = entry.categories[0].score + self.assertGreaterEqual( + score, _SCORE_THRESHOLD, + f'Classification with score lower than threshold found. ' + f'{classification}') + + def test_max_results_option(self): + custom_classifier_options = _ClassifierOptions( + score_threshold=_SCORE_THRESHOLD) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + categories = image_result.classifications[0].entries[0].categories + + self.assertLessEqual( + len(categories), _MAX_RESULTS, 'Too many results returned.') + + def test_allow_list_option(self): + custom_classifier_options = _ClassifierOptions( + category_allowlist=_ALLOW_LIST) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertIn(label, _ALLOW_LIST, + f'Label {label} found but not in label allow list') + + def test_deny_list_option(self): + custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + classifications = image_result.classifications + + for classification in classifications: + for entry in classification.entries: + label = entry.categories[0].category_name + self.assertNotIn(label, _DENY_LIST, + f'Label {label} found but in deny list.') + + def test_combined_allowlist_and_denylist(self): + # Fails with combined allowlist and denylist + with self.assertRaisesRegex( + ValueError, + r'`category_allowlist` and `category_denylist` are mutually ' + r'exclusive options.'): + custom_classifier_options = _ClassifierOptions( + category_allowlist=['foo'], category_denylist=['bar']) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_empty_classification_outputs(self): + custom_classifier_options = _ClassifierOptions(score_threshold=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Performs image classification on the input. + image_result = classifier.classify(self.test_image) + self.assertEmpty(image_result.classifications[0].entries[0].categories) + + def test_missing_result_callback(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM) + with self.assertRaisesRegex(ValueError, + r'result callback must be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock()) + with self.assertRaisesRegex(ValueError, + r'result callback should not be provided'): + with _ImageClassifier.create_from_options(options) as unused_classifier: + pass + + def test_calling_classify_for_video_in_image_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + classifier.classify_for_video(self.test_image, 0) + + def test_calling_classify_async_in_image_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + classifier.classify_async(self.test_image, 0) + + def test_calling_classify_in_video_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + classifier.classify(self.test_image) + + def test_calling_classify_async_in_video_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the live stream mode'): + classifier.classify_async(self.test_image, 0) + + def test_classify_for_video_with_out_of_order_timestamp(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO) + with _ImageClassifier.create_from_options(options) as classifier: + unused_result = classifier.classify_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_for_video(self.test_image, 0) + + def test_classify_for_video(self): + custom_classifier_options = _ClassifierOptions(max_results=4) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classification_result = classifier.classify_for_video( + self.test_image, timestamp) + _assert_proto_equals(classification_result.to_pb2(), + _generate_burger_results(timestamp).to_pb2()) + + def test_classify_for_video_succeeds_with_region_of_interest(self): + custom_classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + classifier_options=custom_classifier_options) + with _ImageClassifier.create_from_options(options) as classifier: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect( + x_center=0.532, y_center=0.521, width=0.164, height=0.427) + for timestamp in range(0, 300, 30): + classification_result = classifier.classify_for_video( + test_image, timestamp, roi) + self.assertEqual(classification_result, + _generate_soccer_ball_results(timestamp)) + + def test_calling_classify_in_live_stream_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the image mode'): + classifier.classify(self.test_image) + + def test_calling_classify_for_video_in_live_stream_mode(self): + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + with self.assertRaisesRegex(ValueError, + r'not initialized with the video mode'): + classifier.classify_for_video(self.test_image, 0) + + def test_classify_async_calls_with_illegal_timestamp(self): + custom_classifier_options = _ClassifierOptions(max_results=4) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=custom_classifier_options, + result_callback=mock.MagicMock()) + with _ImageClassifier.create_from_options(options) as classifier: + classifier.classify_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing'): + classifier.classify_async(self.test_image, 0) + + @parameterized.parameters((0, _generate_burger_results), + (1, _generate_empty_results)) + def test_classify_async_calls(self, threshold, expected_result_fn): + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, output_image: _Image, + timestamp_ms: int): + _assert_proto_equals(result.to_pb2(), + expected_result_fn(timestamp_ms).to_pb2()) + self.assertTrue( + np.array_equal(output_image.numpy_view(), + self.test_image.numpy_view())) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + custom_classifier_options = _ClassifierOptions( + max_results=4, score_threshold=threshold) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=custom_classifier_options, + result_callback=check_result) + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classifier.classify_async(self.test_image, timestamp) + + def test_classify_async_succeeds_with_region_of_interest(self): + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path('multi_objects.jpg')) + # NormalizedRect around the soccer ball. + roi = _NormalizedRect( + x_center=0.532, y_center=0.521, width=0.164, height=0.427) + observed_timestamp_ms = -1 + + def check_result(result: _ClassificationResult, output_image: _Image, + timestamp_ms: int): + _assert_proto_equals(result.to_pb2(), + _generate_soccer_ball_results(timestamp_ms).to_pb2()) + self.assertEqual(output_image.width, test_image.width) + self.assertEqual(output_image.height, test_image.height) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + custom_classifier_options = _ClassifierOptions(max_results=1) + options = _ImageClassifierOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + classifier_options=custom_classifier_options, + result_callback=check_result) + with _ImageClassifier.create_from_options(options) as classifier: + for timestamp in range(0, 300, 30): + classifier.classify_async(test_image, timestamp, roi) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index da9072f18..00fc3268f 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -37,6 +37,28 @@ py_library( ], ) +py_library( + name = "image_classifier", + srcs = [ + "image_classifier.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:classifications", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/components/processors:classifier_options", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) + py_library( name = "image_segmenter", srcs = [ diff --git a/mediapipe/tasks/python/vision/core/BUILD b/mediapipe/tasks/python/vision/core/BUILD index c7422969a..df1b06f4c 100644 --- a/mediapipe/tasks/python/vision/core/BUILD +++ b/mediapipe/tasks/python/vision/core/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. package(default_visibility = ["//mediapipe/tasks:internal"]) diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py new file mode 100644 index 000000000..e41cc77a2 --- /dev/null +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -0,0 +1,294 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe image classifier task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +# TODO: Import MPImage directly one we have an alias +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet +from mediapipe.python._framework_bindings import task_runner +from mediapipe.tasks.cc.components.containers.proto import classifications_pb2 +from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2 +from mediapipe.tasks.python.components.containers import classifications +from mediapipe.tasks.python.components.containers import rect +from mediapipe.tasks.python.components.processors import classifier_options +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import vision_task_running_mode + +_NormalizedRect = rect.NormalizedRect +_BaseOptions = base_options_module.BaseOptions +_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions +_ClassifierOptions = classifier_options.ClassifierOptions +_RunningMode = vision_task_running_mode.VisionTaskRunningMode +_TaskInfo = task_info_module.TaskInfo +_TaskRunner = task_runner.TaskRunner + +_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out' +_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_NORM_RECT_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +def _build_full_image_norm_rect() -> _NormalizedRect: + # Builds a NormalizedRect covering the entire image. + return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1) + + +@dataclasses.dataclass +class ImageClassifierOptions: + """Options for the image classifier task. + + Attributes: + base_options: Base options for the image classifier task. + running_mode: The running mode of the task. Default to the image mode. Image + classifier task has three running modes: 1) The image mode for classifying + objects on single image inputs. 2) The video mode for classifying objects + on the decoded frames of a video. 3) The live stream mode for classifying + objects on a live stream of input data, such as from camera. + classifier_options: Options for the image classification task. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + classifier_options: _ClassifierOptions = _ClassifierOptions() + result_callback: Optional[ + Callable[[classifications.ClassificationResult, image_module.Image, int], + None]] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _ImageClassifierGraphOptionsProto: + """Generates an ImageClassifierOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True + classifier_options_proto = self.classifier_options.to_pb2() + + return _ImageClassifierGraphOptionsProto( + base_options=base_options_proto, + classifier_options=classifier_options_proto) + + +class ImageClassifier(base_vision_task_api.BaseVisionTaskApi): + """Class that performs image classification on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'ImageClassifier': + """Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`. + + Note that the created `ImageClassifier` instance is in image mode, for + classifying objects on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `ImageClassifier` object that's created from the model file and the + default `ImageClassifierOptions`. + + Raises: + ValueError: If failed to create `ImageClassifier` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = ImageClassifierOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, + options: ImageClassifierOptions) -> 'ImageClassifier': + """Creates the `ImageClassifier` object from image classifier options. + + Args: + options: Options for the image classifier task. + + Returns: + `ImageClassifier` object that's created from `options`. + + Raises: + ValueError: If failed to create `ImageClassifier` object from + `ImageClassifierOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + + classification_result = classifications.ClassificationResult([ + classifications.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp + options.result_callback(classification_result, image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]), + ], + output_streams=[ + ':'.join([ + _CLASSIFICATION_RESULT_TAG, + _CLASSIFICATION_RESULT_OUT_STREAM_NAME + ]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]) + ], + task_options=options) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode == + _RunningMode.LIVE_STREAM), options.running_mode, + packets_callback if options.result_callback else None) + + # TODO: Replace _NormalizedRect with ImageProcessingOption + def classify( + self, + image: image_module.Image, + roi: Optional[_NormalizedRect] = None + ) -> classifications.ClassificationResult: + """Performs image classification on the provided MediaPipe Image. + + Args: + image: MediaPipe Image. + roi: The region of interest. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If image classification failed to run. + """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2()) + }) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + + return classifications.ClassificationResult([ + classifications.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + + def classify_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None + ) -> classifications.ClassificationResult: + """Performs image classification on the provided video frames. + + Only use this method when the ImageClassifier is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + roi: The region of interest. + + Returns: + A classification result object that contains a list of classifications. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If image classification failed to run. + """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_NAME: + packet_creator.create_proto(norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) + + classification_result_proto = classifications_pb2.ClassificationResult() + classification_result_proto.CopyFrom( + packet_getter.get_proto( + output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME])) + + return classifications.ClassificationResult([ + classifications.Classifications.create_from_pb2(classification) + for classification in classification_result_proto.classifications + ]) + + def classify_async(self, + image: image_module.Image, + timestamp_ms: int, + roi: Optional[_NormalizedRect] = None) -> None: + """Sends live image data (an Image with a unique timestamp) to perform image classification. + + Only use this method when the ImageClassifier is created with the live + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input image is accepted. The results will be available via the + `result_callback` provided in the `ImageClassifierOptions`. The + `classify_async` method is designed to process live stream data such as + camera input. To lower the overall latency, image classifier may drop the + input images if needed. In other words, it's not guaranteed to have output + per input image. + + The `result_callback` provides: + - A classification result object that contains a list of classifications. + - The input image that the image classifier runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + roi: The region of interest. + + Raises: + ValueError: If the current input timestamp is smaller than what the image + classifier has already processed. + """ + norm_rect = roi if roi is not None else _build_full_image_norm_rect() + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: + packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + _NORM_RECT_NAME: + packet_creator.create_proto(norm_rect.to_pb2()).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND) + }) diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 9f50368b8..6d7bbab6a 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -29,6 +29,8 @@ mediapipe_files(srcs = [ "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v2_1.0_224_quant.tflite", + "mobilenet_v2_1.0_224_quant_without_metadata.tflite", + "mobilenet_v2_1.0_224_without_metadata.tflite", ]) exports_files([ @@ -48,6 +50,9 @@ exports_files([ "score_calibration.txt", "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", + "labels.txt", + "mobilenet_v2_1.0_224.json", + "mobilenet_v2_1.0_224_quant.json", ]) filegroup( @@ -59,6 +64,8 @@ filegroup( "mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v2_1.0_224_quant.tflite", + "mobilenet_v2_1.0_224_quant_without_metadata.tflite", + "mobilenet_v2_1.0_224_without_metadata.tflite", ], ) @@ -78,6 +85,9 @@ filegroup( "input_image_tensor_float_meta.json", "input_image_tensor_uint8_meta.json", "input_image_tensor_unsupported_meta.json", + "labels.txt", + "mobilenet_v2_1.0_224.json", + "mobilenet_v2_1.0_224_quant.json", "score_calibration.txt", "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", diff --git a/mediapipe/tasks/testdata/metadata/labels.txt b/mediapipe/tasks/testdata/metadata/labels.txt new file mode 100644 index 000000000..fe811239d --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json new file mode 100644 index 000000000..6f01f9f09 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224.json @@ -0,0 +1,82 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json new file mode 100644 index 000000000..e2ba42e3b --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/mobilenet_v2_1.0_224_quant.json @@ -0,0 +1,82 @@ +{ + "name": "ImageClassifier", + "description": "Identify the most prominent object in the image from a known set of categories.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "score", + "description": "Score of the labels respectively.", + "content": { + "content_properties_type": "FeatureProperties", + "content_properties": { + } + }, + "process_units": [ + { + "options_type": "ScoreCalibrationOptions", + "options": { + "score_transformation": "LOG", + "default_score": 0.2 + } + } + ], + "stats": { + "max": [ + 255.0 + ], + "min": [ + 0.0 + ] + }, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + }, + { + "name": "score_calibration.txt", + "description": "Contains sigmoid-based score calibration parameters. The main purposes of score calibration is to make scores across classes comparable, so that a common threshold can be used for all output classes.", + "type": "TENSOR_AXIS_SCORE_CALIBRATION" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 6cce5ae41..14999a03e 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -76,9 +76,10 @@ filegroup( filegroup( name = "text_classifier_models", - srcs = glob([ - "test_model_text_classifier*.tflite", - ]), + srcs = [ + "test_model_text_classifier_bool_output.tflite", + "test_model_text_classifier_with_regex_tokenizer.tflite", + ], ) filegroup( diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index ffb4760d9..c45cc6e69 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -28,6 +28,8 @@ mediapipe_files(srcs = [ "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", + "cat_rotated.jpg", + "cat_rotated_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", @@ -35,7 +37,6 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", - "hand_landmark.task", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", @@ -73,6 +74,7 @@ exports_files( "expected_left_up_hand_rotated_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", + "gesture_recognizer.task", ], ) @@ -84,6 +86,8 @@ filegroup( "burger_rotated.jpg", "cat.jpg", "cat_mask.jpg", + "cat_rotated.jpg", + "cat_rotated_mask.jpg", "cats_and_dogs.jpg", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", @@ -118,9 +122,9 @@ filegroup( "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", - "hand_landmark.task", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", + "hand_landmarker.task", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", diff --git a/mediapipe/tasks/testdata/vision/hand_landmark.task b/mediapipe/tasks/testdata/vision/hand_landmarker.task similarity index 99% rename from mediapipe/tasks/testdata/vision/hand_landmark.task rename to mediapipe/tasks/testdata/vision/hand_landmarker.task index b6eedf324..1ae9f7f6b 100644 Binary files a/mediapipe/tasks/testdata/vision/hand_landmark.task and b/mediapipe/tasks/testdata/vision/hand_landmarker.task differ diff --git a/mediapipe/util/resource_util_android.cc b/mediapipe/util/resource_util_android.cc index b18354d5f..1e970f212 100644 --- a/mediapipe/util/resource_util_android.cc +++ b/mediapipe/util/resource_util_android.cc @@ -82,7 +82,8 @@ absl::StatusOr PathToResourceAsFile(const std::string& path) { // If that fails, assume it was a relative path, and try just the base name. { const size_t last_slash_idx = path.find_last_of("\\/"); - CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. + RET_CHECK(last_slash_idx != std::string::npos) + << path << " doesn't have a slash in it"; // Make sure it's a path. auto base_name = path.substr(last_slash_idx + 1); auto status_or_path = PathToResourceAsFileInternal(base_name); if (status_or_path.ok()) { diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index c812dcb57..f64718348 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -71,7 +71,8 @@ absl::StatusOr PathToResourceAsFile(const std::string& path) { // If that fails, assume it was a relative path, and try just the base name. { const size_t last_slash_idx = path.find_last_of("\\/"); - CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. + RET_CHECK(last_slash_idx != std::string::npos) + << path << " doesn't have a slash in it"; // Make sure it's a path. auto base_name = path.substr(last_slash_idx + 1); auto status_or_path = PathToResourceAsFileInternal(base_name); if (status_or_path.ok()) { diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 9d37b60a0..e9b8bfa03 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -84,33 +84,28 @@ cc_library( "//conditions:default": ["tflite_gpu_runner.h"], }), deps = select({ - "//mediapipe:ios": [], - "//mediapipe:macos": [], - "//conditions:default": [ - "@com_google_absl//absl/strings", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/delegates/gpu:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", - ], - "//mediapipe:android": [ - "@com_google_absl//absl/strings", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:statusor", - "@org_tensorflow//tensorflow/lite:framework", - "@org_tensorflow//tensorflow/lite/delegates/gpu:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", - ], - }) + [ + "//mediapipe:ios": [], + "//mediapipe:macos": [], + "//conditions:default": [ + "@com_google_absl//absl/strings", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/delegates/gpu:api", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:model_builder", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:api2", + ], + }) + + select({ + "//mediapipe:android": [ + "@org_tensorflow//tensorflow/lite/delegates/gpu/cl:api", + ], + "//conditions:default": [], + }) + [ "@com_google_absl//absl/status", + "//mediapipe/framework:port", "@org_tensorflow//tensorflow/lite/core/api", ], ) diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index 4c422835a..4e40975cb 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -34,7 +34,7 @@ // This code should be enabled as soon as TensorFlow version, which mediapipe // uses, will include this module. -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) #include "tensorflow/lite/delegates/gpu/cl/api.h" #endif @@ -82,7 +82,7 @@ ObjectDef GetSSBOObjectDef(int channels) { return gpu_object_def; } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) { cl::InferenceOptions result{}; @@ -106,7 +106,7 @@ absl::Status VerifyShapes(const std::vector& actual, return absl::OkStatus(); } -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } // namespace @@ -225,7 +225,7 @@ absl::Status TFLiteGPURunner::InitializeOpenGL( absl::Status TFLiteGPURunner::InitializeOpenCL( std::unique_ptr* builder) { -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) cl::InferenceEnvironmentOptions env_options; if (!serialized_binary_cache_.empty()) { env_options.serialized_binary_cache = serialized_binary_cache_; @@ -254,11 +254,12 @@ absl::Status TFLiteGPURunner::InitializeOpenCL( return absl::OkStatus(); #else - return mediapipe::UnimplementedError("Currently only Android is supported"); -#endif // __ANDROID__ + return mediapipe::UnimplementedError( + "Currently only Android & ChromeOS are supported"); +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel( std::unique_ptr* builder) { @@ -283,7 +284,7 @@ absl::StatusOr> TFLiteGPURunner::GetSerializedModel() { return serialized_model; } -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) } // namespace gpu } // namespace tflite diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index 88d3914f7..dfbc8d659 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -20,6 +20,7 @@ #include #include "absl/status/status.h" +#include "mediapipe/framework/port.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -28,9 +29,9 @@ #include "tensorflow/lite/delegates/gpu/gl/api2.h" #include "tensorflow/lite/model.h" -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) #include "tensorflow/lite/delegates/gpu/cl/api.h" -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) namespace tflite { namespace gpu { @@ -83,7 +84,7 @@ class TFLiteGPURunner { return output_shape_from_model_; } -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) void SetSerializedBinaryCache(std::vector&& cache) { serialized_binary_cache_ = std::move(cache); } @@ -98,26 +99,26 @@ class TFLiteGPURunner { } absl::StatusOr> GetSerializedModel(); -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) private: absl::Status InitializeOpenGL(std::unique_ptr* builder); absl::Status InitializeOpenCL(std::unique_ptr* builder); -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) absl::Status InitializeOpenCLFromSerializedModel( std::unique_ptr* builder); -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) InferenceOptions options_; std::unique_ptr gl_environment_; -#ifdef __ANDROID__ +#if defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) std::unique_ptr cl_environment_; std::vector serialized_binary_cache_; std::vector serialized_model_; bool serialized_model_used_ = false; -#endif // __ANDROID__ +#endif // defined(__ANDROID__) || defined(MEDIAPIPE_CHROMIUMOS) // graph_gl_ is maintained temporarily and becomes invalid after runner_ is // ready diff --git a/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff b/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff new file mode 100644 index 000000000..a084d9262 --- /dev/null +++ b/third_party/com_google_sentencepiece_no_gflag_no_gtest.diff @@ -0,0 +1,34 @@ +diff --git a/src/BUILD b/src/BUILD +index b4298d2..f3877a3 100644 +--- a/src/BUILD ++++ b/src/BUILD +@@ -71,9 +71,7 @@ cc_library( + ":common", + ":sentencepiece_cc_proto", + ":sentencepiece_model_cc_proto", +- "@com_github_gflags_gflags//:gflags", + "@com_google_glog//:glog", +- "@com_google_googletest//:gtest", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:flat_hash_map", +diff --git a/src/normalizer.h b/src/normalizer.h +index c16ac16..2af58be 100644 +--- a/src/normalizer.h ++++ b/src/normalizer.h +@@ -21,7 +21,6 @@ + #include + #include + +-#include "gtest/gtest_prod.h" + #include "absl/strings/string_view.h" + #include "third_party/darts_clone/include/darts.h" + #include "src/common.h" +@@ -97,7 +96,6 @@ class Normalizer { + friend class Builder; + + private: +- FRIEND_TEST(NormalizerTest, EncodeDecodePrecompiledCharsMapTest); + + void Init(); + diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 2c92293ff..4b7309eef 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -31,7 +31,7 @@ def external_files(): http_file( name = "com_google_mediapipe_bert_text_classifier_tflite", sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", - urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1663009542017720"], + urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"], ) http_file( @@ -46,12 +46,6 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], ) - http_file( - name = "com_google_mediapipe_BUILD_orig", - sha256 = "650df617b3e125e0890f1b8c936cc64c9d975707f57e616b6430fc667ce315d4", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1665609930388174"], - ) - http_file( name = "com_google_mediapipe_burger_crop_jpg", sha256 = "8f58de573f0bf59a49c3d86cfabb9ad4061481f574aa049177e8da3963dddc50", @@ -82,6 +76,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"], ) + http_file( + name = "com_google_mediapipe_cat_rotated_jpg", + sha256 = "b78cee5ad14c9f36b1c25d103db371d81ca74d99030063c46a38e80bb8f38649", + urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated.jpg?generation=1666304165042123"], + ) + + http_file( + name = "com_google_mediapipe_cat_rotated_mask_jpg", + sha256 = "f336973e7621d602f2ebc9a6ab1c62d8502272d391713f369d3b99541afda861", + urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated_mask.jpg?generation=1666304167148173"], + ) + http_file( name = "com_google_mediapipe_cats_and_dogs_jpg", sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c", @@ -127,7 +133,7 @@ def external_files(): http_file( name = "com_google_mediapipe_coco_ssd_mobilenet_v1_1_0_quant_2018_06_29_tflite", sha256 = "61d598093ed03ed41aa47c3a39a28ac01e960d6a810a5419b9a5016a1e9c469b", - urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite?generation=1661875702588267"], + urls = ["https://storage.googleapis.com/mediapipe-assets/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite?generation=1666144700870810"], ) http_file( @@ -168,8 +174,8 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt", - sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"], + sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666629474155924"], ) http_file( @@ -180,8 +186,8 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt", - sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"], + sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"], ) http_file( @@ -264,8 +270,8 @@ def external_files(): http_file( name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt", - sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370", - urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"], + sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666629478777955"], ) http_file( @@ -274,6 +280,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_two_hands.pbtxt?generation=1662745353586157"], ) + http_file( + name = "com_google_mediapipe_hand_landmarker_task", + sha256 = "2ed44f10872e87a5834b9b1130fb9ada30e107af2c6fcc4562ad788aca4e7bc4", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmarker.task?generation=1666153732577904"], + ) + http_file( name = "com_google_mediapipe_hand_landmark_full_tflite", sha256 = "11c272b891e1a99ab034208e23937a8008388cf11ed2a9d776ed3d01d0ba00e3", @@ -287,9 +299,9 @@ def external_files(): ) http_file( - name = "com_google_mediapipe_hand_landmark_task", - sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95", - urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"], + name = "com_google_mediapipe_hand_landmark_tflite", + sha256 = "bad88ac1fd144f034e00f075afcade4f3a21d0d09c41bee8dd50504dacd70efd", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.tflite?generation=1666153735814956"], ) http_file( @@ -364,6 +376,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/knift_labelmap.txt?generation=1661875792821628"], ) + http_file( + name = "com_google_mediapipe_labels_txt", + sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9", + urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1665988394538324"], + ) + http_file( name = "com_google_mediapipe_left_hands_jpg", sha256 = "4b5134daa4cb60465535239535f9f74c2842aba3aa5fd30bf04ef5678f93d87f", @@ -448,18 +466,42 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite?generation=1661875836078124"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_json", + sha256 = "94613ea9539a20a3352604004be6d4d64d4d76250bc9042fcd8685c9a8498517", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1666633416316646"], + ) + + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json", + sha256 = "3703eadcf838b65bbc2b2aa11dbb1f1bc654c7a09a7aba5ca75a26096484a8ac", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1666633418665507"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_tflite", sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d", urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.tflite?generation=1664340173966530"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_without_metadata_tflite", + sha256 = "f08d447cde49b4e0446428aa921aff0a14ea589fa9c5817b31f83128e9a43c1d", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant_without_metadata.tflite?generation=1665988405130772"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v2_1_0_224_tflite", sha256 = "ff5cb7f9e62c92ebdad971f8a98aa6b3106d82a64587a7787c6a385c9e791339", urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.tflite?generation=1661875840611150"], ) + http_file( + name = "com_google_mediapipe_mobilenet_v2_1_0_224_without_metadata_tflite", + sha256 = "9f3bc29e38e90842a852bfed957dbf5e36f2d97a91dd17736b1e5c0aca8d3303", + urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_without_metadata.tflite?generation=1665988408360823"], + ) + http_file( name = "com_google_mediapipe_mobilenet_v3_small_100_224_embedder_tflite", sha256 = "f7b9a563cb803bdcba76e8c7e82abde06f5c7a8e67b5e54e43e23095dfe79a78", @@ -475,7 +517,7 @@ def external_files(): http_file( name = "com_google_mediapipe_mobile_object_labeler_v1_tflite", sha256 = "9400671e04685f5277edd3052a311cc51533de9da94255c52ebde1e18484c77c", - urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_object_labeler_v1.tflite?generation=1661875846924538"], + urls = ["https://storage.googleapis.com/mediapipe-assets/mobile_object_labeler_v1.tflite?generation=1666144701839813"], ) http_file( @@ -576,8 +618,8 @@ def external_files(): http_file( name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt", - sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de", - urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"], + sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"], ) http_file( @@ -768,8 +810,8 @@ def external_files(): http_file( name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt", - sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102", - urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"], + sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257", + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666629489421733"], ) http_file(