Merge branch 'master' into image-segmenter-python-impl

This commit is contained in:
Kinar R 2022-10-28 22:50:33 +05:30 committed by GitHub
commit 334f641463
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
142 changed files with 6106 additions and 1291 deletions

View File

@ -172,6 +172,10 @@ http_archive(
urls = [ urls = [
"https://github.com/google/sentencepiece/archive/1.0.0.zip", "https://github.com/google/sentencepiece/archive/1.0.0.zip",
], ],
patches = [
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
],
patch_args = ["-p1"],
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"}, repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
) )

14
docs/BUILD Normal file
View File

@ -0,0 +1,14 @@
# Placeholder for internal Python strict binary compatibility macro.
py_binary(
name = "build_py_api_docs",
srcs = ["build_py_api_docs.py"],
deps = [
"//mediapipe",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/tensorflow_docs",
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
"//third_party/py/tensorflow_docs/api_generator:public_api",
],
)

85
docs/build_py_api_docs.py Normal file
View File

@ -0,0 +1,85 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""MediaPipe reference docs generation script.
This script generates API reference docs for the `mediapipe` PIP package.
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
$> python build_py_api_docs.py
"""
import os
from absl import app
from absl import flags
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
try:
# mediapipe has not been set up to work with bazel yet, so catch & report.
import mediapipe # pytype: disable=import-error
except ImportError as e:
raise ImportError('Please `pip install mediapipe`.') from e
PROJECT_SHORT_NAME = 'mp'
PROJECT_FULL_NAME = 'MediaPipe'
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir',
default='/tmp/generated_docs',
help='Where to write the resulting docs.')
_URL_PREFIX = flags.DEFINE_string(
'code_url_prefix',
'https://github.com/google/mediapipe/tree/master/mediapipe',
'The url prefix for links to code.')
_SEARCH_HINTS = flags.DEFINE_bool(
'search_hints', True,
'Include metadata search hints in the generated files')
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
'Path prefix in the _toc.yaml')
def gen_api_docs():
"""Generates API docs for the mediapipe package."""
doc_generator = generate_lib.DocGenerator(
root_title=PROJECT_FULL_NAME,
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
base_dir=os.path.dirname(mediapipe.__file__),
code_url_prefix=_URL_PREFIX.value,
search_hints=_SEARCH_HINTS.value,
site_path=_SITE_PATH.value,
# This callback ensures that docs are only generated for objects that
# are explicitly imported in your __init__.py files. There are other
# options but this is a good starting point.
callbacks=[public_api.explicit_package_contents_filter],
)
doc_generator.build(_OUTPUT_DIR.value)
print('Docs output to:', _OUTPUT_DIR.value)
def main(_):
gen_api_docs()
if __name__ == '__main__':
app.run(main)

View File

@ -253,6 +253,26 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "regex_preprocessor_calculator_test",
srcs = ["regex_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":regex_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "text_to_tensor_calculator", name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"], srcs = ["text_to_tensor_calculator.cc"],
@ -307,6 +327,27 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "universal_sentence_encoder_preprocessor_calculator_test",
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
deps = [
":universal_sentence_encoder_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],

View File

@ -26,6 +26,8 @@
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors, CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
std::vector<Tensor>& output_tensors) { std::vector<Tensor>& output_tensors) {
return gpu_helper_.RunInGlContext( return gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status { [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
// Explicitly copy input. // Explicitly copy input.
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
glBindBuffer(GL_COPY_READ_BUFFER, glBindBuffer(GL_COPY_READ_BUFFER,
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
} }
// Run inference. // Run inference.
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
output_tensors.reserve(output_size_); output_tensors.reserve(output_size_);
for (int i = 0; i < output_size_; ++i) { for (int i = 0; i < output_size_; ++i) {

View File

@ -32,6 +32,8 @@
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate& delegate); const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
absl::StatusOr<std::vector<Tensor>> Process( absl::StatusOr<std::vector<Tensor>> Process(
const std::vector<Tensor>& input_tensors); CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
absl::Status Close(); absl::Status Close();
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
absl::StatusOr<std::vector<Tensor>> absl::StatusOr<std::vector<Tensor>>
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
const std::vector<Tensor>& input_tensors) { CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
std::vector<Tensor> output_tensors; std::vector<Tensor> output_tensors;
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status { [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i)); input_tensors[i].GetOpenGlBufferReadView().name(), i));
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
output_tensors.back().GetOpenGlBufferWriteView().name(), i)); output_tensors.back().GetOpenGlBufferWriteView().name(), i));
} }
// Run inference. // Run inference.
return tflite_gpu_runner_->Invoke(); {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
return tflite_gpu_runner_->Invoke();
}
})); }));
return output_tensors; return output_tensors;
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
auto output_tensors = absl::make_unique<std::vector<Tensor>>(); auto output_tensors = absl::make_unique<std::vector<Tensor>>();
ASSIGN_OR_RETURN(*output_tensors, ASSIGN_OR_RETURN(*output_tensors,
gpu_inference_runner_->Process(input_tensors)); gpu_inference_runner_->Process(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors)); kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus(); return absl::OkStatus();

View File

@ -289,8 +289,15 @@ class NodeBase {
template <typename T> template <typename T>
T& GetOptions() { T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true; options_used_ = true;
return *options_.MutableExtension(T::ext); return *options_.MutableExtension(extension);
} }
protected: protected:
@ -386,8 +393,15 @@ class PacketGenerator {
template <typename T> template <typename T>
T& GetOptions() { T& GetOptions() {
return GetOptions(T::ext);
}
// Use this API when the proto extension does not follow the "ext" naming
// convention.
template <typename E>
auto& GetOptions(const E& extension) {
options_used_ = true; options_used_ = true;
return *options_.MutableExtension(T::ext); return *options_.MutableExtension(extension);
} }
template <typename B, typename T, bool kIsOptional, bool kIsMultiple> template <typename B, typename T, bool kIsOptional, bool kIsMultiple>

View File

@ -185,7 +185,7 @@ class CalculatorBaseFactory {
// Functions for checking that the calculator has the required GetContract. // Functions for checking that the calculator has the required GetContract.
template <class T> template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
typedef absl::Status (*GetContractType)(CalculatorContract * cc); typedef absl::Status (*GetContractType)(CalculatorContract* cc);
return std::is_same<decltype(&T::GetContract), GetContractType>::value; return std::is_same<decltype(&T::GetContract), GetContractType>::value;
} }
template <class T> template <class T>

View File

@ -133,7 +133,12 @@ message GraphTrace {
TPU_TASK = 13; TPU_TASK = 13;
GPU_CALIBRATION = 14; GPU_CALIBRATION = 14;
PACKET_QUEUED = 15; PACKET_QUEUED = 15;
GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17;
} }
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
// )
// The timing for one packet set being processed at one caclulator node. // The timing for one packet set being processed at one caclulator node.
message CalculatorTrace { message CalculatorTrace {

View File

@ -293,7 +293,6 @@ mediapipe_proto_library(
name = "rect_proto", name = "rect_proto",
srcs = ["rect.proto"], srcs = ["rect.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/formats:location_data_proto"],
) )
mediapipe_register_type( mediapipe_register_type(

View File

@ -109,6 +109,11 @@ struct TraceEvent {
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK; static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION; static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED; static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
// )
}; };
// Packet trace log buffer. // Packet trace log buffer.

View File

@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT; static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
static void EglThreadExitCallback(void* key_value) { static void EglThreadExitCallback(void* key_value) {
#if defined(__ANDROID__)
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
EGL_NO_CONTEXT);
#else
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display // Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
// parameter for eglMakeCurrent. This behavior is not portable to all EGL // parameter for eglMakeCurrent. This behavior is not portable to all EGL
// implementations, and should be considered as an undocumented vendor // implementations, and should be considered as an undocumented vendor
// extension. // extension.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
//
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT); EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif
eglReleaseThread(); eglReleaseThread();
} }

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.ImageProperties; import com.google.mediapipe.framework.image.MPImageProperties;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
// TODO: use Preconditions in this file. // TODO: use Preconditions in this file.
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
} }
/** /**
* Creates an Image packet from an {@link Image}. * Creates a MediaPipe Image packet from a {@link MPImage}.
* *
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. * <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
*/ */
public Packet createImage(Image image) { public Packet createImage(MPImage image) {
// TODO: Choose the best storage from multiple containers. // TODO: Choose the best storage from multiple containers.
ImageProperties properties = image.getContainedImageProperties().get(0); MPImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
ByteBuffer buffer = ByteBufferExtractor.extract(image); ByteBuffer buffer = ByteBufferExtractor.extract(image);
int numChannels = 0; int numChannels = 0;
switch (properties.getImageFormat()) { switch (properties.getImageFormat()) {
case Image.IMAGE_FORMAT_RGBA: case MPImage.IMAGE_FORMAT_RGBA:
numChannels = 4; numChannels = 4;
break; break;
case Image.IMAGE_FORMAT_RGB: case MPImage.IMAGE_FORMAT_RGB:
numChannels = 3; numChannels = 3;
break; break;
case Image.IMAGE_FORMAT_ALPHA: case MPImage.IMAGE_FORMAT_ALPHA:
numChannels = 1; numChannels = 1;
break; break;
default: // fall out default: // fall out
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
int height = image.getHeight(); int height = image.getHeight();
return createImage(buffer, width, height, numChannels); return createImage(buffer, width, height, numChannels);
} }
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image); Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");

View File

@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
import android.graphics.Bitmap; import android.graphics.Bitmap;
/** /**
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
* {@link IllegalArgumentException} will be thrown. * {@link IllegalArgumentException} will be thrown.
*/ */
public final class BitmapExtractor { public final class BitmapExtractor {
/** /**
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}. * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
* *
* @param image the image to extract {@link android.graphics.Bitmap} from. * @param image the image to extract {@link android.graphics.Bitmap} from.
* @return the {@link android.graphics.Bitmap} stored in {@link Image} * @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type * @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions. * conversions.
*/ */
public static Bitmap extract(Image image) { public static Bitmap extract(MPImage image) {
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
if (imageContainer != null) { if (imageContainer != null) {
return ((BitmapImageContainer) imageContainer).getBitmap(); return ((BitmapImageContainer) imageContainer).getBitmap();
} else { } else {
// TODO: Support ByteBuffer -> Bitmap conversion. // TODO: Support ByteBuffer -> Bitmap conversion.
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting Bitmap from an Image created by objects other than Bitmap is not" "Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
+ " supported"); + " supported");
} }
} }

View File

@ -22,7 +22,7 @@ import android.provider.MediaStore;
import java.io.IOException; import java.io.IOException;
/** /**
* Builds {@link Image} from {@link android.graphics.Bitmap}. * Builds {@link MPImage} from {@link android.graphics.Bitmap}.
* *
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once * <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
} }
/** /**
* Creates the builder to build {@link Image} from a file. * Creates the builder to build {@link MPImage} from a file.
* *
* @param context the application context. * @param context the application context.
* @param uri the path to the resource file. * @param uri the path to the resource file.
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
BitmapImageBuilder setTimestamp(long timestamp) { BitmapImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image( return new MPImage(
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
} }
} }

View File

@ -16,19 +16,19 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
class BitmapImageContainer implements ImageContainer { class BitmapImageContainer implements MPImageContainer {
private final Bitmap bitmap; private final Bitmap bitmap;
private final ImageProperties properties; private final MPImageProperties properties;
public BitmapImageContainer(Bitmap bitmap) { public BitmapImageContainer(Bitmap bitmap) {
this.bitmap = bitmap; this.bitmap = bitmap;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setImageFormat(convertFormatCode(bitmap.getConfig())) .setImageFormat(convertFormatCode(bitmap.getConfig()))
.setStorageType(Image.STORAGE_TYPE_BITMAP) .setStorageType(MPImage.STORAGE_TYPE_BITMAP)
.build(); .build();
} }
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
bitmap.recycle(); bitmap.recycle();
} }
@ImageFormat @MPImageFormat
static int convertFormatCode(Bitmap.Config config) { static int convertFormatCode(Bitmap.Config config) {
switch (config) { switch (config) {
case ALPHA_8: case ALPHA_8:
return Image.IMAGE_FORMAT_ALPHA; return MPImage.IMAGE_FORMAT_ALPHA;
case ARGB_8888: case ARGB_8888:
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
default: default:
return Image.IMAGE_FORMAT_UNKNOWN; return MPImage.IMAGE_FORMAT_UNKNOWN;
} }
} }
} }

View File

@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
import android.os.Build.VERSION; import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.Locale; import java.util.Locale;
/** /**
* Utility for extracting {@link ByteBuffer} from {@link Image}. * Utility for extracting {@link ByteBuffer} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
* {@link IllegalArgumentException} will be thrown. * otherwise {@link IllegalArgumentException} will be thrown.
*/ */
public class ByteBufferExtractor { public class ByteBufferExtractor {
/** /**
* Extracts a {@link ByteBuffer} from an {@link Image}. * Extracts a {@link ByteBuffer} from a {@link MPImage}.
* *
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
* *
* @see Image#getContainedImageProperties() * @see MPImage#getContainedImageProperties()
* @return A read-only {@link ByteBuffer}. * @return A read-only {@link ByteBuffer}.
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/ */
@SuppressLint("SwitchIntDef") @SuppressLint("SwitchIntDef")
public static ByteBuffer extract(Image image) { public static ByteBuffer extract(MPImage image) {
ImageContainer container = image.getContainer(); MPImageContainer container = image.getContainer();
switch (container.getImageProperties().getStorageType()) { switch (container.getImageProperties().getStorageType()) {
case Image.STORAGE_TYPE_BYTEBUFFER: case MPImage.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
default: default:
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
+ " supported"); + " supported");
} }
} }
/** /**
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
* *
* <p>Format conversion spec: * <p>Format conversion spec:
* *
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
* *
* @param image the image to extract buffer from. * @param image the image to extract buffer from.
* @param targetFormat the image format of the result bytebuffer. * @param targetFormat the image format of the result bytebuffer.
* @return the readonly {@link ByteBuffer} stored in {@link Image} * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type * @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions. * conversions.
*/ */
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
ImageContainer container; MPImageContainer container;
ImageProperties byteBufferProperties = MPImageProperties byteBufferProperties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(targetFormat) .setImageFormat(targetFormat)
.build(); .build();
if ((container = image.getContainer(byteBufferProperties)) != null) { if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer(); .asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer = ByteBuffer byteBuffer =
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
return byteBuffer; return byteBuffer;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by objects other than Bitmap or" "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
+ " Bytebuffer is not supported"); + " Bytebuffer is not supported");
} }
} }
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
@AutoValue @AutoValue
abstract static class Result { abstract static class Result {
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ /**
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
public abstract ByteBuffer buffer(); public abstract ByteBuffer buffer();
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ /**
@ImageFormat * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
@MPImageFormat
public abstract int format(); public abstract int format();
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
} }
} }
/** /**
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
* *
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
* *
* @return the readonly {@link ByteBuffer} stored in {@link Image} * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
* given {@code imageFormat} * given {@code imageFormat}
*/ */
static Result extractInRecommendedFormat(Image image) { static Result extractInRecommendedFormat(MPImage image) {
ImageContainer container; MPImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
@ImageFormat int format = adviseImageFormat(bitmap); @MPImageFormat int format = adviseImageFormat(bitmap);
Result result = Result result =
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
boolean unused = boolean unused =
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
return result; return result;
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return Result.create( return Result.create(
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
byteBufferImageContainer.getImageFormat()); byteBufferImageContainer.getImageFormat());
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
+ " is not supported"); + " is not supported");
} }
} }
@ImageFormat @MPImageFormat
private static int adviseImageFormat(Bitmap bitmap) { private static int adviseImageFormat(Bitmap bitmap) {
if (bitmap.getConfig() == Config.ARGB_8888) { if (bitmap.getConfig() == Config.ARGB_8888) {
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
String.format( String.format(
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
+ " supported", + " supported",
bitmap.getConfig())); bitmap.getConfig()));
} }
} }
private static ByteBuffer extractByteBufferFromBitmap( private static ByteBuffer extractByteBufferFromBitmap(
Bitmap bitmap, @ImageFormat int imageFormat) { Bitmap bitmap, @MPImageFormat int imageFormat) {
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
+ " supported"); + " supported");
} }
if (bitmap.getConfig() == Config.ARGB_8888) { if (bitmap.getConfig() == Config.ARGB_8888) {
if (imageFormat == Image.IMAGE_FORMAT_RGBA) { if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer); bitmap.copyPixelsToBuffer(buffer);
buffer.rewind(); buffer.rewind();
return buffer; return buffer;
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) { } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster. // TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
int w = bitmap.getWidth(); int w = bitmap.getWidth();
int h = bitmap.getHeight(); int h = bitmap.getHeight();
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
String.format( String.format(
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
+ " %d is not supported", + " %d is not supported",
bitmap.getConfig(), imageFormat)); bitmap.getConfig(), imageFormat));
} }
private static ByteBuffer convertByteBuffer( private static ByteBuffer convertByteBuffer(
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
// array reversely to convert in-place. // array reversely to convert in-place.
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
target.put(array, 0, target.capacity()); target.put(array, 0, target.capacity());
target.rewind(); target.rewind();
return target; return target;
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
// array to convert in-place. // array to convert in-place.

View File

@ -15,11 +15,11 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
/** /**
* Builds a {@link Image} from a {@link ByteBuffer}. * Builds a {@link MPImage} from a {@link ByteBuffer}.
* *
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link * <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
private final ByteBuffer buffer; private final ByteBuffer buffer;
private final int width; private final int width;
private final int height; private final int height;
@ImageFormat private final int imageFormat; @MPImageFormat private final int imageFormat;
// Optional fields. // Optional fields.
private long timestamp; private long timestamp;
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
* @param imageFormat how the data encode the image. * @param imageFormat how the data encode the image.
*/ */
public ByteBufferImageBuilder( public ByteBufferImageBuilder(
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
this.buffer = byteBuffer; this.buffer = byteBuffer;
this.width = width; this.width = width;
this.height = height; this.height = height;
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
this.timestamp = 0; this.timestamp = 0;
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
ByteBufferImageBuilder setTimestamp(long timestamp) { ByteBufferImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
} }
} }

View File

@ -15,21 +15,19 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer { class ByteBufferImageContainer implements MPImageContainer {
private final ByteBuffer buffer; private final ByteBuffer buffer;
private final ImageProperties properties; private final MPImageProperties properties;
public ByteBufferImageContainer( public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
ByteBuffer buffer,
@ImageFormat int imageFormat) {
this.buffer = buffer; this.buffer = buffer;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(imageFormat) .setImageFormat(imageFormat)
.build(); .build();
} }
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
/** /** Returns the image format. */
* Returns the image format. @MPImageFormat
*/
@ImageFormat
public int getImageFormat() { public int getImageFormat() {
return properties.getImageFormat(); return properties.getImageFormat();
} }

View File

@ -29,10 +29,10 @@ import java.util.Map.Entry;
/** /**
* The wrapper class for image objects. * The wrapper class for image objects.
* *
* <p>{@link Image} is designed to be an immutable image container, which could be shared * <p>{@link MPImage} is designed to be an immutable image container, which could be shared
* cross-platforms. * cross-platforms.
* *
* <p>To construct an {@link Image}, use the provided builders: * <p>To construct a {@link MPImage}, use the provided builders:
* *
* <ul> * <ul>
* <li>{@link ByteBufferImageBuilder} * <li>{@link ByteBufferImageBuilder}
@ -40,7 +40,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageBuilder} * <li>{@link MediaImageBuilder}
* </ul> * </ul>
* *
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the * <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release * reference count is 1. Developer can call {@link #close()} to reduce reference count to release
* internal storage earlier, otherwise Java garbage collection will release the storage eventually. * internal storage earlier, otherwise Java garbage collection will release the storage eventually.
* *
@ -53,7 +53,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageExtractor} * <li>{@link MediaImageExtractor}
* </ul> * </ul>
*/ */
public class Image implements Closeable { public class MPImage implements Closeable {
/** Specifies the image format of an image. */ /** Specifies the image format of an image. */
@IntDef({ @IntDef({
@ -69,7 +69,7 @@ public class Image implements Closeable {
IMAGE_FORMAT_JPEG, IMAGE_FORMAT_JPEG,
}) })
@Retention(RetentionPolicy.SOURCE) @Retention(RetentionPolicy.SOURCE)
public @interface ImageFormat {} public @interface MPImageFormat {}
public static final int IMAGE_FORMAT_UNKNOWN = 0; public static final int IMAGE_FORMAT_UNKNOWN = 0;
public static final int IMAGE_FORMAT_RGBA = 1; public static final int IMAGE_FORMAT_RGBA = 1;
@ -98,14 +98,14 @@ public class Image implements Closeable {
public static final int STORAGE_TYPE_IMAGE_PROXY = 4; public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
/** /**
* Returns a list of supported image properties for this {@link Image}. * Returns a list of supported image properties for this {@link MPImage}.
* *
* <p>Currently {@link Image} only support single storage type so the size of return list will * <p>Currently {@link MPImage} only support single storage type so the size of return list will
* always be 1. * always be 1.
* *
* @see ImageProperties * @see MPImageProperties
*/ */
public List<ImageProperties> getContainedImageProperties() { public List<MPImageProperties> getContainedImageProperties() {
return Collections.singletonList(getContainer().getImageProperties()); return Collections.singletonList(getContainer().getImageProperties());
} }
@ -124,7 +124,7 @@ public class Image implements Closeable {
return height; return height;
} }
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
private synchronized void acquire() { private synchronized void acquire() {
referenceCount += 1; referenceCount += 1;
} }
@ -132,7 +132,7 @@ public class Image implements Closeable {
/** /**
* Removes a reference that was previously acquired or init. * Removes a reference that was previously acquired or init.
* *
* <p>When {@link Image} is created, it has 1 reference count. * <p>When {@link MPImage} is created, it has 1 reference count.
* *
* <p>When the reference count becomes 0, it will release the resource under the hood. * <p>When the reference count becomes 0, it will release the resource under the hood.
*/ */
@ -141,24 +141,24 @@ public class Image implements Closeable {
public synchronized void close() { public synchronized void close() {
referenceCount -= 1; referenceCount -= 1;
if (referenceCount == 0) { if (referenceCount == 0) {
for (ImageContainer imageContainer : containerMap.values()) { for (MPImageContainer imageContainer : containerMap.values()) {
imageContainer.close(); imageContainer.close();
} }
} }
} }
/** Advanced API access for {@link Image}. */ /** Advanced API access for {@link MPImage}. */
static final class Internal { static final class Internal {
/** /**
* Acquires a reference on this {@link Image}. This will increase the reference count by 1. * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
* *
* <p>This method is more useful for image consumer to acquire a reference so image resource * <p>This method is more useful for image consumer to acquire a reference so image resource
* will not be closed accidentally. As image creator, normal developer doesn't need to call this * will not be closed accidentally. As image creator, normal developer doesn't need to call this
* method. * method.
* *
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link * <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link Image} anymore. * #close()} to indicate it doesn't need this {@link MPImage} anymore.
* *
* @see #close() * @see #close()
*/ */
@ -166,10 +166,10 @@ public class Image implements Closeable {
image.acquire(); image.acquire();
} }
private final Image image; private final MPImage image;
// Only Image creates the internal helper. // Only MPImage creates the internal helper.
private Internal(Image image) { private Internal(MPImage image) {
this.image = image; this.image = image;
} }
} }
@ -179,15 +179,15 @@ public class Image implements Closeable {
return new Internal(this); return new Internal(this);
} }
private final Map<ImageProperties, ImageContainer> containerMap; private final Map<MPImageProperties, MPImageContainer> containerMap;
private final long timestamp; private final long timestamp;
private final int width; private final int width;
private final int height; private final int height;
private int referenceCount; private int referenceCount;
/** Constructs an {@link Image} with a built container. */ /** Constructs a {@link MPImage} with a built container. */
Image(ImageContainer container, long timestamp, int width, int height) { MPImage(MPImageContainer container, long timestamp, int width, int height) {
this.containerMap = new HashMap<>(); this.containerMap = new HashMap<>();
containerMap.put(container.getImageProperties(), container); containerMap.put(container.getImageProperties(), container);
this.timestamp = timestamp; this.timestamp = timestamp;
@ -201,10 +201,10 @@ public class Image implements Closeable {
* *
* @return the current container. * @return the current container.
*/ */
ImageContainer getContainer() { MPImageContainer getContainer() {
// According to the design, in the future we will support multiple containers in one image. // According to the design, in the future we will support multiple containers in one image.
// Currently just return the original container. // Currently just return the original container.
// TODO: Cache multiple containers in Image. // TODO: Cache multiple containers in MPImage.
return containerMap.values().iterator().next(); return containerMap.values().iterator().next();
} }
@ -214,8 +214,8 @@ public class Image implements Closeable {
* <p>If there are multiple containers with required {@code storageType}, returns the first one. * <p>If there are multiple containers with required {@code storageType}, returns the first one.
*/ */
@Nullable @Nullable
ImageContainer getContainer(@StorageType int storageType) { MPImageContainer getContainer(@StorageType int storageType) {
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
if (entry.getKey().getStorageType() == storageType) { if (entry.getKey().getStorageType() == storageType) {
return entry.getValue(); return entry.getValue();
} }
@ -225,13 +225,13 @@ public class Image implements Closeable {
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
@Nullable @Nullable
ImageContainer getContainer(ImageProperties imageProperties) { MPImageContainer getContainer(MPImageProperties imageProperties) {
return containerMap.get(imageProperties); return containerMap.get(imageProperties);
} }
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
boolean addContainer(ImageContainer container) { boolean addContainer(MPImageContainer container) {
ImageProperties imageProperties = container.getImageProperties(); MPImageProperties imageProperties = container.getImageProperties();
if (containerMap.containsKey(imageProperties)) { if (containerMap.containsKey(imageProperties)) {
return false; return false;
} }

View File

@ -14,14 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that can receive {@link Image} */ /** Lightweight abstraction for an object that can receive {@link MPImage} */
public interface ImageConsumer { public interface MPImageConsumer {
/** /**
* Called when an {@link Image} is available. * Called when a {@link MPImage} is available.
* *
* <p>The argument is only guaranteed to be available until this method returns. if you need to * <p>The argument is only guaranteed to be available until this method returns. if you need to
* extend its life time, acquire it, then release it when done. * extend its life time, acquire it, then release it when done.
*/ */
void onNewImage(Image image); void onNewMPImage(MPImage image);
} }

View File

@ -16,9 +16,9 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Manages internal image data storage. The interface is package-private. */ /** Manages internal image data storage. The interface is package-private. */
interface ImageContainer { interface MPImageContainer {
/** Returns the properties of the contained image. */ /** Returns the properties of the contained image. */
ImageProperties getImageProperties(); MPImageProperties getImageProperties();
/** Close the image container and releases the image resource inside. */ /** Close the image container and releases the image resource inside. */
void close(); void close();

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that produce {@link Image} */ /** Lightweight abstraction for an object that produce {@link MPImage} */
public interface ImageProducer { public interface MPImageProducer {
/** Sets the consumer that receives the {@link Image}. */ /** Sets the consumer that receives the {@link MPImage}. */
void setImageConsumer(ImageConsumer imageConsumer); void setMPImageConsumer(MPImageConsumer imageConsumer);
} }

View File

@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.auto.value.extension.memoized.Memoized; import com.google.auto.value.extension.memoized.Memoized;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import com.google.mediapipe.framework.image.Image.StorageType; import com.google.mediapipe.framework.image.MPImage.StorageType;
/** Groups a set of properties to describe how an image is stored. */ /** Groups a set of properties to describe how an image is stored. */
@AutoValue @AutoValue
public abstract class ImageProperties { public abstract class MPImageProperties {
/** /**
* Gets the pixel format of the image. * Gets the pixel format of the image.
* *
* @see Image.ImageFormat * @see MPImage.MPImageFormat
*/ */
@ImageFormat @MPImageFormat
public abstract int getImageFormat(); public abstract int getImageFormat();
/** /**
* Gets the storage type of the image. * Gets the storage type of the image.
* *
* @see Image.StorageType * @see MPImage.StorageType
*/ */
@StorageType @StorageType
public abstract int getStorageType(); public abstract int getStorageType();
@ -45,36 +45,36 @@ public abstract class ImageProperties {
public abstract int hashCode(); public abstract int hashCode();
/** /**
* Creates a builder of {@link ImageProperties}. * Creates a builder of {@link MPImageProperties}.
* *
* @see ImageProperties.Builder * @see MPImageProperties.Builder
*/ */
static Builder builder() { static Builder builder() {
return new AutoValue_ImageProperties.Builder(); return new AutoValue_MPImageProperties.Builder();
} }
/** Builds a {@link ImageProperties}. */ /** Builds a {@link MPImageProperties}. */
@AutoValue.Builder @AutoValue.Builder
abstract static class Builder { abstract static class Builder {
/** /**
* Sets the {@link Image.ImageFormat}. * Sets the {@link MPImage.MPImageFormat}.
* *
* @see ImageProperties#getImageFormat * @see MPImageProperties#getImageFormat
*/ */
abstract Builder setImageFormat(@ImageFormat int value); abstract Builder setImageFormat(@MPImageFormat int value);
/** /**
* Sets the {@link Image.StorageType}. * Sets the {@link MPImage.StorageType}.
* *
* @see ImageProperties#getStorageType * @see MPImageProperties#getStorageType
*/ */
abstract Builder setStorageType(@StorageType int value); abstract Builder setStorageType(@StorageType int value);
/** Builds the {@link ImageProperties}. */ /** Builds the {@link MPImageProperties}. */
abstract ImageProperties build(); abstract MPImageProperties build();
} }
// Hide the constructor. // Hide the constructor.
ImageProperties() {} MPImageProperties() {}
} }

View File

@ -15,11 +15,12 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
/** /**
* Builds {@link Image} from {@link android.media.Image}. * Builds {@link MPImage} from {@link android.media.Image}.
* *
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify * <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
* content in it. * content in it.
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
public class MediaImageBuilder { public class MediaImageBuilder {
// Mandatory fields. // Mandatory fields.
private final android.media.Image mediaImage; private final Image mediaImage;
// Optional fields. // Optional fields.
private long timestamp; private long timestamp;
@ -40,20 +41,20 @@ public class MediaImageBuilder {
* *
* @param mediaImage image data object. * @param mediaImage image data object.
*/ */
public MediaImageBuilder(android.media.Image mediaImage) { public MediaImageBuilder(Image mediaImage) {
this.mediaImage = mediaImage; this.mediaImage = mediaImage;
this.timestamp = 0; this.timestamp = 0;
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
MediaImageBuilder setTimestamp(long timestamp) { MediaImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image( return new MPImage(
new MediaImageContainer(mediaImage), new MediaImageContainer(mediaImage),
timestamp, timestamp,
mediaImage.getWidth(), mediaImage.getWidth(),

View File

@ -15,33 +15,34 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build; import android.os.Build;
import android.os.Build.VERSION; import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
@RequiresApi(VERSION_CODES.KITKAT) @RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer { class MediaImageContainer implements MPImageContainer {
private final android.media.Image mediaImage; private final Image mediaImage;
private final ImageProperties properties; private final MPImageProperties properties;
public MediaImageContainer(android.media.Image mediaImage) { public MediaImageContainer(Image mediaImage) {
this.mediaImage = mediaImage; this.mediaImage = mediaImage;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
.setImageFormat(convertFormatCode(mediaImage.getFormat())) .setImageFormat(convertFormatCode(mediaImage.getFormat()))
.build(); .build();
} }
public android.media.Image getImage() { public Image getImage() {
return mediaImage; return mediaImage;
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
mediaImage.close(); mediaImage.close();
} }
@ImageFormat @MPImageFormat
static int convertFormatCode(int graphicsFormat) { static int convertFormatCode(int graphicsFormat) {
// We only cover the format mentioned in // We only cover the format mentioned in
// https://developer.android.com/reference/android/media/Image#getFormat() // https://developer.android.com/reference/android/media/Image#getFormat()
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
return Image.IMAGE_FORMAT_RGB; return MPImage.IMAGE_FORMAT_RGB;
} }
} }
switch (graphicsFormat) { switch (graphicsFormat) {
case android.graphics.ImageFormat.JPEG: case android.graphics.ImageFormat.JPEG:
return Image.IMAGE_FORMAT_JPEG; return MPImage.IMAGE_FORMAT_JPEG;
case android.graphics.ImageFormat.YUV_420_888: case android.graphics.ImageFormat.YUV_420_888:
return Image.IMAGE_FORMAT_YUV_420_888; return MPImage.IMAGE_FORMAT_YUV_420_888;
default: default:
return Image.IMAGE_FORMAT_UNKNOWN; return MPImage.IMAGE_FORMAT_UNKNOWN;
} }
} }
} }

View File

@ -15,13 +15,14 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
/** /**
* Utility for extracting {@link android.media.Image} from {@link Image}. * Utility for extracting {@link android.media.Image} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
* otherwise {@link IllegalArgumentException} will be thrown. * otherwise {@link IllegalArgumentException} will be thrown.
*/ */
@RequiresApi(VERSION_CODES.KITKAT) @RequiresApi(VERSION_CODES.KITKAT)
@ -30,20 +31,20 @@ public class MediaImageExtractor {
private MediaImageExtractor() {} private MediaImageExtractor() {}
/** /**
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
* {@link Image} that built from {@link MediaImageBuilder}. * {@link MPImage} that built from {@link MediaImageBuilder}.
* *
* @param image the image to extract {@link android.media.Image} from. * @param image the image to extract {@link android.media.Image} from.
* @return {@link android.media.Image} that stored in {@link Image}. * @return {@link android.media.Image} that stored in {@link MPImage}.
* @throws IllegalArgumentException if the extraction failed. * @throws IllegalArgumentException if the extraction failed.
*/ */
public static android.media.Image extract(Image image) { public static Image extract(MPImage image) {
ImageContainer container; MPImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
return ((MediaImageContainer) container).getImage(); return ((MediaImageContainer) container).getImage();
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract Media Image from an Image created by objects other than Media Image" "Extract Media Image from a MPImage created by objects other than Media Image"
+ " is not supported"); + " is not supported");
} }
} }

View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors. # Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -209,9 +209,9 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
def mediapipe_build_aar_with_jni(name, android_library): def mediapipe_build_aar_with_jni(name, android_library):
"""Builds MediaPipe AAR with jni. """Builds MediaPipe AAR with jni.
Args: Args:
name: The bazel target name. name: The bazel target name.
android_library: the android library that contains jni. android_library: the android library that contains jni.
""" """
# Generates dummy AndroidManifest.xml for dummy apk usage # Generates dummy AndroidManifest.xml for dummy apk usage
@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""):
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java", src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite", target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""):
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor( proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite", target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
)) ))
return proto_src_list return proto_src_list

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
package( package(
default_visibility = ["//mediapipe:__subpackages__"], default_visibility = ["//mediapipe:__subpackages__"],

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"]) licenses(["notice"])
@ -23,15 +24,12 @@ package(
py_library( py_library(
name = "data_util", name = "data_util",
srcs = ["data_util.py"], srcs = ["data_util.py"],
srcs_version = "PY3",
) )
py_test( py_test(
name = "data_util_test", name = "data_util_test",
srcs = ["data_util_test.py"], srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"], data = ["//mediapipe/model_maker/python/core/data/testdata"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":data_util"], deps = [":data_util"],
) )
@ -44,8 +42,6 @@ py_library(
py_test( py_test(
name = "dataset_test", name = "dataset_test",
srcs = ["dataset_test.py"], srcs = ["dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":dataset", ":dataset",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",
@ -55,14 +51,11 @@ py_test(
py_library( py_library(
name = "classification_dataset", name = "classification_dataset",
srcs = ["classification_dataset.py"], srcs = ["classification_dataset.py"],
srcs_version = "PY3",
deps = [":dataset"], deps = [":dataset"],
) )
py_test( py_test(
name = "classification_dataset_test", name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"], srcs = ["classification_dataset_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":classification_dataset"], deps = [":classification_dataset"],
) )

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
package( package(
default_visibility = ["//mediapipe:__subpackages__"], default_visibility = ["//mediapipe:__subpackages__"],
@ -23,7 +24,6 @@ licenses(["notice"])
py_library( py_library(
name = "custom_model", name = "custom_model",
srcs = ["custom_model.py"], srcs = ["custom_model.py"],
srcs_version = "PY3",
deps = [ deps = [
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
@ -34,8 +34,6 @@ py_library(
py_test( py_test(
name = "custom_model_test", name = "custom_model_test",
srcs = ["custom_model_test.py"], srcs = ["custom_model_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":custom_model", ":custom_model",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",
@ -45,7 +43,6 @@ py_test(
py_library( py_library(
name = "classifier", name = "classifier",
srcs = ["classifier.py"], srcs = ["classifier.py"],
srcs_version = "PY3",
deps = [ deps = [
":custom_model", ":custom_model",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
@ -55,8 +52,6 @@ py_library(
py_test( py_test(
name = "classifier_test", name = "classifier_test",
srcs = ["classifier_test.py"], srcs = ["classifier_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":classifier", ":classifier",
"//mediapipe/model_maker/python/core/utils:test_util", "//mediapipe/model_maker/python/core/utils:test_util",

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"]) licenses(["notice"])
@ -24,7 +25,6 @@ py_library(
name = "test_util", name = "test_util",
testonly = 1, testonly = 1,
srcs = ["test_util.py"], srcs = ["test_util.py"],
srcs_version = "PY3",
deps = [ deps = [
":model_util", ":model_util",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
@ -34,7 +34,6 @@ py_library(
py_library( py_library(
name = "model_util", name = "model_util",
srcs = ["model_util.py"], srcs = ["model_util.py"],
srcs_version = "PY3",
deps = [ deps = [
":quantization", ":quantization",
"//mediapipe/model_maker/python/core/data:dataset", "//mediapipe/model_maker/python/core/data:dataset",
@ -44,8 +43,6 @@ py_library(
py_test( py_test(
name = "model_util_test", name = "model_util_test",
srcs = ["model_util_test.py"], srcs = ["model_util_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":model_util", ":model_util",
":quantization", ":quantization",
@ -62,8 +59,6 @@ py_library(
py_test( py_test(
name = "loss_functions_test", name = "loss_functions_test",
srcs = ["loss_functions_test.py"], srcs = ["loss_functions_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":loss_functions"], deps = [":loss_functions"],
) )
@ -77,8 +72,6 @@ py_library(
py_test( py_test(
name = "quantization_test", name = "quantization_test",
srcs = ["quantization_test.py"], srcs = ["quantization_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [ deps = [
":quantization", ":quantization",
":test_util", ":test_util",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python strict library compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro. # Placeholder for internal Python strict test compatibility macro.
licenses(["notice"]) licenses(["notice"])

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Placeholder for internal Python library rule.
# Placeholder for internal Python strict library and test compatibility macro. # Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python library rule.
licenses(["notice"]) licenses(["notice"])

View File

@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
return model.fit( return model.fit(
x=train_ds, x=train_ds,
epochs=hparams.train_epochs, epochs=hparams.train_epochs,
steps_per_epoch=hparams.steps_per_epoch,
validation_data=validation_ds, validation_data=validation_ds,
callbacks=callbacks) callbacks=callbacks)

View File

@ -161,7 +161,7 @@ class Texture {
~Texture() { ~Texture() {
if (is_owned_) { if (is_owned_) {
glDeleteProgram(handle_); glDeleteTextures(1, &handle_);
} }
} }

View File

@ -87,6 +87,7 @@ cc_library(
cc_library( cc_library(
name = "builtin_task_graphs", name = "builtin_task_graphs",
deps = [ deps = [
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
], ],

View File

@ -14,7 +14,7 @@
"""The public facing packet getter APIs.""" """The public facing packet getter APIs."""
from typing import List, Type from typing import List
from google.protobuf import message from google.protobuf import message
from google.protobuf import symbol_database from google.protobuf import symbol_database
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
get_matrix = _packet_getter.get_matrix get_matrix = _packet_getter.get_matrix
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]: def get_proto(packet: mp_packet.Packet) -> message.Message:
"""Get the content of a MediaPipe proto Packet as a proto message. """Get the content of a MediaPipe proto Packet as a proto message.
Args: Args:

View File

@ -46,8 +46,10 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",

View File

@ -17,7 +17,7 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.container.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto"; option java_outer_classname = "CategoryProto";
// A single classification result. // A single classification result.

View File

@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/proto/category.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.container.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto"; option java_outer_classname = "ClassificationsProto";
// List of predicted categories with an optional timestamp. // List of predicted categories with an optional timestamp.

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "EmbeddingsProto";
// Defines a dense floating-point embedding. // Defines a dense floating-point embedding.
message FloatEmbedding { message FloatEmbedding {
repeated float values = 1 [packed = true]; repeated float values = 1 [packed = true];

View File

@ -30,9 +30,11 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator(
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) / options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
std); std);
} }
// TODO: need to support different GPU origin on differnt
// platforms or applications.
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
return absl::OkStatus(); return absl::OkStatus();
} }
} // namespace } // namespace
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration) {
return acceleration.has_gpu();
}
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
bool use_gpu,
ImagePreprocessingOptions* options) { ImagePreprocessingOptions* options) {
ASSIGN_OR_RETURN(auto image_tensor_specs, ASSIGN_OR_RETURN(auto image_tensor_specs,
BuildImageTensorSpecs(model_resources)); BuildImageTensorSpecs(model_resources));
@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
image_tensor_specs, options->mutable_image_to_tensor_options())); image_tensor_specs, options->mutable_image_to_tensor_options()));
// The GPU backend isn't able to process int data. If the input tensor is // The GPU backend isn't able to process int data. If the input tensor is
// quantized, forces the image preprocessing graph to use CPU backend. // quantized, forces the image preprocessing graph to use CPU backend.
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) { if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
} else {
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -19,20 +19,26 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
// Configures an ImagePreprocessing subgraph using the provided model resources. // Configures an ImagePreprocessing subgraph using the provided model resources
// When use_gpu is true, use GPU as backend to convert image to tensor.
// - Accepts CPU input images and outputs CPU tensors. // - Accepts CPU input images and outputs CPU tensors.
// //
// Example usage: // Example usage:
// //
// auto& preprocessing = // auto& preprocessing =
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
// core::proto::Acceleration acceleration;
// acceleration.mutable_xnnpack();
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( // MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
// model_resources, // model_resources,
// use_gpu,
// &preprocessing.GetOptions<ImagePreprocessingOptions>())); // &preprocessing.GetOptions<ImagePreprocessingOptions>()));
// //
// The resulting ImagePreprocessing subgraph has the following I/O: // The resulting ImagePreprocessing subgraph has the following I/O:
@ -56,9 +62,14 @@ namespace components {
// The image that has the pixel data stored on the target storage (CPU vs // The image that has the pixel data stored on the target storage (CPU vs
// GPU). // GPU).
absl::Status ConfigureImagePreprocessing( absl::Status ConfigureImagePreprocessing(
const core::ModelResources& model_resources, const core::ModelResources& model_resources, bool use_gpu,
ImagePreprocessingOptions* options); ImagePreprocessingOptions* options);
// Determine if the image preprocessing subgraph should use GPU as the backend
// according to the given acceleration setting.
bool DetermineImagePreprocessingGpuBackend(
const core::proto::Acceleration& acceleration);
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -156,21 +156,24 @@ absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
} }
absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources( absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) { SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix) {
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
if (!model_resources_cache_service.IsAvailable()) { if (!model_resources_cache_service.IsAvailable()) {
ASSIGN_OR_RETURN(local_model_resources_, ASSIGN_OR_RETURN(auto local_model_resource,
ModelResources::Create("", std::move(external_file))); ModelResources::Create("", std::move(external_file)));
LOG(WARNING) LOG(WARNING)
<< "A local ModelResources object is created. Please consider using " << "A local ModelResources object is created. Please consider using "
"ModelResourcesCacheService to cache the created ModelResources " "ModelResourcesCacheService to cache the created ModelResources "
"object in the CalculatorGraph."; "object in the CalculatorGraph.";
return local_model_resources_.get(); local_model_resources_.push_back(std::move(local_model_resource));
return local_model_resources_.back().get();
} }
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto op_resolver_packet, auto op_resolver_packet,
model_resources_cache_service.GetObject().GetGraphOpResolverPacket()); model_resources_cache_service.GetObject().GetGraphOpResolverPacket());
const std::string tag = CreateModelResourcesTag(sc->OriginalNode()); const std::string tag =
absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix);
ASSIGN_OR_RETURN(auto model_resources, ASSIGN_OR_RETURN(auto model_resources,
ModelResources::Create(tag, std::move(external_file), ModelResources::Create(tag, std::move(external_file),
op_resolver_packet)); op_resolver_packet));
@ -182,7 +185,8 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
absl::StatusOr<const ModelAssetBundleResources*> absl::StatusOr<const ModelAssetBundleResources*>
ModelTaskGraph::CreateModelAssetBundleResources( ModelTaskGraph::CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) { SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix) {
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
bool has_file_pointer_meta = external_file->has_file_pointer_meta(); bool has_file_pointer_meta = external_file->has_file_pointer_meta();
// if external file is set by file pointer, no need to add the model asset // if external file is set by file pointer, no need to add the model asset
@ -190,7 +194,7 @@ ModelTaskGraph::CreateModelAssetBundleResources(
// not owned by this model asset bundle resources. // not owned by this model asset bundle resources.
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) { if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
local_model_asset_bundle_resources_, auto local_model_asset_bundle_resource,
ModelAssetBundleResources::Create("", std::move(external_file))); ModelAssetBundleResources::Create("", std::move(external_file)));
if (!has_file_pointer_meta) { if (!has_file_pointer_meta) {
LOG(WARNING) LOG(WARNING)
@ -198,10 +202,12 @@ ModelTaskGraph::CreateModelAssetBundleResources(
"ModelResourcesCacheService to cache the created ModelResources " "ModelResourcesCacheService to cache the created ModelResources "
"object in the CalculatorGraph."; "object in the CalculatorGraph.";
} }
return local_model_asset_bundle_resources_.get(); local_model_asset_bundle_resources_.push_back(
std::move(local_model_asset_bundle_resource));
return local_model_asset_bundle_resources_.back().get();
} }
const std::string tag = const std::string tag = absl::StrCat(
CreateModelAssetBundleResourcesTag(sc->OriginalNode()); CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix);
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(tag, std::move(external_file))); ModelAssetBundleResources::Create(tag, std::move(external_file)));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
@ -75,9 +76,14 @@ class ModelTaskGraph : public Subgraph {
// construction stage. Note that the external file contents will be moved // construction stage. Note that the external file contents will be moved
// into the model resources object on creation. The returned model resources // into the model resources object on creation. The returned model resources
// pointer will provide graph authors with the access to the metadata // pointer will provide graph authors with the access to the metadata
// extractor and the tflite model. // extractor and the tflite model. When the model resources graph service is
// available, a tag is generated internally asscoiated with the created model
// resource. If more than one model resources are created in a graph, the
// model resources graph service add the tag_suffix to support multiple
// resources.
absl::StatusOr<const ModelResources*> CreateModelResources( absl::StatusOr<const ModelResources*> CreateModelResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file); SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix = "");
// If the model resources graph service is available, creates a model asset // If the model resources graph service is available, creates a model asset
// bundle resources object from the subgraph context, and caches the created // bundle resources object from the subgraph context, and caches the created
@ -103,10 +109,15 @@ class ModelTaskGraph : public Subgraph {
// that can only be used in the graph construction stage. Note that the // that can only be used in the graph construction stage. Note that the
// external file contents will be moved into the model asset bundle resources // external file contents will be moved into the model asset bundle resources
// object on creation. The returned model asset bundle resources pointer will // object on creation. The returned model asset bundle resources pointer will
// provide graph authors with the access to extracted model files. // provide graph authors with the access to extracted model files. When the
// model resources graph service is available, a tag is generated internally
// asscoiated with the created model asset bundle resource. If more than one
// model asset bundle resources are created in a graph, the model resources
// graph service add the tag_suffix to support multiple resources.
absl::StatusOr<const ModelAssetBundleResources*> absl::StatusOr<const ModelAssetBundleResources*>
CreateModelAssetBundleResources( CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file); SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
const std::string tag_suffix = "");
// Inserts a mediapipe task inference subgraph into the provided // Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the // GraphBuilder. The returned node provides the following interfaces to the
@ -124,9 +135,9 @@ class ModelTaskGraph : public Subgraph {
api2::builder::Graph& graph) const; api2::builder::Graph& graph) const;
private: private:
std::unique_ptr<ModelResources> local_model_resources_; std::vector<std::unique_ptr<ModelResources>> local_model_resources_;
std::unique_ptr<ModelAssetBundleResources> std::vector<std::unique_ptr<ModelAssetBundleResources>>
local_model_asset_bundle_resources_; local_model_asset_bundle_resources_;
}; };

View File

@ -63,6 +63,29 @@ cc_library(
], ],
) )
cc_test(
name = "text_classifier_test",
srcs = ["text_classifier_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
deps = [
":text_classifier",
":text_classifier_test_utils",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
cc_library( cc_library(
name = "text_classifier_test_utils", name = "text_classifier_test_utils",
srcs = ["text_classifier_test_utils.cc"], srcs = ["text_classifier_test_utils.cc"],

View File

@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
using ::testing::proto::Approximately;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
using ::testing::proto::Partially;
constexpr float kEpsilon = 0.001; constexpr float kEpsilon = 0.001;
constexpr int kMaxSeqLen = 128; constexpr int kMaxSeqLen = 128;
@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
MP_ASSERT_OK(TextClassifier::Create(std::move(options))); MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
} }
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.956 }
categories { category_name: "positive" score: 0.044 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.0 }
categories { category_name: "positive" score: 1.0 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
classifier->Classify("What a waste of my time."));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.813 }
categories { category_name: "Positive" score: 0.187 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years. "
"Strongly recommend it!"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.487 }
categories { category_name: "Positive" score: 0.513 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify("hello"));
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
classifications {
entries {
categories { index: 1 score: 1 }
categories { index: 0 score: 1 }
categories { index: 2 score: 0 }
}
}
)pb"))));
}
TEST_F(TextClassifierTest, BertLongPositive) {
std::stringstream ss_for_positive_review;
ss_for_positive_review
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kMaxSeqLen; ++i) {
ss_for_positive_review << " long";
}
ss_for_positive_review << " movie review";
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify(ss_for_positive_review.str()));
ASSERT_THAT(result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.014 }
categories { category_name: "positive" score: 0.986 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
} // namespace } // namespace
} // namespace text_classifier } // namespace text_classifier
} // namespace text } // namespace text

View File

@ -73,7 +73,18 @@ cc_library(
], ],
) )
# TODO: This test fails in OSS cc_test(
name = "sentencepiece_tokenizer_test",
srcs = ["sentencepiece_tokenizer_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:albert_model",
],
deps = [
":sentencepiece_tokenizer",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/core:utils",
],
)
cc_library( cc_library(
name = "tokenizer_utils", name = "tokenizer_utils",
@ -97,7 +108,32 @@ cc_library(
], ],
) )
# TODO: This test fails in OSS cc_test(
name = "tokenizer_utils_test",
srcs = ["tokenizer_utils_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:albert_model",
"//mediapipe/tasks/testdata/text:mobile_bert_model",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
linkopts = ["-ldl"],
deps = [
":bert_tokenizer",
":regex_tokenizer",
":sentencepiece_tokenizer",
":tokenizer_utils",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
],
)
cc_library( cc_library(
name = "regex_tokenizer", name = "regex_tokenizer",

View File

@ -21,12 +21,23 @@ cc_library(
hdrs = ["running_mode.h"], hdrs = ["running_mode.h"],
) )
cc_library(
name = "image_processing_options",
hdrs = ["image_processing_options.h"],
deps = [
"//mediapipe/tasks/cc/components/containers:rect",
],
)
cc_library( cc_library(
name = "base_vision_task_api", name = "base_vision_task_api",
hdrs = ["base_vision_task_api.h"], hdrs = ["base_vision_task_api.h"],
deps = [ deps = [
":image_processing_options",
":running_mode", ":running_mode",
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:rect",
"//mediapipe/tasks/cc/core:base_task_api", "//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",

View File

@ -16,15 +16,20 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ #ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_
#define MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ #define MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_
#include <cmath>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <utility> #include <utility>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe { namespace mediapipe {
@ -87,6 +92,60 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
return runner_->Send(std::move(inputs)); return runner_->Send(std::move(inputs));
} }
// Convert from ImageProcessingOptions to NormalizedRect, performing sanity
// checks on-the-fly. If the input ImageProcessingOptions is not present,
// returns a default NormalizedRect covering the whole image with rotation set
// to 0. If 'roi_allowed' is false, an error will be returned if the input
// ImageProcessingOptions has its 'region_or_interest' field set.
static absl::StatusOr<mediapipe::NormalizedRect> ConvertToNormalizedRect(
std::optional<ImageProcessingOptions> options, bool roi_allowed = true) {
mediapipe::NormalizedRect normalized_rect;
normalized_rect.set_rotation(0);
normalized_rect.set_x_center(0.5);
normalized_rect.set_y_center(0.5);
normalized_rect.set_width(1.0);
normalized_rect.set_height(1.0);
if (!options.has_value()) {
return normalized_rect;
}
if (options->rotation_degrees % 90 != 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Expected rotation to be a multiple of 90°.",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
}
// Convert to radians counter-clockwise.
normalized_rect.set_rotation(-options->rotation_degrees * M_PI / 180.0);
if (options->region_of_interest.has_value()) {
if (!roi_allowed) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"This task doesn't support region-of-interest.",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
}
auto& roi = *options->region_of_interest;
if (roi.left >= roi.right || roi.top >= roi.bottom) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Expected Rect with left < right and top < bottom.",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
}
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"Expected Rect values to be in [0,1].",
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
}
normalized_rect.set_x_center((roi.left + roi.right) / 2.0);
normalized_rect.set_y_center((roi.top + roi.bottom) / 2.0);
normalized_rect.set_width(roi.right - roi.left);
normalized_rect.set_height(roi.bottom - roi.top);
}
return normalized_rect;
}
private: private:
RunningMode running_mode_; RunningMode running_mode_;
}; };

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_
#include <optional>
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace core {
// Options for image processing.
//
// If both region-or-interest and rotation are specified, the crop around the
// region-of-interest is extracted first, the the specified rotation is applied
// to the crop.
struct ImageProcessingOptions {
// The optional region-of-interest to crop from the image. If not specified,
// the full image is used.
//
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
std::optional<components::containers::Rect> region_of_interest = std::nullopt;
// The rotation to apply to the image (or cropped region-of-interest), in
// degrees clockwise.
//
// The rotation must be a multiple (positive or negative) of 90°.
int rotation_degrees = 0;
};
} // namespace core
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_

View File

@ -62,13 +62,19 @@ cc_library(
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator",
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
@ -93,10 +99,14 @@ cc_library(
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
@ -137,8 +147,10 @@ cc_library(
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",

View File

@ -93,3 +93,46 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
mediapipe_proto_library(
name = "combined_prediction_calculator_proto",
srcs = ["combined_prediction_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "combined_prediction_calculator",
srcs = ["combined_prediction_calculator.cc"],
deps = [
":combined_prediction_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)
cc_test(
name = "combined_prediction_calculator_test",
srcs = ["combined_prediction_calculator_test.cc"],
deps = [
":combined_prediction_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/strings",
],
)

View File

@ -0,0 +1,187 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/btree_map.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
constexpr char kPredictionTag[] = "PREDICTION";
Classification GetMaxScoringClassification(
const ClassificationList& classifications) {
Classification max_classification;
max_classification.set_score(0);
for (const auto& input : classifications.classification()) {
if (max_classification.score() < input.score()) {
max_classification = input;
}
}
return max_classification;
}
float GetScoreThreshold(
const std::string& input_label,
const absl::btree_map<std::string, float>& classwise_thresholds,
const std::string& background_label, const float default_threshold) {
float threshold = default_threshold;
auto it = classwise_thresholds.find(input_label);
if (it != classwise_thresholds.end()) {
threshold = it->second;
}
return threshold;
}
std::unique_ptr<ClassificationList> GetWinningPrediction(
const ClassificationList& classification_list,
const absl::btree_map<std::string, float>& classwise_thresholds,
const std::string& background_label, const float default_threshold) {
auto prediction_list = std::make_unique<ClassificationList>();
if (classification_list.classification().empty()) {
return prediction_list;
}
Classification& prediction = *prediction_list->add_classification();
auto argmax_prediction = GetMaxScoringClassification(classification_list);
float argmax_prediction_thresh =
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
background_label, default_threshold);
if (argmax_prediction.score() >= argmax_prediction_thresh) {
prediction.set_label(argmax_prediction.label());
prediction.set_score(argmax_prediction.score());
} else {
for (const auto& input : classification_list.classification()) {
if (input.label() == background_label) {
prediction.set_label(input.label());
prediction.set_score(input.score());
break;
}
}
}
return prediction_list;
}
} // namespace
// This calculator accepts multiple ClassificationList input streams. Each
// ClassificationList should contain classifications with labels and
// corresponding softmax scores. The calculator computes the best prediction for
// each ClassificationList input stream via argmax and thresholding. Thresholds
// for all classes can be specified in the
// `CombinedPredictionCalculatorOptions`, along with a default global
// threshold.
// Please note that for this calculator to work as designed, the class names
// other than the background class in the ClassificationList objects must be
// different, but the background class name has to be the same. This background
// label name can be set via `background_label` in
// `CombinedPredictionCalculatorOptions`.
// The ClassificationList in the PREDICTION output stream contains the label of
// the winning class and corresponding softmax score. If none of the
// ClassificationList objects has a non-background winning class, the output
// contains the background class and score of the background class in the first
// ClassificationList. If multiple ClassificationList objects have a
// non-background winning class, the output contains the winning prediction from
// the ClassificationList with the highest priority. Priority is in decreasing
// order of input streams to the graph node using this calculator.
// Input:
// At least one stream with ClassificationList.
// Output:
// PREDICTION - A ClassificationList with the winning label as the only item.
//
// Usage example:
// node {
// calculator: "CombinedPredictionCalculator"
// input_stream: "classification_list_0"
// input_stream: "classification_list_1"
// output_stream: "PREDICTION:prediction"
// options {
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
// class {
// label: "A"
// score_threshold: 0.7
// }
// default_global_threshold: 0.1
// background_label: "B"
// }
// }
// }
class CombinedPredictionCalculator : public Node {
public:
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
""};
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
absl::Status Open(CalculatorContext* cc) override {
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
for (const auto& input : options_.class_()) {
classwise_thresholds_[input.label()] = input.score_threshold();
}
classwise_thresholds_[options_.background_label()] = 0;
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
// After loop, if have winning prediction return. Otherwise empty packet.
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
auto collection = kClassificationListIn(cc);
for (int idx = 0; idx < collection.Count(); ++idx) {
const auto& packet = collection[idx];
if (packet.IsEmpty()) {
continue;
}
auto prediction = GetWinningPrediction(
packet.Get(), classwise_thresholds_, options_.background_label(),
options_.default_global_threshold());
if (prediction->classification(0).label() !=
options_.background_label()) {
kPredictionOut(cc).Send(std::move(prediction));
return absl::OkStatus();
}
if (first_winning_prediction == nullptr) {
first_winning_prediction = std::move(prediction);
}
}
if (first_winning_prediction != nullptr) {
kPredictionOut(cc).Send(std::move(first_winning_prediction));
}
return absl::OkStatus();
}
private:
CombinedPredictionCalculatorOptions options_;
absl::btree_map<std::string, float> classwise_thresholds_;
};
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,41 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message CombinedPredictionCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional CombinedPredictionCalculatorOptions ext = 483738635;
}
message Class {
optional string label = 1;
optional float score_threshold = 2;
}
// List of classes with score thresholds.
repeated Class class = 1;
// Default score threshold applied to a label.
optional float default_global_threshold = 2 [default = 0];
// Name of the background class whose input scores will be ignored while
// thresholding.
optional string background_label = 3;
}

View File

@ -0,0 +1,315 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kPredictionTag[] = "PREDICTION";
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
float drama_thresh, float llama_thresh, float bazinga_thresh,
float joy_thresh, float peace_thresh) {
constexpr absl::string_view kCalculatorProto = R"pb(
calculator: "CombinedPredictionCalculator"
input_stream: "custom_softmax_scores"
input_stream: "canned_softmax_scores"
output_stream: "PREDICTION:prediction"
options {
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
class { label: "CustomDrama" score_threshold: $0 }
class { label: "CustomLlama" score_threshold: $1 }
class { label: "CannedBazinga" score_threshold: $2 }
class { label: "CannedJoy" score_threshold: $3 }
class { label: "CannedPeace" score_threshold: $4 }
background_label: "Negative"
}
}
)pb";
auto runner = std::make_unique<CalculatorRunner>(
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
bazinga_thresh, joy_thresh, peace_thresh));
return runner;
}
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
const float negative_score, const float drama_score,
const float llama_score) {
auto custom_scores = std::make_unique<ClassificationList>();
auto custom_negative = custom_scores->add_classification();
custom_negative->set_label("Negative");
custom_negative->set_score(negative_score);
auto drama = custom_scores->add_classification();
drama->set_label("CustomDrama");
drama->set_score(drama_score);
auto llama = custom_scores->add_classification();
llama->set_label("CustomLlama");
llama->set_score(llama_score);
return custom_scores;
}
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
const float negative_score, const float bazinga_score,
const float joy_score, const float peace_score) {
auto canned_scores = std::make_unique<ClassificationList>();
auto canned_negative = canned_scores->add_classification();
canned_negative->set_label("Negative");
canned_negative->set_score(negative_score);
auto bazinga = canned_scores->add_classification();
bazinga->set_label("CannedBazinga");
bazinga->set_score(bazinga_score);
auto joy = canned_scores->add_classification();
joy->set_label("CannedJoy");
joy->set_score(joy_score);
auto peace = canned_scores->add_classification();
peace->set_label("CannedPeace");
peace->set_score(peace_score);
return canned_scores;
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomEmpty_CannedEmpty_ResultIsEmpty) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
auto canned_scores = BuildCannedScoreInput(
/*negative_score=*/0.1,
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
runner->MutableInputs()->Index(1).packets.push_back(
Adopt(canned_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), "CannedJoy");
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
}
TEST(CombinedPredictionCalculatorPacketTest,
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
auto runner = BuildNodeRunnerWithOptions(
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
auto custom_scores =
BuildCustomScoreInput(/*negative_score=*/0.1,
/*drama_score=*/0.2, /*llama_score=*/0.7);
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(custom_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), "CustomLlama");
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
}
struct CombinedPredictionCalculatorTestCase {
std::string test_name;
float custom_negative_score;
float drama_score;
float llama_score;
float drama_thresh;
float llama_thresh;
float canned_negative_score;
float bazinga_score;
float joy_score;
float peace_score;
float bazinga_thresh;
float joy_thresh;
float peace_thresh;
std::string max_scoring_label;
float max_score;
};
using CombinedPredictionCalculatorTest =
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
auto runner = BuildNodeRunnerWithOptions(
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
test_case.joy_thresh, test_case.peace_thresh);
auto custom_scores =
BuildCustomScoreInput(test_case.custom_negative_score,
test_case.drama_score, test_case.llama_score);
runner->MutableInputs()->Index(0).packets.push_back(
Adopt(custom_scores.release()).At(Timestamp(1)));
auto canned_scores = BuildCannedScoreInput(
test_case.canned_negative_score, test_case.bazinga_score,
test_case.joy_score, test_case.peace_score);
runner->MutableInputs()->Index(1).packets.push_back(
Adopt(canned_scores.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
auto output_prediction_packets =
runner->Outputs().Tag(kPredictionTag).packets;
ASSERT_EQ(output_prediction_packets.size(), 1);
Classification output_prediction =
output_prediction_packets[0].Get<ClassificationList>().classification(0);
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
}
INSTANTIATE_TEST_CASE_P(
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
{
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
.custom_negative_score = 0.1,
.drama_score = 0.5,
.llama_score = 0.3,
.drama_thresh = 0.25,
.llama_thresh = 0.7,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "CustomDrama",
.max_score = 0.5,
},
{
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
.custom_negative_score = 0.1,
.drama_score = 0.3,
.llama_score = 0.6,
.drama_thresh = 0.4,
.llama_thresh = 0.8,
.canned_negative_score = 0.1,
.bazinga_score = 0.4,
.joy_score = 0.3,
.peace_score = 0.2,
.bazinga_thresh = 0.0,
.joy_thresh = 0.0,
.peace_thresh = 0.0,
.max_scoring_label = "CannedBazinga",
.max_score = 0.4,
},
{
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
.custom_negative_score = 0.5,
.drama_score = 0.1,
.llama_score = 0.4,
.drama_thresh = 0.1,
.llama_thresh = 0.05,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.5,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
.custom_negative_score = 0.8,
.drama_score = 0.1,
.llama_score = 0.1,
.drama_thresh = 0.25,
.llama_thresh = 0.7,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.8,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
.custom_negative_score = 0.1,
.drama_score = 0.2,
.llama_score = 0.7,
.drama_thresh = 1.1,
.llama_thresh = 1.1,
.canned_negative_score = 0.1,
.bazinga_score = 0.3,
.joy_score = 0.3,
.peace_score = 0.3,
.bazinga_thresh = 0.7,
.joy_thresh = 0.7,
.peace_thresh = 0.7,
.max_scoring_label = "Negative",
.max_score = 0.1,
},
{
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
.custom_negative_score = 0.1,
.drama_score = 0.3,
.llama_score = 0.6,
.drama_thresh = 0.4,
.llama_thresh = 0.8,
.canned_negative_score = 0.3,
.bazinga_score = 0.2,
.joy_score = 0.3,
.peace_score = 0.2,
.bazinga_thresh = 0.5,
.joy_thresh = 0.5,
.peace_thresh = 0.5,
.max_scoring_label = "Negative",
.max_score = 0.1,
},
}),
[](const testing::TestParamInfo<
CombinedPredictionCalculatorTest::ParamType>& info) {
return info.param.test_name;
});
} // namespace
} // namespace mediapipe

View File

@ -39,7 +39,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
@ -76,31 +78,6 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
// Returns a NormalizedRect filling the whole image. If input is present, its
// rotation is set in the returned NormalizedRect and a check is performed to
// make sure no region-of-interest was provided. Otherwise, rotation is set to
// 0.
absl::StatusOr<NormalizedRect> FillNormalizedRect(
std::optional<NormalizedRect> normalized_rect) {
NormalizedRect result;
if (normalized_rect.has_value()) {
result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (has_coordinates) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GestureRecognizer does not support region-of-interest.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
return result;
}
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running // "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running
// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the // in the live stream mode, a "FlowLimiterCalculator" will be added to limit the
@ -136,57 +113,38 @@ CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<GestureRecognizerGraphOptionsProto> std::unique_ptr<GestureRecognizerGraphOptionsProto>
ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>(); auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE;
// TODO remove these workarounds for base options of subgraphs.
// Configure hand detector options. // Configure hand detector options.
auto base_options_proto_for_hand_detector =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_hand_detector)));
base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode);
auto* hand_detector_graph_options = auto* hand_detector_graph_options =
options_proto->mutable_hand_landmarker_graph_options() options_proto->mutable_hand_landmarker_graph_options()
->mutable_hand_detector_graph_options(); ->mutable_hand_detector_graph_options();
hand_detector_graph_options->mutable_base_options()->Swap(
base_options_proto_for_hand_detector.get());
hand_detector_graph_options->set_num_hands(options->num_hands); hand_detector_graph_options->set_num_hands(options->num_hands);
hand_detector_graph_options->set_min_detection_confidence( hand_detector_graph_options->set_min_detection_confidence(
options->min_hand_detection_confidence); options->min_hand_detection_confidence);
// Configure hand landmark detector options. // Configure hand landmark detector options.
auto base_options_proto_for_hand_landmarker =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_hand_landmarker)));
base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode);
auto* hand_landmarks_detector_graph_options =
options_proto->mutable_hand_landmarker_graph_options()
->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->mutable_base_options()->Swap(
base_options_proto_for_hand_landmarker.get());
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
auto* hand_landmarker_graph_options = auto* hand_landmarker_graph_options =
options_proto->mutable_hand_landmarker_graph_options(); options_proto->mutable_hand_landmarker_graph_options();
hand_landmarker_graph_options->set_min_tracking_confidence( hand_landmarker_graph_options->set_min_tracking_confidence(
options->min_tracking_confidence); options->min_tracking_confidence);
auto* hand_landmarks_detector_graph_options =
hand_landmarker_graph_options
->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
// Configure hand gesture recognizer options. // Configure hand gesture recognizer options.
auto base_options_proto_for_gesture_recognizer =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_gesture_recognizer)));
base_options_proto_for_gesture_recognizer->set_use_stream_mode(
use_stream_mode);
auto* hand_gesture_recognizer_graph_options = auto* hand_gesture_recognizer_graph_options =
options_proto->mutable_hand_gesture_recognizer_graph_options(); options_proto->mutable_hand_gesture_recognizer_graph_options();
hand_gesture_recognizer_graph_options->mutable_base_options()->Swap(
base_options_proto_for_gesture_recognizer.get());
if (options->min_gesture_confidence >= 0) { if (options->min_gesture_confidence >= 0) {
hand_gesture_recognizer_graph_options->mutable_classifier_options() hand_gesture_recognizer_graph_options
->mutable_canned_gesture_classifier_graph_options()
->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence); ->set_score_threshold(options->min_gesture_confidence);
} }
return options_proto; return options_proto;
@ -248,15 +206,16 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize( absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(
FillNormalizedRect(image_processing_options)); NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData( ProcessImageData(
@ -283,15 +242,16 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo( absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(
FillNormalizedRect(image_processing_options)); NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
@ -321,15 +281,16 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
absl::Status GestureRecognizer::RecognizeAsync( absl::Status GestureRecognizer::RecognizeAsync(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(
FillNormalizedRect(image_processing_options)); NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))

View File

@ -23,10 +23,10 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe { namespace mediapipe {
@ -39,12 +39,6 @@ struct GestureRecognizerOptions {
// model file with metadata, accelerator options, op resolver, etc. // model file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options; tasks::core::BaseOptions base_options;
// TODO: remove these. Temporary solutions before bundle asset is
// ready.
tasks::core::BaseOptions base_options_for_hand_landmarker;
tasks::core::BaseOptions base_options_for_hand_detector;
tasks::core::BaseOptions base_options_for_gesture_recognizer;
// The running mode of the task. Default to the image mode. // The running mode of the task. Default to the image mode.
// GestureRecognizer has three running modes: // GestureRecognizer has three running modes:
// 1) The image mode for recognizing hand gestures on single image inputs. // 1) The image mode for recognizing hand gestures on single image inputs.
@ -129,36 +123,36 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// Only use this method when the GestureRecognizer is created with the image // Only use this method when the GestureRecognizer is created with the image
// running mode. // running mode.
// //
// image - mediapipe::Image // The optional 'image_processing_options' parameter can be used to specify
// Image to perform hand gesture recognition on. // the rotation to apply to the image before performing recognition, by
// imageProcessingOptions - std::optional<NormalizedRect> // setting its 'rotation_degrees' field. Note that specifying a
// If provided, can be used to specify the rotation to apply to the image // region-of-interest using the 'region_of_interest' field is NOT supported
// before performing classification, by setting its 'rotation' field in // and will result in an invalid argument error being returned.
// radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that
// specifying a region-of-interest using the 'x_center', 'y_center', 'width'
// and 'height' fields is NOT supported and will result in an invalid
// argument error being returned.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed // TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented. // after the yuv support is implemented.
// TODO: use an ImageProcessingOptions struct instead of
// NormalizedRect.
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize( absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
Image image, Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Performs gesture recognition on the provided video frame. // Performs gesture recognition on the provided video frame.
// Only use this method when the GestureRecognizer is created with the video // Only use this method when the GestureRecognizer is created with the video
// running mode. // running mode.
// //
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing recognition, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::GestureRecognitionResult> absl::StatusOr<components::containers::GestureRecognitionResult>
RecognizeForVideo(Image image, int64 timestamp_ms, RecognizeForVideo(Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
// Sends live image data to perform gesture recognition, and the results will // Sends live image data to perform gesture recognition, and the results will
@ -171,6 +165,12 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// sent to the gesture recognizer. The input timestamps must be monotonically // sent to the gesture recognizer. The input timestamps must be monotonically
// increasing. // increasing.
// //
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing recognition, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" provides // The "result_callback" provides
// - A vector of GestureRecognitionResult, each is the recognized results // - A vector of GestureRecognitionResult, each is the recognized results
// for a input frame. // for a input frame.
@ -180,7 +180,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status RecognizeAsync(Image image, int64 timestamp_ms, absl::Status RecognizeAsync(Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
// Shuts down the GestureRecognizer when all works are done. // Shuts down the GestureRecognizer when all works are done.

View File

@ -25,9 +25,13 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
@ -46,6 +50,8 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
GestureRecognizerGraphOptions; GestureRecognizerGraphOptions;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
@ -61,6 +67,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
constexpr char kHandGestureRecognizerBundleAssetName[] =
"hand_gesture_recognizer.task";
struct GestureRecognizerOutputs { struct GestureRecognizerOutputs {
Source<std::vector<ClassificationList>> gesture; Source<std::vector<ClassificationList>> gesture;
@ -70,6 +79,53 @@ struct GestureRecognizerOutputs {
Source<Image> image; Source<Image> image;
}; };
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
GestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_landmarker_file,
resources.GetModelFile(kHandLandmarkerBundleAssetName));
auto* hand_landmarker_graph_options =
options->mutable_hand_landmarker_graph_options();
SetExternalFile(hand_landmarker_file,
hand_landmarker_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
hand_landmarker_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(
const auto hand_gesture_recognizer_file,
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
auto* hand_gesture_recognizer_graph_options =
options->mutable_hand_gesture_recognizer_graph_options();
SetExternalFile(hand_gesture_recognizer_file,
hand_gesture_recognizer_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
hand_gesture_recognizer_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
if (!hand_gesture_recognizer_graph_options->base_options()
.acceleration()
.has_xnnpack() &&
!hand_gesture_recognizer_graph_options->base_options()
.acceleration()
.has_tflite()) {
hand_gesture_recognizer_graph_options->mutable_base_options()
->mutable_acceleration()
->mutable_xnnpack();
LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets "
<< "HandGestureRecognizerGraph acceleartion to Xnnpack.";
}
hand_gesture_recognizer_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus();
}
} // namespace } // namespace
// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs // A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs
@ -136,6 +192,21 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
Graph graph; Graph graph;
if (sc->Options<GestureRecognizerGraphOptions>()
.base_options()
.has_model_asset()) {
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<GestureRecognizerGraphOptions>(sc));
// When the model resources cache service is available, filling in
// the file pointer meta in the subtasks' base options. Otherwise,
// providing the file contents instead.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<GestureRecognizerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph( BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(), *sc->MutableOptions<GestureRecognizerGraphOptions>(),

View File

@ -30,11 +30,17 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
@ -51,6 +57,8 @@ using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::processors:: using ::mediapipe::tasks::components::processors::
ConfigureTensorsToClassificationCalculator; ConfigureTensorsToClassificationCalculator;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions; HandGestureRecognizerGraphOptions;
@ -70,6 +78,14 @@ constexpr char kVectorTag[] = "VECTOR";
constexpr char kIndexTag[] = "INDEX"; constexpr char kIndexTag[] = "INDEX";
constexpr char kIterableTag[] = "ITERABLE"; constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite";
constexpr char kCannedGestureClassifierTFLiteName[] =
"canned_gesture_classifier.tflite";
struct SubTaskModelResources {
const core::ModelResources* gesture_embedder_model_resource;
const core::ModelResources* canned_gesture_classifier_model_resource;
};
Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix, Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
Graph& graph) { Graph& graph) {
@ -78,6 +94,41 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
return node[Output<std::vector<Tensor>>{"TENSORS"}]; return node[Output<std::vector<Tensor>>{"TENSORS"}];
} }
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandGestureRecognizerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
resources.GetModelFile(kGestureEmbedderTFLiteName));
auto* gesture_embedder_graph_options =
options->mutable_gesture_embedder_graph_options();
SetExternalFile(gesture_embedder_file,
gesture_embedder_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
gesture_embedder_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
auto* canned_gesture_classifier_graph_options =
options->mutable_canned_gesture_classifier_graph_options();
SetExternalFile(
canned_gesture_classifier_file,
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_model_asset(),
is_copy);
canned_gesture_classifier_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
canned_gesture_classifier_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus();
}
} // namespace } // namespace
// A // A
@ -128,27 +179,70 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
ASSIGN_OR_RETURN( if (sc->Options<HandGestureRecognizerGraphOptions>()
const auto* model_resources, .base_options()
CreateModelResources<HandGestureRecognizerGraphOptions>(sc)); .has_model_asset()) {
ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<HandGestureRecognizerGraphOptions>(
sc));
// When the model resources cache service is available, filling in
// the file pointer meta in the subtasks' base options. Otherwise,
// providing the file contents instead.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<HandGestureRecognizerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(const auto sub_task_model_resources,
CreateSubTaskModelResources(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(auto hand_gestures,
auto hand_gestures, BuildGestureRecognizerGraph(
BuildGestureRecognizerGraph( sc->Options<HandGestureRecognizerGraphOptions>(),
sc->Options<HandGestureRecognizerGraphOptions>(), *model_resources, sub_task_model_resources,
graph[Input<ClassificationList>(kHandednessTag)], graph[Input<ClassificationList>(kHandednessTag)],
graph[Input<NormalizedLandmarkList>(kLandmarksTag)], graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)], graph[Input<LandmarkList>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph[Input<std::pair<int, int>>(kImageSizeTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph)); graph[Input<NormalizedRect>(kNormRectTag)], graph));
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)]; hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<SubTaskModelResources> CreateSubTaskModelResources(
SubgraphContext* sc) {
auto* options = sc->MutableOptions<HandGestureRecognizerGraphOptions>();
SubTaskModelResources sub_task_model_resources;
auto& gesture_embedder_model_asset =
*options->mutable_gesture_embedder_graph_options()
->mutable_base_options()
->mutable_model_asset();
ASSIGN_OR_RETURN(
sub_task_model_resources.gesture_embedder_model_resource,
CreateModelResources(sc,
std::make_unique<core::proto::ExternalFile>(
std::move(gesture_embedder_model_asset)),
"_gesture_embedder"));
auto& canned_gesture_classifier_model_asset =
*options->mutable_canned_gesture_classifier_graph_options()
->mutable_base_options()
->mutable_model_asset();
ASSIGN_OR_RETURN(
sub_task_model_resources.canned_gesture_classifier_model_resource,
CreateModelResources(
sc,
std::make_unique<core::proto::ExternalFile>(
std::move(canned_gesture_classifier_model_asset)),
"_canned_gesture_classifier"));
return sub_task_model_resources;
}
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph( absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
const HandGestureRecognizerGraphOptions& graph_options, const HandGestureRecognizerGraphOptions& graph_options,
const core::ModelResources& model_resources, const SubTaskModelResources& sub_task_model_resources,
Source<ClassificationList> handedness, Source<ClassificationList> handedness,
Source<NormalizedLandmarkList> hand_landmarks, Source<NormalizedLandmarkList> hand_landmarks,
Source<LandmarkList> hand_world_landmarks, Source<LandmarkList> hand_world_landmarks,
@ -209,17 +303,33 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
auto concatenated_tensors = concatenate_tensor_vector.Out(""); auto concatenated_tensors = concatenate_tensor_vector.Out("");
// Inference for static hand gesture recognition. // Inference for static hand gesture recognition.
// TODO add embedding step. auto& gesture_embedder_inference =
auto& inference = AddInference( AddInference(*sub_task_model_resources.gesture_embedder_model_resource,
model_resources, graph_options.base_options().acceleration(), graph); graph_options.gesture_embedder_graph_options()
concatenated_tensors >> inference.In(kTensorsTag); .base_options()
auto inference_output_tensors = inference.Out(kTensorsTag); .acceleration(),
graph);
concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag);
auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag);
auto& canned_gesture_classifier_inference = AddInference(
*sub_task_model_resources.canned_gesture_classifier_model_resource,
graph_options.canned_gesture_classifier_graph_options()
.base_options()
.acceleration(),
graph);
embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag);
auto inference_output_tensors =
canned_gesture_classifier_inference.Out(kTensorsTag);
auto& tensors_to_classification = auto& tensors_to_classification =
graph.AddNode("TensorsToClassificationCalculator"); graph.AddNode("TensorsToClassificationCalculator");
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
graph_options.classifier_options(), graph_options.canned_gesture_classifier_graph_options()
*model_resources.GetMetadataExtractor(), 0, .classifier_options(),
*sub_task_model_resources.canned_gesture_classifier_model_resource
->GetMetadataExtractor(),
0,
&tensors_to_classification.GetOptions< &tensors_to_classification.GetOptions<
mediapipe::TensorsToClassificationCalculatorOptions>())); mediapipe::TensorsToClassificationCalculatorOptions>()));
inference_output_tensors >> tensors_to_classification.In(kTensorsTag); inference_output_tensors >> tensors_to_classification.In(kTensorsTag);

View File

@ -49,7 +49,6 @@ mediapipe_proto_library(
":gesture_embedder_graph_options_proto", ":gesture_embedder_graph_options_proto",
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,6 @@ syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto; package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
@ -37,15 +36,11 @@ message HandGestureRecognizerGraphOptions {
// Options for GestureEmbedder. // Options for GestureEmbedder.
optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2; optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2;
// Options for GestureClassifier of default gestures. // Options for GestureClassifier of canned gestures.
optional GestureClassifierGraphOptions optional GestureClassifierGraphOptions
canned_gesture_classifier_graph_options = 3; canned_gesture_classifier_graph_options = 3;
// Options for GestureClassifier of custom gestures. // Options for GestureClassifier of custom gestures.
optional GestureClassifierGraphOptions optional GestureClassifierGraphOptions
custom_gesture_classifier_graph_options = 4; custom_gesture_classifier_graph_options = 4;
// TODO: remove these. Temporary solutions before bundle asset is
// ready.
optional components.processors.proto.ClassifierOptions classifier_options = 5;
} }

View File

@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph {
image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_keep_aspect_ratio(true);
image_to_tensor_options.set_border_mode( image_to_tensor_options.set_border_mode(
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");

View File

@ -92,18 +92,30 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
bool is_copy) { bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_detector_file, ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName)); resources.GetModelFile(kHandDetectorTFLiteName));
auto* hand_detector_graph_options =
options->mutable_hand_detector_graph_options();
SetExternalFile(hand_detector_file, SetExternalFile(hand_detector_file,
options->mutable_hand_detector_graph_options() hand_detector_graph_options->mutable_base_options()
->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
is_copy); is_copy);
hand_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
hand_detector_graph_options->mutable_base_options()->set_use_stream_mode(
options->base_options().use_stream_mode());
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
auto* hand_landmarks_detector_graph_options =
options->mutable_hand_landmarks_detector_graph_options();
SetExternalFile(hand_landmarks_detector_file, SetExternalFile(hand_landmarks_detector_file,
options->mutable_hand_landmarks_detector_graph_options() hand_landmarks_detector_graph_options->mutable_base_options()
->mutable_base_options()
->mutable_model_asset(), ->mutable_model_asset(),
is_copy); is_copy);
hand_landmarks_detector_graph_options->mutable_base_options()
->mutable_acceleration()
->CopyFrom(options->base_options().acceleration());
hand_landmarks_detector_graph_options->mutable_base_options()
->set_use_stream_mode(options->base_options().use_stream_mode());
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -67,7 +67,7 @@ using ::testing::proto::Approximately;
using ::testing::proto::Partially; using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task"; constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task";
constexpr char kLeftHandsImage[] = "left_hands.jpg"; constexpr char kLeftHandsImage[] = "left_hands.jpg";
constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg";

View File

@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
subgraph_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");

View File

@ -59,6 +59,7 @@ cc_library(
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
@ -59,26 +60,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
// Returns a NormalizedRect covering the full image if input is not present.
// Otherwise, makes sure the x_center, y_center, width and height are set in
// case only a rotation was provided in the input.
NormalizedRect FillNormalizedRect(
std::optional<NormalizedRect> normalized_rect) {
NormalizedRect result;
if (normalized_rect.has_value()) {
result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (!has_coordinates) {
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
}
return result;
}
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
// type "ImageClassifierGraph". If the task is running in the live stream mode, // type "ImageClassifierGraph". If the task is running in the live stream mode,
// a "FlowLimiterCalculator" will be added to limit the number of frames in // a "FlowLimiterCalculator" will be added to limit the number of frames in
@ -164,14 +145,16 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
} }
absl::StatusOr<ClassificationResult> ImageClassifier::Classify( absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
Image image, std::optional<NormalizedRect> image_processing_options) { Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData( ProcessImageData(
@ -183,14 +166,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo( absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
@ -206,14 +190,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
absl::Status ImageClassifier::ClassifyAsync( absl::Status ImageClassifier::ClassifyAsync(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options); ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe { namespace mediapipe {
@ -109,12 +109,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// //
// The optional 'image_processing_options' parameter can be used to specify: // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by // - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° // setting its 'rotation_degrees' field.
// anti-clockwise rotation).
// and/or // and/or
// - the region-of-interest on which to perform classification, by setting its // - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is // 'region_of_interest' field. If not specified, the full image is used.
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted // If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop. // first, then the specified rotation is applied to the crop.
// //
@ -126,19 +124,17 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// YUVToImageCalculator is integrated. // YUVToImageCalculator is integrated.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Performs image classification on the provided video frame. // Performs image classification on the provided video frame.
// //
// The optional 'image_processing_options' parameter can be used to specify: // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by // - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° // setting its 'rotation_degrees' field.
// anti-clockwise rotation).
// and/or // and/or
// - the region-of-interest on which to perform classification, by setting its // - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is // 'region_of_interest' field. If not specified, the full image is used.
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted // If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop. // first, then the specified rotation is applied to the crop.
// //
@ -150,7 +146,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::ClassificationResult> absl::StatusOr<components::containers::proto::ClassificationResult>
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
// Sends live image data to image classification, and the results will be // Sends live image data to image classification, and the results will be
@ -158,12 +154,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// //
// The optional 'image_processing_options' parameter can be used to specify: // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing classification, by // - the rotation to apply to the image before performing classification, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° // setting its 'rotation_degrees' field.
// anti-clockwise rotation).
// and/or // and/or
// - the region-of-interest on which to perform classification, by setting its // - the region-of-interest on which to perform classification, by setting its
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is // 'region_of_interest' field. If not specified, the full image is used.
// set, they will automatically be set to cover the full image.
// If both are specified, the crop around the region-of-interest is extracted // If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop. // first, then the specified rotation is applied to the crop.
// //
@ -175,7 +169,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// sent to the object detector. The input timestamps must be monotonically // sent to the object detector. The input timestamps must be monotonically
// increasing. // increasing.
// //
// The "result_callback" prvoides // The "result_callback" provides:
// - The classification results as a ClassificationResult object. // - The classification results as a ClassificationResult object.
// - The const reference to the corresponding input image that the image // - The const reference to the corresponding input image that the image
// classifier runs on. Note that the const reference to the image will no // classifier runs on. Note that the const reference to the image will no
@ -183,12 +177,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms, absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
// TODO: add Classify() variants taking a region of interest as
// additional argument.
// Shuts down the ImageClassifier when all works are done. // Shuts down the ImageClassifier when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
}; };

View File

@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
@ -35,6 +34,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -49,9 +50,11 @@ namespace image_classifier {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications; using ::mediapipe::tasks::components::containers::proto::Classifications;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -547,12 +550,9 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball. // Region-of-interest around the soccer ball.
NormalizedRect image_processing_options; Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
image_processing_options.set_x_center(0.532); ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
image_processing_options.set_y_center(0.521);
image_processing_options.set_width(0.164);
image_processing_options.set_height(0.427);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options)); image, image_processing_options));
@ -572,8 +572,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Specify a 90° anti-clockwise rotation. // Specify a 90° anti-clockwise rotation.
NormalizedRect image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.set_rotation(M_PI / 2.0); image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options)); image, image_processing_options));
@ -616,13 +616,10 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Crop around the chair, with 90° anti-clockwise rotation. // Region-of-interest around the chair, with 90° anti-clockwise rotation.
NormalizedRect image_processing_options; Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049};
image_processing_options.set_x_center(0.2821); ImageProcessingOptions image_processing_options{roi,
image_processing_options.set_y_center(0.2406); /*rotation_degrees=*/-90};
image_processing_options.set_width(0.5642);
image_processing_options.set_height(0.1286);
image_processing_options.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options)); image, image_processing_options));
@ -633,7 +630,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
entries { entries {
categories { categories {
index: 560 index: 560
score: 0.6800408 score: 0.6522213
category_name: "folding chair" category_name: "folding chair"
} }
timestamp_ms: 0 timestamp_ms: 0
@ -643,6 +640,69 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
})pb")); })pb"));
} }
// Testing all these once with ImageClassifier.
TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
MP_ASSERT_OK_AND_ASSIGN(Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"multi_objects.jpg")));
auto options = std::make_unique<ImageClassifierOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
// Invalid: left > right.
Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/0};
auto results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect with left < right and top < bottom"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
// Invalid: top > bottom.
roi = {/*left=*/0, /*top=*/0.9, /*right=*/1, /*bottom=*/0.1};
image_processing_options = {roi,
/*rotation_degrees=*/0};
results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect with left < right and top < bottom"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
// Invalid: coordinates out of [0,1] range.
roi = {/*left=*/-0.1, /*top=*/0, /*right=*/1, /*bottom=*/1};
image_processing_options = {roi,
/*rotation_degrees=*/0};
results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected Rect values to be in [0,1]"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
// Invalid: rotation not a multiple of 90°.
image_processing_options = {/*region_of_interest=*/std::nullopt,
/*rotation_degrees=*/1};
results = image_classifier->Classify(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("Expected rotation to be a multiple of 90°"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
class VideoModeTest : public tflite_shims::testing::Test {}; class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
@ -732,11 +792,9 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball. // Crop around the soccer ball.
NormalizedRect image_processing_options; // Region-of-interest around the soccer ball.
image_processing_options.set_x_center(0.532); Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
image_processing_options.set_y_center(0.521); ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
image_processing_options.set_width(0.164);
image_processing_options.set_height(0.427);
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
@ -877,11 +935,8 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
// Crop around the soccer ball. // Crop around the soccer ball.
NormalizedRect image_processing_options; Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
image_processing_options.set_x_center(0.532); ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
image_processing_options.set_y_center(0.521);
image_processing_options.set_width(0.164);
image_processing_options.set_height(0.427);
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK( MP_ASSERT_OK(

View File

@ -58,6 +58,7 @@ cc_library(
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto:: using ::mediapipe::tasks::vision::image_embedder::proto::
ImageEmbedderGraphOptions; ImageEmbedderGraphOptions;
// Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() {
NormalizedRect norm_rect;
norm_rect.set_x_center(0.5);
norm_rect.set_y_center(0.5);
norm_rect.set_width(1);
norm_rect.set_height(1);
return norm_rect;
}
// Creates a MediaPipe graph config that contains a single node of type // Creates a MediaPipe graph config that contains a single node of type
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is // "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to // running in the live stream mode, a "FlowLimiterCalculator" will be added to
@ -148,15 +139,16 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed( absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
Image image, std::optional<NormalizedRect> roi) { Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = ASSIGN_OR_RETURN(NormalizedRect norm_rect,
roi.has_value() ? roi.value() : BuildFullImageNormRect(); ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData( ProcessImageData(
@ -167,15 +159,16 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
} }
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo( absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) { Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = ASSIGN_OR_RETURN(NormalizedRect norm_rect,
roi.has_value() ? roi.value() : BuildFullImageNormRect(); ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
@ -188,16 +181,17 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>(); return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
} }
absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms, absl::Status ImageEmbedder::EmbedAsync(
std::optional<NormalizedRect> roi) { Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
NormalizedRect norm_rect = ASSIGN_OR_RETURN(NormalizedRect norm_rect,
roi.has_value() ? roi.value() : BuildFullImageNormRect(); ConvertToNormalizedRect(image_processing_options));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/components/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe { namespace mediapipe {
@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi {
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create( static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
std::unique_ptr<ImageEmbedderOptions> options); std::unique_ptr<ImageEmbedderOptions> options);
// Performs embedding extraction on the provided single image. Extraction // Performs embedding extraction on the provided single image.
// is performed on the region of interest specified by the `roi` argument if //
// provided, or on the entire image otherwise. // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
// //
// Only use this method when the ImageEmbedder is created with the image // Only use this method when the ImageEmbedder is created with the image
// running mode. // running mode.
@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed( absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs embedding extraction on the provided video frame. Extraction // Performs embedding extraction on the provided video frame.
// is performed on the region of interested specified by the `roi` argument if //
// provided, or on the entire image otherwise. // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
// //
// Only use this method when the ImageEmbedder is created with the video // Only use this method when the ImageEmbedder is created with the video
// running mode. // running mode.
@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo( absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to embedder, and the results will be available via // Sends live image data to embedder, and the results will be available via
// the "result_callback" provided in the ImageEmbedderOptions. Embedding // the "result_callback" provided in the ImageEmbedderOptions.
// extraction is performed on the region of interested specified by the `roi` //
// argument if provided, or on the entire image otherwise. // The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
// //
// Only use this method when the ImageEmbedder is created with the live // Only use this method when the ImageEmbedder is created with the live
// stream running mode. // stream running mode.
@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data // longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status EmbedAsync( absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms,
mediapipe::Image image, int64 timestamp_ms, std::optional<core::ImageProcessingOptions>
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); image_processing_options = std::nullopt);
// Shuts down the ImageEmbedder when all works are done. // Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }

View File

@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
@ -42,7 +41,9 @@ namespace image_embedder {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile( Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
NormalizedRect roi; Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
roi.set_x_center(200.0 / 480); ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
roi.set_y_center(0.5);
roi.set_width(400.0 / 480);
roi.set_height(1.0f);
// Extract both embeddings. // Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, MP_ASSERT_OK_AND_ASSIGN(
image_embedder->Embed(image, roi)); const EmbeddingResult& image_result,
image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop)); image_embedder->Embed(crop));
@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
} }
TEST_F(ImageModeTest, SucceedsWithRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a rotated version of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
// Region-of-interest corresponding to burger_crop.jpg.
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
class VideoModeTest : public tflite_shims::testing::Test {}; class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {

View File

@ -24,10 +24,12 @@ cc_library(
":image_segmenter_graph", ":image_segmenter_graph",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
@ -48,6 +50,7 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",

View File

@ -17,8 +17,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageSegmenterGraph"; "mediapipe.tasks.vision.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig(
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get()); task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag); graph.Out(kGroupedSegmentationTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator( return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag); {kImageTag, kNormRectTag},
kGroupedSegmentationTag);
} }
graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -139,47 +146,68 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
} }
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment( absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
mediapipe::Image image) { mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData({{kImageInStreamName, ProcessImageData(
mediapipe::MakePacket<Image>(std::move(image))}})); {{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>(); return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
} }
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo( absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms) { mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>(); return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
} }
absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) { absl::Status ImageSegmenter::SegmentAsync(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// running mode. // running mode.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed //
// after the yuv support is implemented. // The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
// //
// If the output_type is CATEGORY_MASK, the returned vector of images is // If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask. // per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images // If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask. // contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image); absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs image segmentation on the provided video frame. // Performs image segmentation on the provided video frame.
// Only use this method when the ImageSegmenter is created with the video // Only use this method when the ImageSegmenter is created with the video
@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
// //
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is // If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask. // per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images // If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask. // contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo( absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms); mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform image segmentation, and the results will // Sends live image data to perform image segmentation, and the results will
// be available via the "result_callback" provided in the // be available via the "result_callback" provided in the
@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// sent to the image segmenter. The input timestamps must be monotonically // sent to the image segmenter. The input timestamps must be monotonically
// increasing. // increasing.
// //
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" prvoides // The "result_callback" prvoides
// - A vector of segmented image masks. // - A vector of segmented image masks.
// If the output_type is CATEGORY_MASK, the returned vector of images is // If the output_type is CATEGORY_MASK, the returned vector of images is
@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// no longer be valid when the callback returns. To access the image data // no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms); absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageSegmenter when all works are done. // Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
@ -159,6 +161,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform segmentation on. // Image to perform segmentation on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection
// on.
// @Optional: rect covering the whole image is used if not specified.
// //
// Outputs: // Outputs:
// SEGMENTATION - mediapipe::Image @Multiple // SEGMENTATION - mediapipe::Image @Multiple
@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterOptions>(sc)); CreateModelResources<ImageSegmenterOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN(auto output_streams, ASSIGN_OR_RETURN(
BuildSegmentationTask( auto output_streams,
sc->Options<ImageSegmenterOptions>(), *model_resources, BuildSegmentationTask(
graph[Input<Image>(kImageTag)], graph)); sc->Options<ImageSegmenterOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
auto& merge_images_to_vector = auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator"); graph.AddNode("MergeImagesToVectorCalculator");
@ -228,18 +236,21 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask( absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterOptions& task_options, const ImageSegmenterOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
// Adds inference subgraph and connects its input stream to the output // Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator. // tensors produced by the ImageToTensorCalculator.

View File

@ -29,8 +29,10 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -44,6 +46,8 @@ namespace {
using ::mediapipe::Image; using ::mediapipe::Image;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options))); ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21); EXPECT_EQ(confidence_masks.size(), 21);
@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
} }
TEST_F(ImageModeTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
// Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = segmenter->Segment(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
Image image = Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));

View File

@ -75,6 +75,7 @@ cc_library(
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h" #include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h"
@ -58,31 +59,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
using ObjectDetectorOptionsProto = using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions; object_detector::proto::ObjectDetectorOptions;
// Returns a NormalizedRect filling the whole image. If input is present, its
// rotation is set in the returned NormalizedRect and a check is performed to
// make sure no region-of-interest was provided. Otherwise, rotation is set to
// 0.
absl::StatusOr<NormalizedRect> FillNormalizedRect(
std::optional<NormalizedRect> normalized_rect) {
NormalizedRect result;
if (normalized_rect.has_value()) {
result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (has_coordinates) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ObjectDetector does not support region-of-interest.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
return result;
}
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the // "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the
// live stream mode, a "FlowLimiterCalculator" will be added to limit the // live stream mode, a "FlowLimiterCalculator" will be added to limit the
@ -170,15 +146,16 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect( absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(
FillNormalizedRect(image_processing_options)); NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessImageData( ProcessImageData(
@ -189,15 +166,16 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo( absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(
FillNormalizedRect(image_processing_options)); NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
@ -212,15 +190,16 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
absl::Status ObjectDetector::DetectAsync( absl::Status ObjectDetector::DetectAsync(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect, ASSIGN_OR_RETURN(
FillNormalizedRect(image_processing_options)); NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))

View File

@ -27,9 +27,9 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe { namespace mediapipe {
@ -154,10 +154,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// after the yuv support is implemented. // after the yuv support is implemented.
// //
// The optional 'image_processing_options' parameter can be used to specify // The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing classification, by // the rotation to apply to the image before performing detection, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° // setting its 'rotation_degrees' field. Note that specifying a
// anti-clockwise rotation). Note that specifying a region-of-interest using // region-of-interest using the 'region_of_interest' field is NOT supported
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// //
// For CPU images, the returned bounding boxes are expressed in the // For CPU images, the returned bounding boxes are expressed in the
@ -168,7 +167,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// images after enabling the gpu support in MediaPipe Tasks. // images after enabling the gpu support in MediaPipe Tasks.
absl::StatusOr<std::vector<mediapipe::Detection>> Detect( absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Performs object detection on the provided video frame. // Performs object detection on the provided video frame.
@ -180,10 +179,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing. // must be monotonically increasing.
// //
// The optional 'image_processing_options' parameter can be used to specify // The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing classification, by // the rotation to apply to the image before performing detection, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° // setting its 'rotation_degrees' field. Note that specifying a
// anti-clockwise rotation). Note that specifying a region-of-interest using // region-of-interest using the 'region_of_interest' field is NOT supported
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// //
// For CPU images, the returned bounding boxes are expressed in the // For CPU images, the returned bounding boxes are expressed in the
@ -192,7 +190,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// underlying image data. // underlying image data.
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo( absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
// Sends live image data to perform object detection, and the results will be // Sends live image data to perform object detection, and the results will be
@ -206,10 +204,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// increasing. // increasing.
// //
// The optional 'image_processing_options' parameter can be used to specify // The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing classification, by // the rotation to apply to the image before performing detection, by
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90° // setting its 'rotation_degrees' field. Note that specifying a
// anti-clockwise rotation). Note that specifying a region-of-interest using // region-of-interest using the 'region_of_interest' field is NOT supported
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
// and will result in an invalid argument error being returned. // and will result in an invalid argument error being returned.
// //
// The "result_callback" provides // The "result_callback" provides
@ -223,7 +220,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms, absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt); image_processing_options = std::nullopt);
// Shuts down the ObjectDetector when all works are done. // Shuts down the ObjectDetector when all works are done.

View File

@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
// stream. // stream.
auto& preprocessing = auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
task_options.base_options().acceleration());
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
model_resources, model_resources, use_gpu,
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);

View File

@ -31,11 +31,12 @@ limitations under the License.
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/location_data.pb.h" #include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
@ -64,6 +65,8 @@ namespace vision {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -532,8 +535,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
NormalizedRect image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.set_rotation(M_PI / 2.0); image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto results, object_detector->Detect(image, image_processing_options)); auto results, object_detector->Detect(image, image_processing_options));
MP_ASSERT_OK(object_detector->Close()); MP_ASSERT_OK(object_detector->Close());
@ -557,16 +560,17 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
ObjectDetector::Create(std::move(options))); ObjectDetector::Create(std::move(options)));
NormalizedRect image_processing_options; Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
image_processing_options.set_x_center(0.5); ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
image_processing_options.set_y_center(0.5);
image_processing_options.set_width(1.0);
image_processing_options.set_height(1.0);
auto results = object_detector->Detect(image, image_processing_options); auto results = object_detector->Detect(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(), EXPECT_THAT(results.status().message(),
HasSubstr("ObjectDetector does not support region-of-interest")); HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
} }
class VideoModeTest : public tflite_shims::testing::Test {}; class VideoModeTest : public tflite_shims::testing::Test {};

View File

@ -31,6 +31,7 @@ android_binary(
multidex = "native", multidex = "native",
resource_files = ["//mediapipe/tasks/examples/android:resource_files"], resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
deps = [ deps = [
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",

View File

@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector;
import android.content.Intent; import android.content.Intent;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.MediaMetadataRetriever; import android.media.MediaMetadataRetriever;
import android.os.Bundle; import android.os.Bundle;
import android.provider.MediaStore; import android.provider.MediaStore;
@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts; import androidx.activity.result.contract.ActivityResultContracts;
import androidx.exifinterface.media.ExifInterface; import androidx.exifinterface.media.ExifInterface;
// ContentResolver dependency // ContentResolver dependency
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity {
if (resultIntent != null) { if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) { if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null; Bitmap bitmap = null;
int rotation = 0;
try { try {
bitmap = bitmap =
downscaleBitmap( downscaleBitmap(
@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity {
try { try {
InputStream imageData = InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData()); this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData); rotation = getImageRotation(imageData);
} catch (IOException e) { } catch (IOException | MediaPipeException e) {
Log.e(TAG, "Bitmap rotation error:" + e); Log.e(TAG, "Bitmap rotation error:" + e);
} }
if (bitmap != null) { if (bitmap != null) {
Image image = new BitmapImageBuilder(bitmap).build(); MPImage image = new BitmapImageBuilder(bitmap).build();
ObjectDetectionResult detectionResult = objectDetector.detect(image); ObjectDetectionResult detectionResult =
objectDetector.detect(
image,
ImageProcessingOptions.builder().setRotationDegrees(rotation).build());
imageView.setData(image, detectionResult); imageView.setData(image, detectionResult);
runOnUiThread(() -> imageView.update()); runOnUiThread(() -> imageView.update());
} }
@ -144,7 +149,8 @@ public class MainActivity extends AppCompatActivity {
MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT));
long frameIntervalMs = duration / numFrames; long frameIntervalMs = duration / numFrames;
for (int i = 0; i < numFrames; ++i) { for (int i = 0; i < numFrames; ++i) {
Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); MPImage image =
new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build();
ObjectDetectionResult detectionResult = ObjectDetectionResult detectionResult =
objectDetector.detectForVideo(image, frameIntervalMs * i); objectDetector.detectForVideo(image, frameIntervalMs * i);
// Currently only annotates the detection result on the first video frame and // Currently only annotates the detection result on the first video frame and
@ -209,28 +215,25 @@ public class MainActivity extends AppCompatActivity {
return Bitmap.createScaledBitmap(originalBitmap, width, height, false); return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
} }
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException { private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException {
int orientation = int orientation =
new ExifInterface(imageData) new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) { switch (orientation) {
case ExifInterface.ORIENTATION_NORMAL:
return 0;
case ExifInterface.ORIENTATION_ROTATE_90: case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90); return 90;
break;
case ExifInterface.ORIENTATION_ROTATE_180: case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180); return 180;
break;
case ExifInterface.ORIENTATION_ROTATE_270: case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270); return 270;
break;
default: default:
matrix.postRotate(0); // TODO: use getRotationDegrees() and isFlipped() instead of switch once flip
// is supported.
throw new MediaPipeException(
MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(),
"Flipped images are not supported yet.");
} }
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
} }
} }

View File

@ -22,7 +22,7 @@ import android.graphics.Matrix;
import android.graphics.Paint; import android.graphics.Paint;
import androidx.appcompat.widget.AppCompatImageView; import androidx.appcompat.widget.AppCompatImageView;
import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Detection; import com.google.mediapipe.tasks.components.containers.Detection;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
@ -40,12 +40,12 @@ public class ObjectDetectionResultImageView extends AppCompatImageView {
} }
/** /**
* Sets an {@link Image} and an {@link ObjectDetectionResult} to render. * Sets a {@link MPImage} and an {@link ObjectDetectionResult} to render.
* *
* @param image an {@link Image} object for annotation. * @param image a {@link MPImage} object for annotation.
* @param result an {@link ObjectDetectionResult} object that contains the detection result. * @param result an {@link ObjectDetectionResult} object that contains the detection result.
*/ */
public void setData(Image image, ObjectDetectionResult result) { public void setData(MPImage image, ObjectDetectionResult result) {
if (image == null || result == null) { if (image == null || result == null) {
return; return;
} }

View File

@ -1 +1,15 @@
# dummy file for tap test to find the pattern # Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])

View File

@ -0,0 +1,15 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])

View File

@ -36,3 +36,15 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar")
mediapipe_tasks_core_aar(
name = "tasks_core",
srcs = glob(["*.java"]) + [
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src",
"//mediapipe/java/com/google/mediapipe/framework/image:java_src",
],
manifest = "AndroidManifest.xml",
)

View File

@ -0,0 +1,256 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Building MediaPipe Tasks AARs."""
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build_aar_with_jni", "mediapipe_java_proto_src_extractor", "mediapipe_java_proto_srcs")
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
]
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
]
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
]
def mediapipe_tasks_core_aar(name, srcs, manifest):
"""Builds medaipipe tasks core AAR.
Args:
name: The bazel target name.
srcs: MediaPipe Tasks' core layer source files.
manifest: The Android manifest.
"""
mediapipe_tasks_java_proto_srcs = []
for target in _CORE_TASKS_JAVA_PROTO_LITE_TARGETS:
mediapipe_tasks_java_proto_srcs.append(
_mediapipe_tasks_java_proto_src_extractor(target = target),
)
for target in _VISION_TASKS_JAVA_PROTO_LITE_TARGETS:
mediapipe_tasks_java_proto_srcs.append(
_mediapipe_tasks_java_proto_src_extractor(target = target),
)
for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS:
mediapipe_tasks_java_proto_srcs.append(
_mediapipe_tasks_java_proto_src_extractor(target = target),
)
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java",
))
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
src_out = "com/google/mediapipe/calculator/proto/InferenceCalculatorProto.java",
))
android_library(
name = name,
srcs = srcs + [
"//mediapipe/java/com/google/mediapipe/framework:java_src",
] + mediapipe_java_proto_srcs() +
mediapipe_tasks_java_proto_srcs,
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = manifest,
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework:calculator_profile_java_proto_lite",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework:mediapipe_options_java_proto_lite",
"//mediapipe/framework:packet_factory_java_proto_lite",
"//mediapipe/framework:packet_generator_java_proto_lite",
"//mediapipe/framework:status_handler_java_proto_lite",
"//mediapipe/framework:stream_handler_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
"//third_party:androidx_annotation",
"//third_party:autovalue",
"@com_google_protobuf//:protobuf_javalite",
"@maven//:com_google_guava_guava",
"@maven//:com_google_flogger_flogger",
"@maven//:com_google_flogger_flogger_system_backend",
"@maven//:com_google_code_findbugs_jsr305",
] +
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS +
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS +
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
)
def mediapipe_tasks_vision_aar(name, srcs, native_library):
"""Builds medaipipe tasks vision AAR.
Args:
name: The bazel target name.
srcs: MediaPipe Vision Tasks' source files.
native_library: The native library that contains vision tasks' graph and calculators.
"""
native.genrule(
name = name + "tasks_manifest_generator",
outs = ["AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision">
<uses-sdk
android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>
EOF
""",
)
_mediapipe_tasks_aar(
name = name,
srcs = srcs,
manifest = "AndroidManifest.xml",
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS,
native_library = native_library,
)
def mediapipe_tasks_text_aar(name, srcs, native_library):
"""Builds medaipipe tasks text AAR.
Args:
name: The bazel target name.
srcs: MediaPipe Text Tasks' source files.
native_library: The native library that contains text tasks' graph and calculators.
"""
native.genrule(
name = name + "tasks_manifest_generator",
outs = ["AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.text">
<uses-sdk
android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>
EOF
""",
)
_mediapipe_tasks_aar(
name = name,
srcs = srcs,
manifest = "AndroidManifest.xml",
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
native_library = native_library,
)
def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library):
"""Builds medaipipe tasks AAR."""
# When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command,
# the OpenCV so libraries will be excluded from the AAR package to
# save the package size.
native.config_setting(
name = "exclude_opencv_so_lib",
define_values = {
"EXCLUDE_OPENCV_SO_LIB": "1",
},
visibility = ["//visibility:public"],
)
native.cc_library(
name = name + "_jni_opencv_cc_lib",
srcs = select({
"//mediapipe:android_arm64": ["@android_opencv//:libopencv_java3_so_arm64-v8a"],
"//mediapipe:android_armeabi": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"],
"//mediapipe:android_arm": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"],
"//mediapipe:android_x86": ["@android_opencv//:libopencv_java3_so_x86"],
"//mediapipe:android_x86_64": ["@android_opencv//:libopencv_java3_so_x86_64"],
"//conditions:default": [],
}),
alwayslink = 1,
)
android_library(
name = name + "_android_lib",
srcs = srcs,
manifest = manifest,
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
deps = java_proto_lite_targets + [native_library] + [
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
] + select({
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
"//mediapipe/framework/port:disable_opencv": [],
"exclude_opencv_so_lib": [],
}),
)
mediapipe_build_aar_with_jni(name, name + "_android_lib")
def _mediapipe_tasks_java_proto_src_extractor(target):
proto_path = "com/google/" + target.split(":")[0].replace("cc/", "").replace("//", "").replace("_", "") + "/"
proto_name = target.split(":")[-1].replace("_java_proto_lite", "").replace("_", " ").title().replace(" ", "") + "Proto.java"
return mediapipe_java_proto_src_extractor(
target = target,
src_out = proto_path + proto_name,
)

View File

@ -61,3 +61,11 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
mediapipe_tasks_text_aar(
name = "tasks_text",
srcs = glob(["**/*.java"]),
native_library = ":libmediapipe_tasks_text_jni_lib",
)

View File

@ -15,11 +15,11 @@
package com.google.mediapipe.tasks.text.textclassifier; package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.Category; import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry; import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications; import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;

View File

@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.OutputHandler;

View File

@ -28,6 +28,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
@ -128,6 +129,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
@ -140,3 +142,11 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
mediapipe_tasks_vision_aar(
name = "tasks_vision",
srcs = glob(["**/*.java"]),
native_library = ":libmediapipe_tasks_vision_jni_lib",
)

View File

@ -19,12 +19,11 @@ import com.google.mediapipe.formats.proto.RectProto.NormalizedRect;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional;
/** The base class of MediaPipe vision tasks. */ /** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable { public class BaseVisionTaskApi implements AutoCloseable {
@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
private final TaskRunner runner; private final TaskRunner runner;
private final RunningMode runningMode; private final RunningMode runningMode;
private final String imageStreamName; private final String imageStreamName;
private final Optional<String> normRectStreamName; private final String normRectStreamName;
static { static {
System.loadLibrary("mediapipe_tasks_vision_jni"); System.loadLibrary("mediapipe_tasks_vision_jni");
@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
} }
/** /**
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input. * Constructor to initialize a {@link BaseVisionTaskApi}.
* *
* @param runner a {@link TaskRunner}. * @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream. * @param imageStreamName the name of the input image stream.
*/ * @param normRectStreamName the name of the input normalized rect image stream used to provide
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) { * (mandatory) rotation and (optional) region-of-interest.
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.empty();
}
/**
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
* input.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image stream.
*/ */
public BaseVisionTaskApi( public BaseVisionTaskApi(
TaskRunner runner, TaskRunner runner,
@ -70,61 +55,31 @@ public class BaseVisionTaskApi implements AutoCloseable {
this.runner = runner; this.runner = runner;
this.runningMode = runningMode; this.runningMode = runningMode;
this.imageStreamName = imageStreamName; this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.of(normRectStreamName); this.normRectStreamName = normRectStreamName;
} }
/** /**
* A synchronous method to process single image inputs. The call blocks the current thread until a * A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned. * failure status or a successful result is returned.
* *
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input. * input image before running inference.
* @throws MediaPipeException if the task is not in the image mode.
*/ */
protected TaskResult processImageData(Image image) { protected TaskResult processImageData(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
if (runningMode != RunningMode.IMAGE) { if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:" "Task is not initialized with the image mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets);
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
* rect.
*/
protected TaskResult processImageData(Image image, RectF roi) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put( inputPackets.put(
normRectStreamName.get(), normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
return runner.process(inputPackets); return runner.process(inputPackets);
} }
@ -132,56 +87,25 @@ public class BaseVisionTaskApi implements AutoCloseable {
* A synchronous method to process continuous video frames. The call blocks the current thread * A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned. * until a failure status or a successful result is returned.
* *
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @param timestampMs the corresponding timestamp of the input image in milliseconds. * @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect * @throws MediaPipeException if the task is not in the video mode.
* input.
*/ */
protected TaskResult processVideoData(Image image, long timestampMs) { protected TaskResult processVideoData(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (runningMode != RunningMode.VIDEO) { if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:" "Task is not initialized with the video mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put( inputPackets.put(
normRectStreamName.get(), normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
@ -189,56 +113,25 @@ public class BaseVisionTaskApi implements AutoCloseable {
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener. * available in the user-defined result listener.
* *
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @param timestampMs the corresponding timestamp of the input image in milliseconds. * @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect * @throws MediaPipeException if the task is not in the stream mode.
* input.
*/ */
protected void sendLiveStreamData(Image image, long timestampMs) { protected void sendLiveStreamData(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) { if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:" "Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put( inputPackets.put(
normRectStreamName.get(), normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
runner.close(); runner.close();
} }
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */ /**
private static NormalizedRect convertToNormalizedRect(RectF rect) { * Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf
* message.
*/
private static NormalizedRect convertToNormalizedRect(
ImageProcessingOptions imageProcessingOptions) {
RectF regionOfInterest =
imageProcessingOptions.regionOfInterest().isPresent()
? imageProcessingOptions.regionOfInterest().get()
: new RectF(0, 0, 1, 1);
return NormalizedRect.newBuilder() return NormalizedRect.newBuilder()
.setXCenter(rect.centerX()) .setXCenter(regionOfInterest.centerX())
.setYCenter(rect.centerY()) .setYCenter(regionOfInterest.centerY())
.setWidth(rect.width()) .setWidth(regionOfInterest.width())
.setHeight(rect.height()) .setHeight(regionOfInterest.height())
// Convert to radians anti-clockwise.
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
.build(); .build();
} }
} }

View File

@ -0,0 +1,92 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
import android.graphics.RectF;
import com.google.auto.value.AutoValue;
import java.util.Optional;
// TODO: add support for image flipping.
/** Options for image processing. */
@AutoValue
public abstract class ImageProcessingOptions {
/**
* Builder for {@link ImageProcessingOptions}.
*
* <p>If both region-of-interest and rotation are specified, the crop around the
* region-of-interest is extracted first, then the specified rotation is applied to the crop.
*/
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the optional region-of-interest to crop from the image. If not specified, the full image
* is used.
*
* <p>Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be
* < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()}
* is called.
*/
public abstract Builder setRegionOfInterest(RectF value);
/**
* Sets the rotation to apply to the image (or cropped region-of-interest), in degrees
* clockwise. Defaults to 0.
*
* <p>The rotation must be a multiple (positive or negative) of 90°, otherwise an
* IllegalArgumentException will be thrown when {@link #build()} is called.
*/
public abstract Builder setRotationDegrees(int value);
abstract ImageProcessingOptions autoBuild();
/**
* Validates and builds the {@link ImageProcessingOptions} instance.
*
* @throws IllegalArgumentException if some of the provided values do not meet their
* requirements.
*/
public final ImageProcessingOptions build() {
ImageProcessingOptions options = autoBuild();
if (options.regionOfInterest().isPresent()) {
RectF roi = options.regionOfInterest().get();
if (roi.left >= roi.right || roi.top >= roi.bottom) {
throw new IllegalArgumentException(
String.format(
"Expected left < right and top < bottom, found: %s.", roi.toShortString()));
}
if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) {
throw new IllegalArgumentException(
String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString()));
}
}
if (options.rotationDegrees() % 90 != 0) {
throw new IllegalArgumentException(
String.format(
"Expected rotation to be a multiple of 90°, found: %d.",
options.rotationDegrees()));
}
return options;
}
}
public abstract Optional<RectF> regionOfInterest();
public abstract int rotationDegrees();
public static Builder builder() {
return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0);
}
}

Some files were not shown because too many files have changed in this diff Show More