Merge branch 'master' into image-segmenter-python-impl
This commit is contained in:
commit
334f641463
|
@ -172,6 +172,10 @@ http_archive(
|
|||
urls = [
|
||||
"https://github.com/google/sentencepiece/archive/1.0.0.zip",
|
||||
],
|
||||
patches = [
|
||||
"//third_party:com_google_sentencepiece_no_gflag_no_gtest.diff",
|
||||
],
|
||||
patch_args = ["-p1"],
|
||||
repo_mapping = {"@com_google_glog" : "@com_github_glog_glog"},
|
||||
)
|
||||
|
||||
|
|
14
docs/BUILD
Normal file
14
docs/BUILD
Normal file
|
@ -0,0 +1,14 @@
|
|||
# Placeholder for internal Python strict binary compatibility macro.
|
||||
|
||||
py_binary(
|
||||
name = "build_py_api_docs",
|
||||
srcs = ["build_py_api_docs.py"],
|
||||
deps = [
|
||||
"//mediapipe",
|
||||
"//third_party/py/absl:app",
|
||||
"//third_party/py/absl/flags",
|
||||
"//third_party/py/tensorflow_docs",
|
||||
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
|
||||
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||
],
|
||||
)
|
85
docs/build_py_api_docs.py
Normal file
85
docs/build_py_api_docs.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
r"""MediaPipe reference docs generation script.
|
||||
|
||||
This script generates API reference docs for the `mediapipe` PIP package.
|
||||
|
||||
$> pip install -U git+https://github.com/tensorflow/docs mediapipe
|
||||
$> python build_py_api_docs.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from tensorflow_docs.api_generator import generate_lib
|
||||
from tensorflow_docs.api_generator import public_api
|
||||
|
||||
try:
|
||||
# mediapipe has not been set up to work with bazel yet, so catch & report.
|
||||
import mediapipe # pytype: disable=import-error
|
||||
except ImportError as e:
|
||||
raise ImportError('Please `pip install mediapipe`.') from e
|
||||
|
||||
|
||||
PROJECT_SHORT_NAME = 'mp'
|
||||
PROJECT_FULL_NAME = 'MediaPipe'
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir',
|
||||
default='/tmp/generated_docs',
|
||||
help='Where to write the resulting docs.')
|
||||
|
||||
_URL_PREFIX = flags.DEFINE_string(
|
||||
'code_url_prefix',
|
||||
'https://github.com/google/mediapipe/tree/master/mediapipe',
|
||||
'The url prefix for links to code.')
|
||||
|
||||
_SEARCH_HINTS = flags.DEFINE_bool(
|
||||
'search_hints', True,
|
||||
'Include metadata search hints in the generated files')
|
||||
|
||||
_SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python',
|
||||
'Path prefix in the _toc.yaml')
|
||||
|
||||
|
||||
def gen_api_docs():
|
||||
"""Generates API docs for the mediapipe package."""
|
||||
|
||||
doc_generator = generate_lib.DocGenerator(
|
||||
root_title=PROJECT_FULL_NAME,
|
||||
py_modules=[(PROJECT_SHORT_NAME, mediapipe)],
|
||||
base_dir=os.path.dirname(mediapipe.__file__),
|
||||
code_url_prefix=_URL_PREFIX.value,
|
||||
search_hints=_SEARCH_HINTS.value,
|
||||
site_path=_SITE_PATH.value,
|
||||
# This callback ensures that docs are only generated for objects that
|
||||
# are explicitly imported in your __init__.py files. There are other
|
||||
# options but this is a good starting point.
|
||||
callbacks=[public_api.explicit_package_contents_filter],
|
||||
)
|
||||
|
||||
doc_generator.build(_OUTPUT_DIR.value)
|
||||
|
||||
print('Docs output to:', _OUTPUT_DIR.value)
|
||||
|
||||
|
||||
def main(_):
|
||||
gen_api_docs()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -253,6 +253,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "regex_preprocessor_calculator_test",
|
||||
srcs = ["regex_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_to_tensor_calculator",
|
||||
srcs = ["text_to_tensor_calculator.cc"],
|
||||
|
@ -307,6 +327,27 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "universal_sentence_encoder_preprocessor_calculator_test",
|
||||
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
|
||||
deps = [
|
||||
":universal_sentence_encoder_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "inference_calculator_proto",
|
||||
srcs = ["inference_calculator.proto"],
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
|
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
|
|||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
|
||||
std::vector<Tensor>& output_tensors) {
|
||||
return gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
// Explicitly copy input.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
glBindBuffer(GL_COPY_READ_BUFFER,
|
||||
|
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
|
|||
}
|
||||
|
||||
// Run inference.
|
||||
{
|
||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
output_tensors.reserve(output_size_);
|
||||
for (int i = 0; i < output_size_; ++i) {
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
|
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
|
|||
const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> Process(
|
||||
const std::vector<Tensor>& input_tensors);
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
|
||||
|
||||
absl::Status Close();
|
||||
|
||||
|
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
|
|||
|
||||
absl::StatusOr<std::vector<Tensor>>
|
||||
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||
const std::vector<Tensor>& input_tensors) {
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
||||
std::vector<Tensor> output_tensors;
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
||||
input_tensors[i].GetOpenGlBufferReadView().name(), i));
|
||||
|
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
|||
output_tensors.back().GetOpenGlBufferWriteView().name(), i));
|
||||
}
|
||||
// Run inference.
|
||||
{
|
||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
||||
return tflite_gpu_runner_->Invoke();
|
||||
}
|
||||
}));
|
||||
|
||||
return output_tensors;
|
||||
|
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
|
|||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
|
||||
ASSIGN_OR_RETURN(*output_tensors,
|
||||
gpu_inference_runner_->Process(input_tensors));
|
||||
gpu_inference_runner_->Process(cc, input_tensors));
|
||||
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -289,8 +289,15 @@ class NodeBase {
|
|||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
return *options_.MutableExtension(extension);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -386,8 +393,15 @@ class PacketGenerator {
|
|||
|
||||
template <typename T>
|
||||
T& GetOptions() {
|
||||
return GetOptions(T::ext);
|
||||
}
|
||||
|
||||
// Use this API when the proto extension does not follow the "ext" naming
|
||||
// convention.
|
||||
template <typename E>
|
||||
auto& GetOptions(const E& extension) {
|
||||
options_used_ = true;
|
||||
return *options_.MutableExtension(T::ext);
|
||||
return *options_.MutableExtension(extension);
|
||||
}
|
||||
|
||||
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
|
||||
|
|
|
@ -185,7 +185,7 @@ class CalculatorBaseFactory {
|
|||
// Functions for checking that the calculator has the required GetContract.
|
||||
template <class T>
|
||||
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
|
||||
typedef absl::Status (*GetContractType)(CalculatorContract * cc);
|
||||
typedef absl::Status (*GetContractType)(CalculatorContract* cc);
|
||||
return std::is_same<decltype(&T::GetContract), GetContractType>::value;
|
||||
}
|
||||
template <class T>
|
||||
|
|
|
@ -133,7 +133,12 @@ message GraphTrace {
|
|||
TPU_TASK = 13;
|
||||
GPU_CALIBRATION = 14;
|
||||
PACKET_QUEUED = 15;
|
||||
GPU_TASK_INVOKE = 16;
|
||||
TPU_TASK_INVOKE = 17;
|
||||
}
|
||||
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
|
||||
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
|
||||
// )
|
||||
|
||||
// The timing for one packet set being processed at one caclulator node.
|
||||
message CalculatorTrace {
|
||||
|
|
|
@ -293,7 +293,6 @@ mediapipe_proto_library(
|
|||
name = "rect_proto",
|
||||
srcs = ["rect.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/formats:location_data_proto"],
|
||||
)
|
||||
|
||||
mediapipe_register_type(
|
||||
|
|
|
@ -109,6 +109,11 @@ struct TraceEvent {
|
|||
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
|
||||
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
|
||||
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
|
||||
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
||||
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
||||
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
|
||||
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
|
||||
// )
|
||||
};
|
||||
|
||||
// Packet trace log buffer.
|
||||
|
|
|
@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
|
|||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
||||
|
||||
static void EglThreadExitCallback(void* key_value) {
|
||||
#if defined(__ANDROID__)
|
||||
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||
EGL_NO_CONTEXT);
|
||||
#else
|
||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
||||
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
||||
// implementations, and should be considered as an undocumented vendor
|
||||
// extension.
|
||||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||
//
|
||||
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
|
||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||
#endif
|
||||
eglReleaseThread();
|
||||
}
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
|
|||
import android.graphics.Bitmap;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.framework.image.ImageProperties;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.framework.image.MPImageProperties;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
// TODO: use Preconditions in this file.
|
||||
|
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates an Image packet from an {@link Image}.
|
||||
* Creates a MediaPipe Image packet from a {@link MPImage}.
|
||||
*
|
||||
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
|
||||
*/
|
||||
public Packet createImage(Image image) {
|
||||
public Packet createImage(MPImage image) {
|
||||
// TODO: Choose the best storage from multiple containers.
|
||||
ImageProperties properties = image.getContainedImageProperties().get(0);
|
||||
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
|
||||
MPImageProperties properties = image.getContainedImageProperties().get(0);
|
||||
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
|
||||
ByteBuffer buffer = ByteBufferExtractor.extract(image);
|
||||
int numChannels = 0;
|
||||
switch (properties.getImageFormat()) {
|
||||
case Image.IMAGE_FORMAT_RGBA:
|
||||
case MPImage.IMAGE_FORMAT_RGBA:
|
||||
numChannels = 4;
|
||||
break;
|
||||
case Image.IMAGE_FORMAT_RGB:
|
||||
case MPImage.IMAGE_FORMAT_RGB:
|
||||
numChannels = 3;
|
||||
break;
|
||||
case Image.IMAGE_FORMAT_ALPHA:
|
||||
case MPImage.IMAGE_FORMAT_ALPHA:
|
||||
numChannels = 1;
|
||||
break;
|
||||
default: // fall out
|
||||
|
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
int height = image.getHeight();
|
||||
return createImage(buffer, width, height, numChannels);
|
||||
}
|
||||
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
|
||||
if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) {
|
||||
Bitmap bitmap = BitmapExtractor.extract(image);
|
||||
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
|
||||
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
|
||||
|
|
|
@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
|
|||
import android.graphics.Bitmap;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
|
||||
* Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
|
||||
* {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
public final class BitmapExtractor {
|
||||
|
||||
/**
|
||||
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
|
||||
* Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
|
||||
*
|
||||
* @param image the image to extract {@link android.graphics.Bitmap} from.
|
||||
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
|
||||
* @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
public static Bitmap extract(Image image) {
|
||||
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
|
||||
public static Bitmap extract(MPImage image) {
|
||||
MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
|
||||
if (imageContainer != null) {
|
||||
return ((BitmapImageContainer) imageContainer).getBitmap();
|
||||
} else {
|
||||
// TODO: Support ByteBuffer -> Bitmap conversion.
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
|
||||
"Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
|
||||
+ " supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import android.provider.MediaStore;
|
|||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Builds {@link Image} from {@link android.graphics.Bitmap}.
|
||||
* Builds {@link MPImage} from {@link android.graphics.Bitmap}.
|
||||
*
|
||||
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
|
||||
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
|
||||
|
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates the builder to build {@link Image} from a file.
|
||||
* Creates the builder to build {@link MPImage} from a file.
|
||||
*
|
||||
* @param context the application context.
|
||||
* @param uri the path to the resource file.
|
||||
|
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
|
|||
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
BitmapImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(
|
||||
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,19 +16,19 @@ limitations under the License.
|
|||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
|
||||
class BitmapImageContainer implements ImageContainer {
|
||||
class BitmapImageContainer implements MPImageContainer {
|
||||
|
||||
private final Bitmap bitmap;
|
||||
private final ImageProperties properties;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public BitmapImageContainer(Bitmap bitmap) {
|
||||
this.bitmap = bitmap;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
MPImageProperties.builder()
|
||||
.setImageFormat(convertFormatCode(bitmap.getConfig()))
|
||||
.setStorageType(Image.STORAGE_TYPE_BITMAP)
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BITMAP)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
|
|||
bitmap.recycle();
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
static int convertFormatCode(Bitmap.Config config) {
|
||||
switch (config) {
|
||||
case ALPHA_8:
|
||||
return Image.IMAGE_FORMAT_ALPHA;
|
||||
return MPImage.IMAGE_FORMAT_ALPHA;
|
||||
case ARGB_8888:
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
default:
|
||||
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||
return MPImage.IMAGE_FORMAT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
|
|||
import android.os.Build.VERSION;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link ByteBuffer} from {@link Image}.
|
||||
* Utility for extracting {@link ByteBuffer} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
|
||||
* {@link IllegalArgumentException} will be thrown.
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
|
||||
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
public class ByteBufferExtractor {
|
||||
|
||||
/**
|
||||
* Extracts a {@link ByteBuffer} from an {@link Image}.
|
||||
* Extracts a {@link ByteBuffer} from a {@link MPImage}.
|
||||
*
|
||||
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
|
||||
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
|
||||
* MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
|
||||
*
|
||||
* @see Image#getContainedImageProperties()
|
||||
* @see MPImage#getContainedImageProperties()
|
||||
* @return A read-only {@link ByteBuffer}.
|
||||
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
|
||||
*/
|
||||
@SuppressLint("SwitchIntDef")
|
||||
public static ByteBuffer extract(Image image) {
|
||||
ImageContainer container = image.getContainer();
|
||||
public static ByteBuffer extract(MPImage image) {
|
||||
MPImageContainer container = image.getContainer();
|
||||
switch (container.getImageProperties().getStorageType()) {
|
||||
case Image.STORAGE_TYPE_BYTEBUFFER:
|
||||
case MPImage.STORAGE_TYPE_BYTEBUFFER:
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
|
||||
+ " supported");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
|
||||
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
|
||||
*
|
||||
* <p>Format conversion spec:
|
||||
*
|
||||
|
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
|
|||
*
|
||||
* @param image the image to extract buffer from.
|
||||
* @param targetFormat the image format of the result bytebuffer.
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
|
||||
ImageContainer container;
|
||||
ImageProperties byteBufferProperties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
MPImageContainer container;
|
||||
MPImageProperties byteBufferProperties =
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
|
||||
.setImageFormat(targetFormat)
|
||||
.build();
|
||||
if ((container = image.getContainer(byteBufferProperties)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
|
||||
.asReadOnlyBuffer();
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
|
||||
ByteBuffer byteBuffer =
|
||||
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
|
||||
|
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
|
|||
return byteBuffer;
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
|
||||
"Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
|
||||
+ " Bytebuffer is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
|
||||
/** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
|
||||
@AutoValue
|
||||
abstract static class Result {
|
||||
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||
/**
|
||||
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
|
||||
*/
|
||||
public abstract ByteBuffer buffer();
|
||||
|
||||
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||
@ImageFormat
|
||||
/**
|
||||
* Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
|
||||
*/
|
||||
@MPImageFormat
|
||||
public abstract int format();
|
||||
|
||||
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
|
||||
static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
|
||||
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
|
||||
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
|
||||
*
|
||||
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
|
||||
*
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
|
||||
* given {@code imageFormat}
|
||||
*/
|
||||
static Result extractInRecommendedFormat(Image image) {
|
||||
ImageContainer container;
|
||||
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||
static Result extractInRecommendedFormat(MPImage image) {
|
||||
MPImageContainer container;
|
||||
if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
|
||||
@ImageFormat int format = adviseImageFormat(bitmap);
|
||||
@MPImageFormat int format = adviseImageFormat(bitmap);
|
||||
Result result =
|
||||
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
|
||||
|
||||
boolean unused =
|
||||
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
|
||||
return result;
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return Result.create(
|
||||
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
|
||||
byteBufferImageContainer.getImageFormat());
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
|
||||
+ " is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
private static int adviseImageFormat(Bitmap bitmap) {
|
||||
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
|
||||
"Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
|
||||
+ " supported",
|
||||
bitmap.getConfig()));
|
||||
}
|
||||
}
|
||||
|
||||
private static ByteBuffer extractByteBufferFromBitmap(
|
||||
Bitmap bitmap, @ImageFormat int imageFormat) {
|
||||
Bitmap bitmap, @MPImageFormat int imageFormat) {
|
||||
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
|
||||
"Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
|
||||
+ " supported");
|
||||
}
|
||||
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||
if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
|
||||
bitmap.copyPixelsToBuffer(buffer);
|
||||
buffer.rewind();
|
||||
return buffer;
|
||||
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
|
||||
} else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
|
||||
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
|
||||
int w = bitmap.getWidth();
|
||||
int h = bitmap.getHeight();
|
||||
|
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
|
|||
}
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
|
||||
"Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
|
||||
+ " %d is not supported",
|
||||
bitmap.getConfig(), imageFormat));
|
||||
}
|
||||
|
||||
private static ByteBuffer convertByteBuffer(
|
||||
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
|
||||
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
|
||||
if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
|
||||
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
|
||||
// array reversely to convert in-place.
|
||||
|
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
|
|||
target.put(array, 0, target.capacity());
|
||||
target.rewind();
|
||||
return target;
|
||||
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
|
||||
} else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
|
||||
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
|
||||
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
|
||||
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
|
||||
// array to convert in-place.
|
||||
|
|
|
@ -15,11 +15,11 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* Builds a {@link Image} from a {@link ByteBuffer}.
|
||||
* Builds a {@link MPImage} from a {@link ByteBuffer}.
|
||||
*
|
||||
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
|
||||
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
|
||||
|
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
|
|||
private final ByteBuffer buffer;
|
||||
private final int width;
|
||||
private final int height;
|
||||
@ImageFormat private final int imageFormat;
|
||||
@MPImageFormat private final int imageFormat;
|
||||
|
||||
// Optional fields.
|
||||
private long timestamp;
|
||||
|
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
|
|||
* @param imageFormat how the data encode the image.
|
||||
*/
|
||||
public ByteBufferImageBuilder(
|
||||
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
|
||||
ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
|
||||
this.buffer = byteBuffer;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
|
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
|
|||
this.timestamp = 0;
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
ByteBufferImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,21 +15,19 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
class ByteBufferImageContainer implements ImageContainer {
|
||||
class ByteBufferImageContainer implements MPImageContainer {
|
||||
|
||||
private final ByteBuffer buffer;
|
||||
private final ImageProperties properties;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public ByteBufferImageContainer(
|
||||
ByteBuffer buffer,
|
||||
@ImageFormat int imageFormat) {
|
||||
public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
|
||||
this.buffer = buffer;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
|
||||
.setImageFormat(imageFormat)
|
||||
.build();
|
||||
}
|
||||
|
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the image format.
|
||||
*/
|
||||
@ImageFormat
|
||||
/** Returns the image format. */
|
||||
@MPImageFormat
|
||||
public int getImageFormat() {
|
||||
return properties.getImageFormat();
|
||||
}
|
||||
|
|
|
@ -29,10 +29,10 @@ import java.util.Map.Entry;
|
|||
/**
|
||||
* The wrapper class for image objects.
|
||||
*
|
||||
* <p>{@link Image} is designed to be an immutable image container, which could be shared
|
||||
* <p>{@link MPImage} is designed to be an immutable image container, which could be shared
|
||||
* cross-platforms.
|
||||
*
|
||||
* <p>To construct an {@link Image}, use the provided builders:
|
||||
* <p>To construct a {@link MPImage}, use the provided builders:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link ByteBufferImageBuilder}
|
||||
|
@ -40,7 +40,7 @@ import java.util.Map.Entry;
|
|||
* <li>{@link MediaImageBuilder}
|
||||
* </ul>
|
||||
*
|
||||
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
|
||||
* <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
|
||||
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
|
||||
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
|
||||
*
|
||||
|
@ -53,7 +53,7 @@ import java.util.Map.Entry;
|
|||
* <li>{@link MediaImageExtractor}
|
||||
* </ul>
|
||||
*/
|
||||
public class Image implements Closeable {
|
||||
public class MPImage implements Closeable {
|
||||
|
||||
/** Specifies the image format of an image. */
|
||||
@IntDef({
|
||||
|
@ -69,7 +69,7 @@ public class Image implements Closeable {
|
|||
IMAGE_FORMAT_JPEG,
|
||||
})
|
||||
@Retention(RetentionPolicy.SOURCE)
|
||||
public @interface ImageFormat {}
|
||||
public @interface MPImageFormat {}
|
||||
|
||||
public static final int IMAGE_FORMAT_UNKNOWN = 0;
|
||||
public static final int IMAGE_FORMAT_RGBA = 1;
|
||||
|
@ -98,14 +98,14 @@ public class Image implements Closeable {
|
|||
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
|
||||
|
||||
/**
|
||||
* Returns a list of supported image properties for this {@link Image}.
|
||||
* Returns a list of supported image properties for this {@link MPImage}.
|
||||
*
|
||||
* <p>Currently {@link Image} only support single storage type so the size of return list will
|
||||
* <p>Currently {@link MPImage} only support single storage type so the size of return list will
|
||||
* always be 1.
|
||||
*
|
||||
* @see ImageProperties
|
||||
* @see MPImageProperties
|
||||
*/
|
||||
public List<ImageProperties> getContainedImageProperties() {
|
||||
public List<MPImageProperties> getContainedImageProperties() {
|
||||
return Collections.singletonList(getContainer().getImageProperties());
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ public class Image implements Closeable {
|
|||
return height;
|
||||
}
|
||||
|
||||
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
|
||||
/** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
|
||||
private synchronized void acquire() {
|
||||
referenceCount += 1;
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ public class Image implements Closeable {
|
|||
/**
|
||||
* Removes a reference that was previously acquired or init.
|
||||
*
|
||||
* <p>When {@link Image} is created, it has 1 reference count.
|
||||
* <p>When {@link MPImage} is created, it has 1 reference count.
|
||||
*
|
||||
* <p>When the reference count becomes 0, it will release the resource under the hood.
|
||||
*/
|
||||
|
@ -141,24 +141,24 @@ public class Image implements Closeable {
|
|||
public synchronized void close() {
|
||||
referenceCount -= 1;
|
||||
if (referenceCount == 0) {
|
||||
for (ImageContainer imageContainer : containerMap.values()) {
|
||||
for (MPImageContainer imageContainer : containerMap.values()) {
|
||||
imageContainer.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Advanced API access for {@link Image}. */
|
||||
/** Advanced API access for {@link MPImage}. */
|
||||
static final class Internal {
|
||||
|
||||
/**
|
||||
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
|
||||
* Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
|
||||
*
|
||||
* <p>This method is more useful for image consumer to acquire a reference so image resource
|
||||
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
|
||||
* method.
|
||||
*
|
||||
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
|
||||
* #close()} to indicate it doesn't need this {@link Image} anymore.
|
||||
* <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
|
||||
* #close()} to indicate it doesn't need this {@link MPImage} anymore.
|
||||
*
|
||||
* @see #close()
|
||||
*/
|
||||
|
@ -166,10 +166,10 @@ public class Image implements Closeable {
|
|||
image.acquire();
|
||||
}
|
||||
|
||||
private final Image image;
|
||||
private final MPImage image;
|
||||
|
||||
// Only Image creates the internal helper.
|
||||
private Internal(Image image) {
|
||||
// Only MPImage creates the internal helper.
|
||||
private Internal(MPImage image) {
|
||||
this.image = image;
|
||||
}
|
||||
}
|
||||
|
@ -179,15 +179,15 @@ public class Image implements Closeable {
|
|||
return new Internal(this);
|
||||
}
|
||||
|
||||
private final Map<ImageProperties, ImageContainer> containerMap;
|
||||
private final Map<MPImageProperties, MPImageContainer> containerMap;
|
||||
private final long timestamp;
|
||||
private final int width;
|
||||
private final int height;
|
||||
|
||||
private int referenceCount;
|
||||
|
||||
/** Constructs an {@link Image} with a built container. */
|
||||
Image(ImageContainer container, long timestamp, int width, int height) {
|
||||
/** Constructs a {@link MPImage} with a built container. */
|
||||
MPImage(MPImageContainer container, long timestamp, int width, int height) {
|
||||
this.containerMap = new HashMap<>();
|
||||
containerMap.put(container.getImageProperties(), container);
|
||||
this.timestamp = timestamp;
|
||||
|
@ -201,10 +201,10 @@ public class Image implements Closeable {
|
|||
*
|
||||
* @return the current container.
|
||||
*/
|
||||
ImageContainer getContainer() {
|
||||
MPImageContainer getContainer() {
|
||||
// According to the design, in the future we will support multiple containers in one image.
|
||||
// Currently just return the original container.
|
||||
// TODO: Cache multiple containers in Image.
|
||||
// TODO: Cache multiple containers in MPImage.
|
||||
return containerMap.values().iterator().next();
|
||||
}
|
||||
|
||||
|
@ -214,8 +214,8 @@ public class Image implements Closeable {
|
|||
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
|
||||
*/
|
||||
@Nullable
|
||||
ImageContainer getContainer(@StorageType int storageType) {
|
||||
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
|
||||
MPImageContainer getContainer(@StorageType int storageType) {
|
||||
for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
|
||||
if (entry.getKey().getStorageType() == storageType) {
|
||||
return entry.getValue();
|
||||
}
|
||||
|
@ -225,13 +225,13 @@ public class Image implements Closeable {
|
|||
|
||||
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
|
||||
@Nullable
|
||||
ImageContainer getContainer(ImageProperties imageProperties) {
|
||||
MPImageContainer getContainer(MPImageProperties imageProperties) {
|
||||
return containerMap.get(imageProperties);
|
||||
}
|
||||
|
||||
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
|
||||
boolean addContainer(ImageContainer container) {
|
||||
ImageProperties imageProperties = container.getImageProperties();
|
||||
boolean addContainer(MPImageContainer container) {
|
||||
MPImageProperties imageProperties = container.getImageProperties();
|
||||
if (containerMap.containsKey(imageProperties)) {
|
||||
return false;
|
||||
}
|
|
@ -14,14 +14,14 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Lightweight abstraction for an object that can receive {@link Image} */
|
||||
public interface ImageConsumer {
|
||||
/** Lightweight abstraction for an object that can receive {@link MPImage} */
|
||||
public interface MPImageConsumer {
|
||||
|
||||
/**
|
||||
* Called when an {@link Image} is available.
|
||||
* Called when a {@link MPImage} is available.
|
||||
*
|
||||
* <p>The argument is only guaranteed to be available until this method returns. if you need to
|
||||
* extend its life time, acquire it, then release it when done.
|
||||
*/
|
||||
void onNewImage(Image image);
|
||||
void onNewMPImage(MPImage image);
|
||||
}
|
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Manages internal image data storage. The interface is package-private. */
|
||||
interface ImageContainer {
|
||||
interface MPImageContainer {
|
||||
/** Returns the properties of the contained image. */
|
||||
ImageProperties getImageProperties();
|
||||
MPImageProperties getImageProperties();
|
||||
|
||||
/** Close the image container and releases the image resource inside. */
|
||||
void close();
|
|
@ -14,9 +14,9 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Lightweight abstraction for an object that produce {@link Image} */
|
||||
public interface ImageProducer {
|
||||
/** Lightweight abstraction for an object that produce {@link MPImage} */
|
||||
public interface MPImageProducer {
|
||||
|
||||
/** Sets the consumer that receives the {@link Image}. */
|
||||
void setImageConsumer(ImageConsumer imageConsumer);
|
||||
/** Sets the consumer that receives the {@link MPImage}. */
|
||||
void setMPImageConsumer(MPImageConsumer imageConsumer);
|
||||
}
|
|
@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
|
|||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.auto.value.extension.memoized.Memoized;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.Image.StorageType;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.StorageType;
|
||||
|
||||
/** Groups a set of properties to describe how an image is stored. */
|
||||
@AutoValue
|
||||
public abstract class ImageProperties {
|
||||
public abstract class MPImageProperties {
|
||||
|
||||
/**
|
||||
* Gets the pixel format of the image.
|
||||
*
|
||||
* @see Image.ImageFormat
|
||||
* @see MPImage.MPImageFormat
|
||||
*/
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
public abstract int getImageFormat();
|
||||
|
||||
/**
|
||||
* Gets the storage type of the image.
|
||||
*
|
||||
* @see Image.StorageType
|
||||
* @see MPImage.StorageType
|
||||
*/
|
||||
@StorageType
|
||||
public abstract int getStorageType();
|
||||
|
@ -45,36 +45,36 @@ public abstract class ImageProperties {
|
|||
public abstract int hashCode();
|
||||
|
||||
/**
|
||||
* Creates a builder of {@link ImageProperties}.
|
||||
* Creates a builder of {@link MPImageProperties}.
|
||||
*
|
||||
* @see ImageProperties.Builder
|
||||
* @see MPImageProperties.Builder
|
||||
*/
|
||||
static Builder builder() {
|
||||
return new AutoValue_ImageProperties.Builder();
|
||||
return new AutoValue_MPImageProperties.Builder();
|
||||
}
|
||||
|
||||
/** Builds a {@link ImageProperties}. */
|
||||
/** Builds a {@link MPImageProperties}. */
|
||||
@AutoValue.Builder
|
||||
abstract static class Builder {
|
||||
|
||||
/**
|
||||
* Sets the {@link Image.ImageFormat}.
|
||||
* Sets the {@link MPImage.MPImageFormat}.
|
||||
*
|
||||
* @see ImageProperties#getImageFormat
|
||||
* @see MPImageProperties#getImageFormat
|
||||
*/
|
||||
abstract Builder setImageFormat(@ImageFormat int value);
|
||||
abstract Builder setImageFormat(@MPImageFormat int value);
|
||||
|
||||
/**
|
||||
* Sets the {@link Image.StorageType}.
|
||||
* Sets the {@link MPImage.StorageType}.
|
||||
*
|
||||
* @see ImageProperties#getStorageType
|
||||
* @see MPImageProperties#getStorageType
|
||||
*/
|
||||
abstract Builder setStorageType(@StorageType int value);
|
||||
|
||||
/** Builds the {@link ImageProperties}. */
|
||||
abstract ImageProperties build();
|
||||
/** Builds the {@link MPImageProperties}. */
|
||||
abstract MPImageProperties build();
|
||||
}
|
||||
|
||||
// Hide the constructor.
|
||||
ImageProperties() {}
|
||||
MPImageProperties() {}
|
||||
}
|
|
@ -15,11 +15,12 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Builds {@link Image} from {@link android.media.Image}.
|
||||
* Builds {@link MPImage} from {@link android.media.Image}.
|
||||
*
|
||||
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
|
||||
* content in it.
|
||||
|
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
|
|||
public class MediaImageBuilder {
|
||||
|
||||
// Mandatory fields.
|
||||
private final android.media.Image mediaImage;
|
||||
private final Image mediaImage;
|
||||
|
||||
// Optional fields.
|
||||
private long timestamp;
|
||||
|
@ -40,20 +41,20 @@ public class MediaImageBuilder {
|
|||
*
|
||||
* @param mediaImage image data object.
|
||||
*/
|
||||
public MediaImageBuilder(android.media.Image mediaImage) {
|
||||
public MediaImageBuilder(Image mediaImage) {
|
||||
this.mediaImage = mediaImage;
|
||||
this.timestamp = 0;
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
MediaImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(
|
||||
new MediaImageContainer(mediaImage),
|
||||
timestamp,
|
||||
mediaImage.getWidth(),
|
||||
|
|
|
@ -15,33 +15,34 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build;
|
||||
import android.os.Build.VERSION;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
|
||||
@RequiresApi(VERSION_CODES.KITKAT)
|
||||
class MediaImageContainer implements ImageContainer {
|
||||
class MediaImageContainer implements MPImageContainer {
|
||||
|
||||
private final android.media.Image mediaImage;
|
||||
private final ImageProperties properties;
|
||||
private final Image mediaImage;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public MediaImageContainer(android.media.Image mediaImage) {
|
||||
public MediaImageContainer(Image mediaImage) {
|
||||
this.mediaImage = mediaImage;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
|
||||
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
|
||||
.build();
|
||||
}
|
||||
|
||||
public android.media.Image getImage() {
|
||||
public Image getImage() {
|
||||
return mediaImage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
|
|||
mediaImage.close();
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
static int convertFormatCode(int graphicsFormat) {
|
||||
// We only cover the format mentioned in
|
||||
// https://developer.android.com/reference/android/media/Image#getFormat()
|
||||
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
|
||||
return Image.IMAGE_FORMAT_RGB;
|
||||
return MPImage.IMAGE_FORMAT_RGB;
|
||||
}
|
||||
}
|
||||
switch (graphicsFormat) {
|
||||
case android.graphics.ImageFormat.JPEG:
|
||||
return Image.IMAGE_FORMAT_JPEG;
|
||||
return MPImage.IMAGE_FORMAT_JPEG;
|
||||
case android.graphics.ImageFormat.YUV_420_888:
|
||||
return Image.IMAGE_FORMAT_YUV_420_888;
|
||||
return MPImage.IMAGE_FORMAT_YUV_420_888;
|
||||
default:
|
||||
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||
return MPImage.IMAGE_FORMAT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,13 +15,14 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link android.media.Image} from {@link Image}.
|
||||
* Utility for extracting {@link android.media.Image} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
|
||||
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
@RequiresApi(VERSION_CODES.KITKAT)
|
||||
|
@ -30,20 +31,20 @@ public class MediaImageExtractor {
|
|||
private MediaImageExtractor() {}
|
||||
|
||||
/**
|
||||
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
|
||||
* {@link Image} that built from {@link MediaImageBuilder}.
|
||||
* Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
|
||||
* {@link MPImage} that built from {@link MediaImageBuilder}.
|
||||
*
|
||||
* @param image the image to extract {@link android.media.Image} from.
|
||||
* @return {@link android.media.Image} that stored in {@link Image}.
|
||||
* @return {@link android.media.Image} that stored in {@link MPImage}.
|
||||
* @throws IllegalArgumentException if the extraction failed.
|
||||
*/
|
||||
public static android.media.Image extract(Image image) {
|
||||
ImageContainer container;
|
||||
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||
public static Image extract(MPImage image) {
|
||||
MPImageContainer container;
|
||||
if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||
return ((MediaImageContainer) container).getImage();
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Extract Media Image from an Image created by objects other than Media Image"
|
||||
"Extract Media Image from a MPImage created by objects other than Media Image"
|
||||
+ " is not supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019-2020 The MediaPipe Authors.
|
||||
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -328,19 +328,14 @@ def mediapipe_java_proto_srcs(name = ""):
|
|||
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
|
@ -349,8 +344,18 @@ def mediapipe_java_proto_srcs(name = ""):
|
|||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -23,15 +24,12 @@ package(
|
|||
py_library(
|
||||
name = "data_util",
|
||||
srcs = ["data_util.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "data_util_test",
|
||||
srcs = ["data_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/data/testdata"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":data_util"],
|
||||
)
|
||||
|
||||
|
@ -44,8 +42,6 @@ py_library(
|
|||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
@ -55,14 +51,11 @@ py_test(
|
|||
py_library(
|
||||
name = "classification_dataset",
|
||||
srcs = ["classification_dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classification_dataset_test",
|
||||
srcs = ["classification_dataset_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":classification_dataset"],
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
|
@ -23,7 +24,6 @@ licenses(["notice"])
|
|||
py_library(
|
||||
name = "custom_model",
|
||||
srcs = ["custom_model.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
|
@ -34,8 +34,6 @@ py_library(
|
|||
py_test(
|
||||
name = "custom_model_test",
|
||||
srcs = ["custom_model_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
@ -45,7 +43,6 @@ py_test(
|
|||
py_library(
|
||||
name = "classifier",
|
||||
srcs = ["classifier.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":custom_model",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
|
@ -55,8 +52,6 @@ py_library(
|
|||
py_test(
|
||||
name = "classifier_test",
|
||||
srcs = ["classifier_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
@ -24,7 +25,6 @@ py_library(
|
|||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
|
@ -34,7 +34,6 @@ py_library(
|
|||
py_library(
|
||||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
"//mediapipe/model_maker/python/core/data:dataset",
|
||||
|
@ -44,8 +43,6 @@ py_library(
|
|||
py_test(
|
||||
name = "model_util_test",
|
||||
srcs = ["model_util_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":model_util",
|
||||
":quantization",
|
||||
|
@ -62,8 +59,6 @@ py_library(
|
|||
py_test(
|
||||
name = "loss_functions_test",
|
||||
srcs = ["loss_functions_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":loss_functions"],
|
||||
)
|
||||
|
||||
|
@ -77,8 +72,6 @@ py_library(
|
|||
py_test(
|
||||
name = "quantization_test",
|
||||
srcs = ["quantization_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [
|
||||
":quantization",
|
||||
":test_util",
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python library rule.
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python library rule.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
|
|
|
@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.train_epochs,
|
||||
steps_per_epoch=hparams.steps_per_epoch,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
||||
|
|
|
@ -161,7 +161,7 @@ class Texture {
|
|||
|
||||
~Texture() {
|
||||
if (is_owned_) {
|
||||
glDeleteProgram(handle_);
|
||||
glDeleteTextures(1, &handle_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -87,6 +87,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "builtin_task_graphs",
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
],
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
"""The public facing packet getter APIs."""
|
||||
|
||||
from typing import List, Type
|
||||
from typing import List
|
||||
|
||||
from google.protobuf import message
|
||||
from google.protobuf import symbol_database
|
||||
|
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
|
|||
get_matrix = _packet_getter.get_matrix
|
||||
|
||||
|
||||
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
|
||||
def get_proto(packet: mp_packet.Packet) -> message.Message:
|
||||
"""Get the content of a MediaPipe proto Packet as a proto message.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -46,8 +46,10 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
@ -17,7 +17,7 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// A single classification result.
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto;
|
|||
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// List of predicted categories with an optional timestamp.
|
||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "EmbeddingsProto";
|
||||
|
||||
// Defines a dense floating-point embedding.
|
||||
message FloatEmbedding {
|
||||
repeated float values = 1 [packed = true];
|
||||
|
|
|
@ -30,9 +30,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -128,12 +130,21 @@ absl::Status ConfigureImageToTensorCalculator(
|
|||
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
||||
std);
|
||||
}
|
||||
// TODO: need to support different GPU origin on differnt
|
||||
// platforms or applications.
|
||||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool DetermineImagePreprocessingGpuBackend(
|
||||
const core::proto::Acceleration& acceleration) {
|
||||
return acceleration.has_gpu();
|
||||
}
|
||||
|
||||
absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
||||
bool use_gpu,
|
||||
ImagePreprocessingOptions* options) {
|
||||
ASSIGN_OR_RETURN(auto image_tensor_specs,
|
||||
BuildImageTensorSpecs(model_resources));
|
||||
|
@ -141,7 +152,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources,
|
|||
image_tensor_specs, options->mutable_image_to_tensor_options()));
|
||||
// The GPU backend isn't able to process int data. If the input tensor is
|
||||
// quantized, forces the image preprocessing graph to use CPU backend.
|
||||
if (image_tensor_specs.tensor_type == tflite::TensorType_UINT8) {
|
||||
if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) {
|
||||
options->set_backend(ImagePreprocessingOptions::GPU_BACKEND);
|
||||
} else {
|
||||
options->set_backend(ImagePreprocessingOptions::CPU_BACKEND);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -19,20 +19,26 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
|
||||
// Configures an ImagePreprocessing subgraph using the provided model resources.
|
||||
// Configures an ImagePreprocessing subgraph using the provided model resources
|
||||
// When use_gpu is true, use GPU as backend to convert image to tensor.
|
||||
// - Accepts CPU input images and outputs CPU tensors.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// auto& preprocessing =
|
||||
// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
// core::proto::Acceleration acceleration;
|
||||
// acceleration.mutable_xnnpack();
|
||||
// bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration);
|
||||
// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
// model_resources,
|
||||
// use_gpu,
|
||||
// &preprocessing.GetOptions<ImagePreprocessingOptions>()));
|
||||
//
|
||||
// The resulting ImagePreprocessing subgraph has the following I/O:
|
||||
|
@ -56,9 +62,14 @@ namespace components {
|
|||
// The image that has the pixel data stored on the target storage (CPU vs
|
||||
// GPU).
|
||||
absl::Status ConfigureImagePreprocessing(
|
||||
const core::ModelResources& model_resources,
|
||||
const core::ModelResources& model_resources, bool use_gpu,
|
||||
ImagePreprocessingOptions* options);
|
||||
|
||||
// Determine if the image preprocessing subgraph should use GPU as the backend
|
||||
// according to the given acceleration setting.
|
||||
bool DetermineImagePreprocessingGpuBackend(
|
||||
const core::proto::Acceleration& acceleration);
|
||||
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -156,21 +156,24 @@ absl::StatusOr<CalculatorGraphConfig> ModelTaskGraph::GetConfig(
|
|||
}
|
||||
|
||||
absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||
const std::string tag_suffix) {
|
||||
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
||||
if (!model_resources_cache_service.IsAvailable()) {
|
||||
ASSIGN_OR_RETURN(local_model_resources_,
|
||||
ASSIGN_OR_RETURN(auto local_model_resource,
|
||||
ModelResources::Create("", std::move(external_file)));
|
||||
LOG(WARNING)
|
||||
<< "A local ModelResources object is created. Please consider using "
|
||||
"ModelResourcesCacheService to cache the created ModelResources "
|
||||
"object in the CalculatorGraph.";
|
||||
return local_model_resources_.get();
|
||||
local_model_resources_.push_back(std::move(local_model_resource));
|
||||
return local_model_resources_.back().get();
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto op_resolver_packet,
|
||||
model_resources_cache_service.GetObject().GetGraphOpResolverPacket());
|
||||
const std::string tag = CreateModelResourcesTag(sc->OriginalNode());
|
||||
const std::string tag =
|
||||
absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix);
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
ModelResources::Create(tag, std::move(external_file),
|
||||
op_resolver_packet));
|
||||
|
@ -182,7 +185,8 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
|||
|
||||
absl::StatusOr<const ModelAssetBundleResources*>
|
||||
ModelTaskGraph::CreateModelAssetBundleResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||
const std::string tag_suffix) {
|
||||
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
||||
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
|
||||
// if external file is set by file pointer, no need to add the model asset
|
||||
|
@ -190,7 +194,7 @@ ModelTaskGraph::CreateModelAssetBundleResources(
|
|||
// not owned by this model asset bundle resources.
|
||||
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
|
||||
ASSIGN_OR_RETURN(
|
||||
local_model_asset_bundle_resources_,
|
||||
auto local_model_asset_bundle_resource,
|
||||
ModelAssetBundleResources::Create("", std::move(external_file)));
|
||||
if (!has_file_pointer_meta) {
|
||||
LOG(WARNING)
|
||||
|
@ -198,10 +202,12 @@ ModelTaskGraph::CreateModelAssetBundleResources(
|
|||
"ModelResourcesCacheService to cache the created ModelResources "
|
||||
"object in the CalculatorGraph.";
|
||||
}
|
||||
return local_model_asset_bundle_resources_.get();
|
||||
local_model_asset_bundle_resources_.push_back(
|
||||
std::move(local_model_asset_bundle_resource));
|
||||
return local_model_asset_bundle_resources_.back().get();
|
||||
}
|
||||
const std::string tag =
|
||||
CreateModelAssetBundleResourcesTag(sc->OriginalNode());
|
||||
const std::string tag = absl::StrCat(
|
||||
CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix);
|
||||
ASSIGN_OR_RETURN(
|
||||
auto model_bundle_resources,
|
||||
ModelAssetBundleResources::Create(tag, std::move(external_file)));
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
|
@ -75,9 +76,14 @@ class ModelTaskGraph : public Subgraph {
|
|||
// construction stage. Note that the external file contents will be moved
|
||||
// into the model resources object on creation. The returned model resources
|
||||
// pointer will provide graph authors with the access to the metadata
|
||||
// extractor and the tflite model.
|
||||
// extractor and the tflite model. When the model resources graph service is
|
||||
// available, a tag is generated internally asscoiated with the created model
|
||||
// resource. If more than one model resources are created in a graph, the
|
||||
// model resources graph service add the tag_suffix to support multiple
|
||||
// resources.
|
||||
absl::StatusOr<const ModelResources*> CreateModelResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||
const std::string tag_suffix = "");
|
||||
|
||||
// If the model resources graph service is available, creates a model asset
|
||||
// bundle resources object from the subgraph context, and caches the created
|
||||
|
@ -103,10 +109,15 @@ class ModelTaskGraph : public Subgraph {
|
|||
// that can only be used in the graph construction stage. Note that the
|
||||
// external file contents will be moved into the model asset bundle resources
|
||||
// object on creation. The returned model asset bundle resources pointer will
|
||||
// provide graph authors with the access to extracted model files.
|
||||
// provide graph authors with the access to extracted model files. When the
|
||||
// model resources graph service is available, a tag is generated internally
|
||||
// asscoiated with the created model asset bundle resource. If more than one
|
||||
// model asset bundle resources are created in a graph, the model resources
|
||||
// graph service add the tag_suffix to support multiple resources.
|
||||
absl::StatusOr<const ModelAssetBundleResources*>
|
||||
CreateModelAssetBundleResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file,
|
||||
const std::string tag_suffix = "");
|
||||
|
||||
// Inserts a mediapipe task inference subgraph into the provided
|
||||
// GraphBuilder. The returned node provides the following interfaces to the
|
||||
|
@ -124,9 +135,9 @@ class ModelTaskGraph : public Subgraph {
|
|||
api2::builder::Graph& graph) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<ModelResources> local_model_resources_;
|
||||
std::vector<std::unique_ptr<ModelResources>> local_model_resources_;
|
||||
|
||||
std::unique_ptr<ModelAssetBundleResources>
|
||||
std::vector<std::unique_ptr<ModelAssetBundleResources>>
|
||||
local_model_asset_bundle_resources_;
|
||||
};
|
||||
|
||||
|
|
|
@ -63,6 +63,29 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_classifier_test",
|
||||
srcs = ["text_classifier_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:bert_text_classifier_models",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
deps = [
|
||||
":text_classifier",
|
||||
":text_classifier_test_utils",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier_test_utils",
|
||||
srcs = ["text_classifier_test_utils.cc"],
|
||||
|
|
|
@ -49,9 +49,6 @@ using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
|||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::IgnoringRepeatedFieldOrdering;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kEpsilon = 0.001;
|
||||
constexpr int kMaxSeqLen = 128;
|
||||
|
@ -110,127 +107,6 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
|||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult negative_result,
|
||||
classifier->Classify("unflinchingly bleak and desperate"));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.956 }
|
||||
categories { category_name: "positive" score: 0.044 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("it's a charming and often affecting journey"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.0 }
|
||||
categories { category_name: "positive" score: 1.0 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
|
||||
classifier->Classify("What a waste of my time."));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.813 }
|
||||
categories { category_name: "Positive" score: 0.187 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("This is the best movie I’ve seen in recent years. "
|
||||
"Strongly recommend it!"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.487 }
|
||||
categories { category_name: "Positive" score: 0.513 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||
options->base_options.op_resolver = CreateCustomResolver();
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify("hello"));
|
||||
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { index: 1 score: 1 }
|
||||
categories { index: 0 score: 1 }
|
||||
categories { index: 2 score: 0 }
|
||||
}
|
||||
}
|
||||
)pb"))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, BertLongPositive) {
|
||||
std::stringstream ss_for_positive_review;
|
||||
ss_for_positive_review
|
||||
<< "it's a charming and often affecting journey and this is a long";
|
||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||
ss_for_positive_review << " long";
|
||||
}
|
||||
ss_for_positive_review << " movie review";
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify(ss_for_positive_review.str()));
|
||||
ASSERT_THAT(result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.014 }
|
||||
categories { category_name: "positive" score: 0.986 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
|
|
|
@ -73,7 +73,18 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: This test fails in OSS
|
||||
cc_test(
|
||||
name = "sentencepiece_tokenizer_test",
|
||||
srcs = ["sentencepiece_tokenizer_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:albert_model",
|
||||
],
|
||||
deps = [
|
||||
":sentencepiece_tokenizer",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer_utils",
|
||||
|
@ -97,7 +108,32 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: This test fails in OSS
|
||||
cc_test(
|
||||
name = "tokenizer_utils_test",
|
||||
srcs = ["tokenizer_utils_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:albert_model",
|
||||
"//mediapipe/tasks/testdata/text:mobile_bert_model",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":regex_tokenizer",
|
||||
":sentencepiece_tokenizer",
|
||||
":tokenizer_utils",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "regex_tokenizer",
|
||||
|
|
|
@ -21,12 +21,23 @@ cc_library(
|
|||
hdrs = ["running_mode.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_processing_options",
|
||||
hdrs = ["image_processing_options.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base_vision_task_api",
|
||||
hdrs = ["base_vision_task_api.h"],
|
||||
deps = [
|
||||
":image_processing_options",
|
||||
":running_mode",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"@com_google_absl//absl/status",
|
||||
|
|
|
@ -16,15 +16,20 @@ limitations under the License.
|
|||
#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -87,6 +92,60 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi {
|
|||
return runner_->Send(std::move(inputs));
|
||||
}
|
||||
|
||||
// Convert from ImageProcessingOptions to NormalizedRect, performing sanity
|
||||
// checks on-the-fly. If the input ImageProcessingOptions is not present,
|
||||
// returns a default NormalizedRect covering the whole image with rotation set
|
||||
// to 0. If 'roi_allowed' is false, an error will be returned if the input
|
||||
// ImageProcessingOptions has its 'region_or_interest' field set.
|
||||
static absl::StatusOr<mediapipe::NormalizedRect> ConvertToNormalizedRect(
|
||||
std::optional<ImageProcessingOptions> options, bool roi_allowed = true) {
|
||||
mediapipe::NormalizedRect normalized_rect;
|
||||
normalized_rect.set_rotation(0);
|
||||
normalized_rect.set_x_center(0.5);
|
||||
normalized_rect.set_y_center(0.5);
|
||||
normalized_rect.set_width(1.0);
|
||||
normalized_rect.set_height(1.0);
|
||||
if (!options.has_value()) {
|
||||
return normalized_rect;
|
||||
}
|
||||
|
||||
if (options->rotation_degrees % 90 != 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Expected rotation to be a multiple of 90°.",
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||
}
|
||||
// Convert to radians counter-clockwise.
|
||||
normalized_rect.set_rotation(-options->rotation_degrees * M_PI / 180.0);
|
||||
|
||||
if (options->region_of_interest.has_value()) {
|
||||
if (!roi_allowed) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"This task doesn't support region-of-interest.",
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||
}
|
||||
auto& roi = *options->region_of_interest;
|
||||
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Expected Rect with left < right and top < bottom.",
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||
}
|
||||
if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"Expected Rect values to be in [0,1].",
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError);
|
||||
}
|
||||
normalized_rect.set_x_center((roi.left + roi.right) / 2.0);
|
||||
normalized_rect.set_y_center((roi.top + roi.bottom) / 2.0);
|
||||
normalized_rect.set_width(roi.right - roi.left);
|
||||
normalized_rect.set_height(roi.bottom - roi.top);
|
||||
}
|
||||
return normalized_rect;
|
||||
}
|
||||
|
||||
private:
|
||||
RunningMode running_mode_;
|
||||
};
|
||||
|
|
52
mediapipe/tasks/cc/vision/core/image_processing_options.h
Normal file
52
mediapipe/tasks/cc/vision/core/image_processing_options.h
Normal file
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace core {
|
||||
|
||||
// Options for image processing.
|
||||
//
|
||||
// If both region-or-interest and rotation are specified, the crop around the
|
||||
// region-of-interest is extracted first, the the specified rotation is applied
|
||||
// to the crop.
|
||||
struct ImageProcessingOptions {
|
||||
// The optional region-of-interest to crop from the image. If not specified,
|
||||
// the full image is used.
|
||||
//
|
||||
// Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom.
|
||||
std::optional<components::containers::Rect> region_of_interest = std::nullopt;
|
||||
|
||||
// The rotation to apply to the image (or cropped region-of-interest), in
|
||||
// degrees clockwise.
|
||||
//
|
||||
// The rotation must be a multiple (positive or negative) of 90°.
|
||||
int rotation_degrees = 0;
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_
|
|
@ -62,13 +62,19 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
|
@ -93,10 +99,14 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
|
||||
|
@ -137,8 +147,10 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
|
|
|
@ -93,3 +93,46 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "combined_prediction_calculator_proto",
|
||||
srcs = ["combined_prediction_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "combined_prediction_calculator",
|
||||
srcs = ["combined_prediction_calculator.cc"],
|
||||
deps = [
|
||||
":combined_prediction_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "combined_prediction_calculator_test",
|
||||
srcs = ["combined_prediction_calculator_test.cc"],
|
||||
deps = [
|
||||
":combined_prediction_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/port:gtest",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/combined_prediction_calculator.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
constexpr char kPredictionTag[] = "PREDICTION";
|
||||
|
||||
Classification GetMaxScoringClassification(
|
||||
const ClassificationList& classifications) {
|
||||
Classification max_classification;
|
||||
max_classification.set_score(0);
|
||||
for (const auto& input : classifications.classification()) {
|
||||
if (max_classification.score() < input.score()) {
|
||||
max_classification = input;
|
||||
}
|
||||
}
|
||||
return max_classification;
|
||||
}
|
||||
|
||||
float GetScoreThreshold(
|
||||
const std::string& input_label,
|
||||
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||
const std::string& background_label, const float default_threshold) {
|
||||
float threshold = default_threshold;
|
||||
auto it = classwise_thresholds.find(input_label);
|
||||
if (it != classwise_thresholds.end()) {
|
||||
threshold = it->second;
|
||||
}
|
||||
return threshold;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> GetWinningPrediction(
|
||||
const ClassificationList& classification_list,
|
||||
const absl::btree_map<std::string, float>& classwise_thresholds,
|
||||
const std::string& background_label, const float default_threshold) {
|
||||
auto prediction_list = std::make_unique<ClassificationList>();
|
||||
if (classification_list.classification().empty()) {
|
||||
return prediction_list;
|
||||
}
|
||||
Classification& prediction = *prediction_list->add_classification();
|
||||
auto argmax_prediction = GetMaxScoringClassification(classification_list);
|
||||
float argmax_prediction_thresh =
|
||||
GetScoreThreshold(argmax_prediction.label(), classwise_thresholds,
|
||||
background_label, default_threshold);
|
||||
if (argmax_prediction.score() >= argmax_prediction_thresh) {
|
||||
prediction.set_label(argmax_prediction.label());
|
||||
prediction.set_score(argmax_prediction.score());
|
||||
} else {
|
||||
for (const auto& input : classification_list.classification()) {
|
||||
if (input.label() == background_label) {
|
||||
prediction.set_label(input.label());
|
||||
prediction.set_score(input.score());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return prediction_list;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// This calculator accepts multiple ClassificationList input streams. Each
|
||||
// ClassificationList should contain classifications with labels and
|
||||
// corresponding softmax scores. The calculator computes the best prediction for
|
||||
// each ClassificationList input stream via argmax and thresholding. Thresholds
|
||||
// for all classes can be specified in the
|
||||
// `CombinedPredictionCalculatorOptions`, along with a default global
|
||||
// threshold.
|
||||
// Please note that for this calculator to work as designed, the class names
|
||||
// other than the background class in the ClassificationList objects must be
|
||||
// different, but the background class name has to be the same. This background
|
||||
// label name can be set via `background_label` in
|
||||
// `CombinedPredictionCalculatorOptions`.
|
||||
// The ClassificationList in the PREDICTION output stream contains the label of
|
||||
// the winning class and corresponding softmax score. If none of the
|
||||
// ClassificationList objects has a non-background winning class, the output
|
||||
// contains the background class and score of the background class in the first
|
||||
// ClassificationList. If multiple ClassificationList objects have a
|
||||
// non-background winning class, the output contains the winning prediction from
|
||||
// the ClassificationList with the highest priority. Priority is in decreasing
|
||||
// order of input streams to the graph node using this calculator.
|
||||
// Input:
|
||||
// At least one stream with ClassificationList.
|
||||
// Output:
|
||||
// PREDICTION - A ClassificationList with the winning label as the only item.
|
||||
//
|
||||
// Usage example:
|
||||
// node {
|
||||
// calculator: "CombinedPredictionCalculator"
|
||||
// input_stream: "classification_list_0"
|
||||
// input_stream: "classification_list_1"
|
||||
// output_stream: "PREDICTION:prediction"
|
||||
// options {
|
||||
// [mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||
// class {
|
||||
// label: "A"
|
||||
// score_threshold: 0.7
|
||||
// }
|
||||
// default_global_threshold: 0.1
|
||||
// background_label: "B"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
class CombinedPredictionCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<ClassificationList>::Multiple kClassificationListIn{
|
||||
""};
|
||||
static constexpr Output<ClassificationList> kPredictionOut{"PREDICTION"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kClassificationListIn, kPredictionOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
options_ = cc->Options<CombinedPredictionCalculatorOptions>();
|
||||
for (const auto& input : options_.class_()) {
|
||||
classwise_thresholds_[input.label()] = input.score_threshold();
|
||||
}
|
||||
classwise_thresholds_[options_.background_label()] = 0;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
// After loop, if have winning prediction return. Otherwise empty packet.
|
||||
std::unique_ptr<ClassificationList> first_winning_prediction = nullptr;
|
||||
auto collection = kClassificationListIn(cc);
|
||||
for (int idx = 0; idx < collection.Count(); ++idx) {
|
||||
const auto& packet = collection[idx];
|
||||
if (packet.IsEmpty()) {
|
||||
continue;
|
||||
}
|
||||
auto prediction = GetWinningPrediction(
|
||||
packet.Get(), classwise_thresholds_, options_.background_label(),
|
||||
options_.default_global_threshold());
|
||||
if (prediction->classification(0).label() !=
|
||||
options_.background_label()) {
|
||||
kPredictionOut(cc).Send(std::move(prediction));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (first_winning_prediction == nullptr) {
|
||||
first_winning_prediction = std::move(prediction);
|
||||
}
|
||||
}
|
||||
if (first_winning_prediction != nullptr) {
|
||||
kPredictionOut(cc).Send(std::move(first_winning_prediction));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
CombinedPredictionCalculatorOptions options_;
|
||||
absl::btree_map<std::string, float> classwise_thresholds_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(CombinedPredictionCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message CombinedPredictionCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional CombinedPredictionCalculatorOptions ext = 483738635;
|
||||
}
|
||||
|
||||
message Class {
|
||||
optional string label = 1;
|
||||
optional float score_threshold = 2;
|
||||
}
|
||||
|
||||
// List of classes with score thresholds.
|
||||
repeated Class class = 1;
|
||||
|
||||
// Default score threshold applied to a label.
|
||||
optional float default_global_threshold = 2 [default = 0];
|
||||
|
||||
// Name of the background class whose input scores will be ignored while
|
||||
// thresholding.
|
||||
optional string background_label = 3;
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kPredictionTag[] = "PREDICTION";
|
||||
|
||||
std::unique_ptr<CalculatorRunner> BuildNodeRunnerWithOptions(
|
||||
float drama_thresh, float llama_thresh, float bazinga_thresh,
|
||||
float joy_thresh, float peace_thresh) {
|
||||
constexpr absl::string_view kCalculatorProto = R"pb(
|
||||
calculator: "CombinedPredictionCalculator"
|
||||
input_stream: "custom_softmax_scores"
|
||||
input_stream: "canned_softmax_scores"
|
||||
output_stream: "PREDICTION:prediction"
|
||||
options {
|
||||
[mediapipe.CombinedPredictionCalculatorOptions.ext] {
|
||||
class { label: "CustomDrama" score_threshold: $0 }
|
||||
class { label: "CustomLlama" score_threshold: $1 }
|
||||
class { label: "CannedBazinga" score_threshold: $2 }
|
||||
class { label: "CannedJoy" score_threshold: $3 }
|
||||
class { label: "CannedPeace" score_threshold: $4 }
|
||||
background_label: "Negative"
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
auto runner = std::make_unique<CalculatorRunner>(
|
||||
absl::Substitute(kCalculatorProto, drama_thresh, llama_thresh,
|
||||
bazinga_thresh, joy_thresh, peace_thresh));
|
||||
return runner;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> BuildCustomScoreInput(
|
||||
const float negative_score, const float drama_score,
|
||||
const float llama_score) {
|
||||
auto custom_scores = std::make_unique<ClassificationList>();
|
||||
auto custom_negative = custom_scores->add_classification();
|
||||
custom_negative->set_label("Negative");
|
||||
custom_negative->set_score(negative_score);
|
||||
auto drama = custom_scores->add_classification();
|
||||
drama->set_label("CustomDrama");
|
||||
drama->set_score(drama_score);
|
||||
auto llama = custom_scores->add_classification();
|
||||
llama->set_label("CustomLlama");
|
||||
llama->set_score(llama_score);
|
||||
return custom_scores;
|
||||
}
|
||||
|
||||
std::unique_ptr<ClassificationList> BuildCannedScoreInput(
|
||||
const float negative_score, const float bazinga_score,
|
||||
const float joy_score, const float peace_score) {
|
||||
auto canned_scores = std::make_unique<ClassificationList>();
|
||||
auto canned_negative = canned_scores->add_classification();
|
||||
canned_negative->set_label("Negative");
|
||||
canned_negative->set_score(negative_score);
|
||||
auto bazinga = canned_scores->add_classification();
|
||||
bazinga->set_label("CannedBazinga");
|
||||
bazinga->set_score(bazinga_score);
|
||||
auto joy = canned_scores->add_classification();
|
||||
joy->set_label("CannedJoy");
|
||||
joy->set_score(joy_score);
|
||||
auto peace = canned_scores->add_classification();
|
||||
peace->set_label("CannedPeace");
|
||||
peace->set_score(peace_score);
|
||||
return canned_scores;
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomEmpty_CannedEmpty_ResultIsEmpty) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.0,
|
||||
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
EXPECT_THAT(runner->Outputs().Tag("PREDICTION").packets, testing::IsEmpty());
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomEmpty_CannedNotEmpty_ResultIsCanned) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.0, /*llama_thresh=*/0.0, /*bazinga_thresh=*/0.9,
|
||||
/*joy_thresh=*/0.5, /*peace_thresh=*/0.8);
|
||||
auto canned_scores = BuildCannedScoreInput(
|
||||
/*negative_score=*/0.1,
|
||||
/*bazinga_score=*/0.1, /*joy_score=*/0.6, /*peace_score=*/0.2);
|
||||
runner->MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), "CannedJoy");
|
||||
EXPECT_NEAR(output_prediction.score(), 0.6, 1e-4);
|
||||
}
|
||||
|
||||
TEST(CombinedPredictionCalculatorPacketTest,
|
||||
CustomNotEmpty_CannedEmpty_ResultIsCustom) {
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
/*drama_thresh=*/0.3, /*llama_thresh=*/0.5, /*bazinga_thresh=*/0.0,
|
||||
/*joy_thresh=*/0.0, /*peace_thresh=*/0.0);
|
||||
auto custom_scores =
|
||||
BuildCustomScoreInput(/*negative_score=*/0.1,
|
||||
/*drama_score=*/0.2, /*llama_score=*/0.7);
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), "CustomLlama");
|
||||
EXPECT_NEAR(output_prediction.score(), 0.7, 1e-4);
|
||||
}
|
||||
|
||||
struct CombinedPredictionCalculatorTestCase {
|
||||
std::string test_name;
|
||||
float custom_negative_score;
|
||||
float drama_score;
|
||||
float llama_score;
|
||||
float drama_thresh;
|
||||
float llama_thresh;
|
||||
float canned_negative_score;
|
||||
float bazinga_score;
|
||||
float joy_score;
|
||||
float peace_score;
|
||||
float bazinga_thresh;
|
||||
float joy_thresh;
|
||||
float peace_thresh;
|
||||
std::string max_scoring_label;
|
||||
float max_score;
|
||||
};
|
||||
|
||||
using CombinedPredictionCalculatorTest =
|
||||
testing::TestWithParam<CombinedPredictionCalculatorTestCase>;
|
||||
|
||||
TEST_P(CombinedPredictionCalculatorTest, OutputsCorrectResult) {
|
||||
const CombinedPredictionCalculatorTestCase& test_case = GetParam();
|
||||
|
||||
auto runner = BuildNodeRunnerWithOptions(
|
||||
test_case.drama_thresh, test_case.llama_thresh, test_case.bazinga_thresh,
|
||||
test_case.joy_thresh, test_case.peace_thresh);
|
||||
|
||||
auto custom_scores =
|
||||
BuildCustomScoreInput(test_case.custom_negative_score,
|
||||
test_case.drama_score, test_case.llama_score);
|
||||
|
||||
runner->MutableInputs()->Index(0).packets.push_back(
|
||||
Adopt(custom_scores.release()).At(Timestamp(1)));
|
||||
|
||||
auto canned_scores = BuildCannedScoreInput(
|
||||
test_case.canned_negative_score, test_case.bazinga_score,
|
||||
test_case.joy_score, test_case.peace_score);
|
||||
runner->MutableInputs()->Index(1).packets.push_back(
|
||||
Adopt(canned_scores.release()).At(Timestamp(1)));
|
||||
|
||||
MP_ASSERT_OK(runner->Run()) << "Calculator execution failed.";
|
||||
|
||||
auto output_prediction_packets =
|
||||
runner->Outputs().Tag(kPredictionTag).packets;
|
||||
ASSERT_EQ(output_prediction_packets.size(), 1);
|
||||
Classification output_prediction =
|
||||
output_prediction_packets[0].Get<ClassificationList>().classification(0);
|
||||
|
||||
EXPECT_EQ(output_prediction.label(), test_case.max_scoring_label);
|
||||
EXPECT_NEAR(output_prediction.score(), test_case.max_score, 1e-4);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
CombinedPredictionCalculatorTests, CombinedPredictionCalculatorTest,
|
||||
testing::ValuesIn<CombinedPredictionCalculatorTestCase>({
|
||||
{
|
||||
.test_name = "TestCustomDramaWinnnerWith_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.5,
|
||||
.llama_score = 0.3,
|
||||
.drama_thresh = 0.25,
|
||||
.llama_thresh = 0.7,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "CustomDrama",
|
||||
.max_score = 0.5,
|
||||
},
|
||||
{
|
||||
.test_name = "TestCannedWinnerWith_HighCustom_ZeroCanned_Thresh",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.3,
|
||||
.llama_score = 0.6,
|
||||
.drama_thresh = 0.4,
|
||||
.llama_thresh = 0.8,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.4,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.2,
|
||||
.bazinga_thresh = 0.0,
|
||||
.joy_thresh = 0.0,
|
||||
.peace_thresh = 0.0,
|
||||
.max_scoring_label = "CannedBazinga",
|
||||
.max_score = 0.4,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_LowCustom_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.5,
|
||||
.drama_score = 0.1,
|
||||
.llama_score = 0.4,
|
||||
.drama_thresh = 0.1,
|
||||
.llama_thresh = 0.05,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.5,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh",
|
||||
.custom_negative_score = 0.8,
|
||||
.drama_score = 0.1,
|
||||
.llama_score = 0.1,
|
||||
.drama_thresh = 0.25,
|
||||
.llama_thresh = 0.7,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.8,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCannedThresh2",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.2,
|
||||
.llama_score = 0.7,
|
||||
.drama_thresh = 1.1,
|
||||
.llama_thresh = 1.1,
|
||||
.canned_negative_score = 0.1,
|
||||
.bazinga_score = 0.3,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.3,
|
||||
.bazinga_thresh = 0.7,
|
||||
.joy_thresh = 0.7,
|
||||
.peace_thresh = 0.7,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.1,
|
||||
},
|
||||
{
|
||||
.test_name = "TestNegativeWinnerWith_HighCustom_HighCanned_Thresh3",
|
||||
.custom_negative_score = 0.1,
|
||||
.drama_score = 0.3,
|
||||
.llama_score = 0.6,
|
||||
.drama_thresh = 0.4,
|
||||
.llama_thresh = 0.8,
|
||||
.canned_negative_score = 0.3,
|
||||
.bazinga_score = 0.2,
|
||||
.joy_score = 0.3,
|
||||
.peace_score = 0.2,
|
||||
.bazinga_thresh = 0.5,
|
||||
.joy_thresh = 0.5,
|
||||
.peace_thresh = 0.5,
|
||||
.max_scoring_label = "Negative",
|
||||
.max_score = 0.1,
|
||||
},
|
||||
}),
|
||||
[](const testing::TestParamInfo<
|
||||
CombinedPredictionCalculatorTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
|
@ -39,7 +39,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
|
@ -76,31 +78,6 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
|||
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
// Returns a NormalizedRect filling the whole image. If input is present, its
|
||||
// rotation is set in the returned NormalizedRect and a check is performed to
|
||||
// make sure no region-of-interest was provided. Otherwise, rotation is set to
|
||||
// 0.
|
||||
absl::StatusOr<NormalizedRect> FillNormalizedRect(
|
||||
std::optional<NormalizedRect> normalized_rect) {
|
||||
NormalizedRect result;
|
||||
if (normalized_rect.has_value()) {
|
||||
result = *normalized_rect;
|
||||
}
|
||||
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
|
||||
result.has_width() || result.has_height();
|
||||
if (has_coordinates) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GestureRecognizer does not support region-of-interest.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
result.set_x_center(0.5);
|
||||
result.set_y_center(0.5);
|
||||
result.set_width(1);
|
||||
result.set_height(1);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running
|
||||
// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the
|
||||
|
@ -136,57 +113,38 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
std::unique_ptr<GestureRecognizerGraphOptionsProto>
|
||||
ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
|
||||
auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode != core::RunningMode::IMAGE);
|
||||
|
||||
bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE;
|
||||
|
||||
// TODO remove these workarounds for base options of subgraphs.
|
||||
// Configure hand detector options.
|
||||
auto base_options_proto_for_hand_detector =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(options->base_options_for_hand_detector)));
|
||||
base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode);
|
||||
auto* hand_detector_graph_options =
|
||||
options_proto->mutable_hand_landmarker_graph_options()
|
||||
->mutable_hand_detector_graph_options();
|
||||
hand_detector_graph_options->mutable_base_options()->Swap(
|
||||
base_options_proto_for_hand_detector.get());
|
||||
hand_detector_graph_options->set_num_hands(options->num_hands);
|
||||
hand_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_detection_confidence);
|
||||
|
||||
// Configure hand landmark detector options.
|
||||
auto base_options_proto_for_hand_landmarker =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(options->base_options_for_hand_landmarker)));
|
||||
base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode);
|
||||
auto* hand_landmarks_detector_graph_options =
|
||||
options_proto->mutable_hand_landmarker_graph_options()
|
||||
->mutable_hand_landmarks_detector_graph_options();
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()->Swap(
|
||||
base_options_proto_for_hand_landmarker.get());
|
||||
hand_landmarks_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_presence_confidence);
|
||||
|
||||
auto* hand_landmarker_graph_options =
|
||||
options_proto->mutable_hand_landmarker_graph_options();
|
||||
hand_landmarker_graph_options->set_min_tracking_confidence(
|
||||
options->min_tracking_confidence);
|
||||
auto* hand_landmarks_detector_graph_options =
|
||||
hand_landmarker_graph_options
|
||||
->mutable_hand_landmarks_detector_graph_options();
|
||||
hand_landmarks_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_presence_confidence);
|
||||
|
||||
// Configure hand gesture recognizer options.
|
||||
auto base_options_proto_for_gesture_recognizer =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(options->base_options_for_gesture_recognizer)));
|
||||
base_options_proto_for_gesture_recognizer->set_use_stream_mode(
|
||||
use_stream_mode);
|
||||
auto* hand_gesture_recognizer_graph_options =
|
||||
options_proto->mutable_hand_gesture_recognizer_graph_options();
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()->Swap(
|
||||
base_options_proto_for_gesture_recognizer.get());
|
||||
if (options->min_gesture_confidence >= 0) {
|
||||
hand_gesture_recognizer_graph_options->mutable_classifier_options()
|
||||
hand_gesture_recognizer_graph_options
|
||||
->mutable_canned_gesture_classifier_graph_options()
|
||||
->mutable_classifier_options()
|
||||
->set_score_threshold(options->min_gesture_confidence);
|
||||
}
|
||||
return options_proto;
|
||||
|
@ -248,15 +206,16 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
|
|||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
|
@ -283,15 +242,16 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
|||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
|
@ -321,15 +281,16 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
|||
|
||||
absl::Status GestureRecognizer::RecognizeAsync(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
|
|
|
@ -23,10 +23,10 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -39,12 +39,6 @@ struct GestureRecognizerOptions {
|
|||
// model file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// TODO: remove these. Temporary solutions before bundle asset is
|
||||
// ready.
|
||||
tasks::core::BaseOptions base_options_for_hand_landmarker;
|
||||
tasks::core::BaseOptions base_options_for_hand_detector;
|
||||
tasks::core::BaseOptions base_options_for_gesture_recognizer;
|
||||
|
||||
// The running mode of the task. Default to the image mode.
|
||||
// GestureRecognizer has three running modes:
|
||||
// 1) The image mode for recognizing hand gestures on single image inputs.
|
||||
|
@ -129,36 +123,36 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// Only use this method when the GestureRecognizer is created with the image
|
||||
// running mode.
|
||||
//
|
||||
// image - mediapipe::Image
|
||||
// Image to perform hand gesture recognition on.
|
||||
// imageProcessingOptions - std::optional<NormalizedRect>
|
||||
// If provided, can be used to specify the rotation to apply to the image
|
||||
// before performing classification, by setting its 'rotation' field in
|
||||
// radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that
|
||||
// specifying a region-of-interest using the 'x_center', 'y_center', 'width'
|
||||
// and 'height' fields is NOT supported and will result in an invalid
|
||||
// argument error being returned.
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing recognition, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
// TODO: use an ImageProcessingOptions struct instead of
|
||||
// NormalizedRect.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
|
||||
Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs gesture recognition on the provided video frame.
|
||||
// Only use this method when the GestureRecognizer is created with the video
|
||||
// running mode.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing recognition, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult>
|
||||
RecognizeForVideo(Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Sends live image data to perform gesture recognition, and the results will
|
||||
|
@ -171,6 +165,12 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// sent to the gesture recognizer. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing recognition, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" provides
|
||||
// - A vector of GestureRecognitionResult, each is the recognized results
|
||||
// for a input frame.
|
||||
|
@ -180,7 +180,7 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status RecognizeAsync(Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the GestureRecognizer when all works are done.
|
||||
|
|
|
@ -25,9 +25,13 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
|
@ -46,6 +50,8 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::core::ModelAssetBundleResources;
|
||||
using ::mediapipe::tasks::metadata::SetExternalFile;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
GestureRecognizerGraphOptions;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
|
@ -61,6 +67,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
|
|||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
|
||||
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
|
||||
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
|
||||
constexpr char kHandGestureRecognizerBundleAssetName[] =
|
||||
"hand_gesture_recognizer.task";
|
||||
|
||||
struct GestureRecognizerOutputs {
|
||||
Source<std::vector<ClassificationList>> gesture;
|
||||
|
@ -70,6 +79,53 @@ struct GestureRecognizerOutputs {
|
|||
Source<Image> image;
|
||||
};
|
||||
|
||||
// Sets the base options in the sub tasks.
|
||||
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||
GestureRecognizerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarker_file,
|
||||
resources.GetModelFile(kHandLandmarkerBundleAssetName));
|
||||
auto* hand_landmarker_graph_options =
|
||||
options->mutable_hand_landmarker_graph_options();
|
||||
SetExternalFile(hand_landmarker_file,
|
||||
hand_landmarker_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
hand_landmarker_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto hand_gesture_recognizer_file,
|
||||
resources.GetModelFile(kHandGestureRecognizerBundleAssetName));
|
||||
auto* hand_gesture_recognizer_graph_options =
|
||||
options->mutable_hand_gesture_recognizer_graph_options();
|
||||
SetExternalFile(hand_gesture_recognizer_file,
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
if (!hand_gesture_recognizer_graph_options->base_options()
|
||||
.acceleration()
|
||||
.has_xnnpack() &&
|
||||
!hand_gesture_recognizer_graph_options->base_options()
|
||||
.acceleration()
|
||||
.has_tflite()) {
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->mutable_xnnpack();
|
||||
LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets "
|
||||
<< "HandGestureRecognizerGraph acceleartion to Xnnpack.";
|
||||
}
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()
|
||||
->set_use_stream_mode(options->base_options().use_stream_mode());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs
|
||||
|
@ -136,6 +192,21 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
if (sc->Options<GestureRecognizerGraphOptions>()
|
||||
.base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_asset_bundle_resources,
|
||||
CreateModelAssetBundleResources<GestureRecognizerGraphOptions>(sc));
|
||||
// When the model resources cache service is available, filling in
|
||||
// the file pointer meta in the subtasks' base options. Otherwise,
|
||||
// providing the file contents instead.
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
*model_asset_bundle_resources,
|
||||
sc->MutableOptions<GestureRecognizerGraphOptions>(),
|
||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
|
||||
BuildGestureRecognizerGraph(
|
||||
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
|
||||
|
|
|
@ -30,11 +30,17 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
|
@ -51,6 +57,8 @@ using ::mediapipe::api2::builder::Graph;
|
|||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::processors::
|
||||
ConfigureTensorsToClassificationCalculator;
|
||||
using ::mediapipe::tasks::core::ModelAssetBundleResources;
|
||||
using ::mediapipe::tasks::metadata::SetExternalFile;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
HandGestureRecognizerGraphOptions;
|
||||
|
||||
|
@ -70,6 +78,14 @@ constexpr char kVectorTag[] = "VECTOR";
|
|||
constexpr char kIndexTag[] = "INDEX";
|
||||
constexpr char kIterableTag[] = "ITERABLE";
|
||||
constexpr char kBatchEndTag[] = "BATCH_END";
|
||||
constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite";
|
||||
constexpr char kCannedGestureClassifierTFLiteName[] =
|
||||
"canned_gesture_classifier.tflite";
|
||||
|
||||
struct SubTaskModelResources {
|
||||
const core::ModelResources* gesture_embedder_model_resource;
|
||||
const core::ModelResources* canned_gesture_classifier_model_resource;
|
||||
};
|
||||
|
||||
Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
||||
Graph& graph) {
|
||||
|
@ -78,6 +94,41 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
return node[Output<std::vector<Tensor>>{"TENSORS"}];
|
||||
}
|
||||
|
||||
// Sets the base options in the sub tasks.
|
||||
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||
HandGestureRecognizerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto gesture_embedder_file,
|
||||
resources.GetModelFile(kGestureEmbedderTFLiteName));
|
||||
auto* gesture_embedder_graph_options =
|
||||
options->mutable_gesture_embedder_graph_options();
|
||||
SetExternalFile(gesture_embedder_file,
|
||||
gesture_embedder_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
gesture_embedder_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
|
||||
ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file,
|
||||
resources.GetModelFile(kCannedGestureClassifierTFLiteName));
|
||||
auto* canned_gesture_classifier_graph_options =
|
||||
options->mutable_canned_gesture_classifier_graph_options();
|
||||
SetExternalFile(
|
||||
canned_gesture_classifier_file,
|
||||
canned_gesture_classifier_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
canned_gesture_classifier_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
canned_gesture_classifier_graph_options->mutable_base_options()
|
||||
->set_use_stream_mode(options->base_options().use_stream_mode());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A
|
||||
|
@ -128,14 +179,29 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
if (sc->Options<HandGestureRecognizerGraphOptions>()
|
||||
.base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<HandGestureRecognizerGraphOptions>(sc));
|
||||
const auto* model_asset_bundle_resources,
|
||||
CreateModelAssetBundleResources<HandGestureRecognizerGraphOptions>(
|
||||
sc));
|
||||
// When the model resources cache service is available, filling in
|
||||
// the file pointer meta in the subtasks' base options. Otherwise,
|
||||
// providing the file contents instead.
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
*model_asset_bundle_resources,
|
||||
sc->MutableOptions<HandGestureRecognizerGraphOptions>(),
|
||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(const auto sub_task_model_resources,
|
||||
CreateSubTaskModelResources(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_gestures,
|
||||
ASSIGN_OR_RETURN(auto hand_gestures,
|
||||
BuildGestureRecognizerGraph(
|
||||
sc->Options<HandGestureRecognizerGraphOptions>(), *model_resources,
|
||||
sc->Options<HandGestureRecognizerGraphOptions>(),
|
||||
sub_task_model_resources,
|
||||
graph[Input<ClassificationList>(kHandednessTag)],
|
||||
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
||||
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
||||
|
@ -146,9 +212,37 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<SubTaskModelResources> CreateSubTaskModelResources(
|
||||
SubgraphContext* sc) {
|
||||
auto* options = sc->MutableOptions<HandGestureRecognizerGraphOptions>();
|
||||
SubTaskModelResources sub_task_model_resources;
|
||||
auto& gesture_embedder_model_asset =
|
||||
*options->mutable_gesture_embedder_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset();
|
||||
ASSIGN_OR_RETURN(
|
||||
sub_task_model_resources.gesture_embedder_model_resource,
|
||||
CreateModelResources(sc,
|
||||
std::make_unique<core::proto::ExternalFile>(
|
||||
std::move(gesture_embedder_model_asset)),
|
||||
"_gesture_embedder"));
|
||||
auto& canned_gesture_classifier_model_asset =
|
||||
*options->mutable_canned_gesture_classifier_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset();
|
||||
ASSIGN_OR_RETURN(
|
||||
sub_task_model_resources.canned_gesture_classifier_model_resource,
|
||||
CreateModelResources(
|
||||
sc,
|
||||
std::make_unique<core::proto::ExternalFile>(
|
||||
std::move(canned_gesture_classifier_model_asset)),
|
||||
"_canned_gesture_classifier"));
|
||||
return sub_task_model_resources;
|
||||
}
|
||||
|
||||
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
const core::ModelResources& model_resources,
|
||||
const SubTaskModelResources& sub_task_model_resources,
|
||||
Source<ClassificationList> handedness,
|
||||
Source<NormalizedLandmarkList> hand_landmarks,
|
||||
Source<LandmarkList> hand_world_landmarks,
|
||||
|
@ -209,17 +303,33 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
auto concatenated_tensors = concatenate_tensor_vector.Out("");
|
||||
|
||||
// Inference for static hand gesture recognition.
|
||||
// TODO add embedding step.
|
||||
auto& inference = AddInference(
|
||||
model_resources, graph_options.base_options().acceleration(), graph);
|
||||
concatenated_tensors >> inference.In(kTensorsTag);
|
||||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||
auto& gesture_embedder_inference =
|
||||
AddInference(*sub_task_model_resources.gesture_embedder_model_resource,
|
||||
graph_options.gesture_embedder_graph_options()
|
||||
.base_options()
|
||||
.acceleration(),
|
||||
graph);
|
||||
concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag);
|
||||
auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag);
|
||||
|
||||
auto& canned_gesture_classifier_inference = AddInference(
|
||||
*sub_task_model_resources.canned_gesture_classifier_model_resource,
|
||||
graph_options.canned_gesture_classifier_graph_options()
|
||||
.base_options()
|
||||
.acceleration(),
|
||||
graph);
|
||||
embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag);
|
||||
auto inference_output_tensors =
|
||||
canned_gesture_classifier_inference.Out(kTensorsTag);
|
||||
|
||||
auto& tensors_to_classification =
|
||||
graph.AddNode("TensorsToClassificationCalculator");
|
||||
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
|
||||
graph_options.classifier_options(),
|
||||
*model_resources.GetMetadataExtractor(), 0,
|
||||
graph_options.canned_gesture_classifier_graph_options()
|
||||
.classifier_options(),
|
||||
*sub_task_model_resources.canned_gesture_classifier_model_resource
|
||||
->GetMetadataExtractor(),
|
||||
0,
|
||||
&tensors_to_classification.GetOptions<
|
||||
mediapipe::TensorsToClassificationCalculatorOptions>()));
|
||||
inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
|
||||
|
|
|
@ -49,7 +49,6 @@ mediapipe_proto_library(
|
|||
":gesture_embedder_graph_options_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,6 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
|
||||
|
@ -37,15 +36,11 @@ message HandGestureRecognizerGraphOptions {
|
|||
// Options for GestureEmbedder.
|
||||
optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2;
|
||||
|
||||
// Options for GestureClassifier of default gestures.
|
||||
// Options for GestureClassifier of canned gestures.
|
||||
optional GestureClassifierGraphOptions
|
||||
canned_gesture_classifier_graph_options = 3;
|
||||
|
||||
// Options for GestureClassifier of custom gestures.
|
||||
optional GestureClassifierGraphOptions
|
||||
custom_gesture_classifier_graph_options = 4;
|
||||
|
||||
// TODO: remove these. Temporary solutions before bundle asset is
|
||||
// ready.
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 5;
|
||||
}
|
||||
|
|
|
@ -235,8 +235,10 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
image_to_tensor_options.set_keep_aspect_ratio(true);
|
||||
image_to_tensor_options.set_border_mode(
|
||||
mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO);
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In("IMAGE");
|
||||
|
|
|
@ -92,18 +92,30 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto hand_detector_file,
|
||||
resources.GetModelFile(kHandDetectorTFLiteName));
|
||||
auto* hand_detector_graph_options =
|
||||
options->mutable_hand_detector_graph_options();
|
||||
SetExternalFile(hand_detector_file,
|
||||
options->mutable_hand_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
hand_detector_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
hand_detector_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
hand_detector_graph_options->mutable_base_options()->set_use_stream_mode(
|
||||
options->base_options().use_stream_mode());
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
|
||||
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
|
||||
auto* hand_landmarks_detector_graph_options =
|
||||
options->mutable_hand_landmarks_detector_graph_options();
|
||||
SetExternalFile(hand_landmarks_detector_file,
|
||||
options->mutable_hand_landmarks_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()
|
||||
->mutable_acceleration()
|
||||
->CopyFrom(options->base_options().acceleration());
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()
|
||||
->set_use_stream_mode(options->base_options().use_stream_mode());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ using ::testing::proto::Approximately;
|
|||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task";
|
||||
constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task";
|
||||
constexpr char kLeftHandsImage[] = "left_hands.jpg";
|
||||
constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg";
|
||||
|
||||
|
|
|
@ -283,8 +283,10 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
subgraph_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In("IMAGE");
|
||||
|
|
|
@ -59,6 +59,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||
|
@ -59,26 +60,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
|||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
|
||||
// Returns a NormalizedRect covering the full image if input is not present.
|
||||
// Otherwise, makes sure the x_center, y_center, width and height are set in
|
||||
// case only a rotation was provided in the input.
|
||||
NormalizedRect FillNormalizedRect(
|
||||
std::optional<NormalizedRect> normalized_rect) {
|
||||
NormalizedRect result;
|
||||
if (normalized_rect.has_value()) {
|
||||
result = *normalized_rect;
|
||||
}
|
||||
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
|
||||
result.has_width() || result.has_height();
|
||||
if (!has_coordinates) {
|
||||
result.set_x_center(0.5);
|
||||
result.set_y_center(0.5);
|
||||
result.set_width(1);
|
||||
result.set_height(1);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// type "ImageClassifierGraph". If the task is running in the live stream mode,
|
||||
// a "FlowLimiterCalculator" will be added to limit the number of frames in
|
||||
|
@ -164,14 +145,16 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
||||
Image image, std::optional<NormalizedRect> image_processing_options) {
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
|
@ -183,14 +166,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
|||
|
||||
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
|
@ -206,14 +190,15 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
|||
|
||||
absl::Status ImageClassifier::ClassifyAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect = FillNormalizedRect(image_processing_options);
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
|
|
|
@ -22,11 +22,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -109,12 +109,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation).
|
||||
// setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform classification, by setting its
|
||||
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
|
||||
// set, they will automatically be set to cover the full image.
|
||||
// 'region_of_interest' field. If not specified, the full image is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
|
@ -126,19 +124,17 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// YUVToImageCalculator is integrated.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs image classification on the provided video frame.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation).
|
||||
// setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform classification, by setting its
|
||||
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
|
||||
// set, they will automatically be set to cover the full image.
|
||||
// 'region_of_interest' field. If not specified, the full image is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
|
@ -150,7 +146,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>
|
||||
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Sends live image data to image classification, and the results will be
|
||||
|
@ -158,12 +154,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation).
|
||||
// setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform classification, by setting its
|
||||
// 'x_center', 'y_center', 'width' and 'height' fields. If none of these is
|
||||
// set, they will automatically be set to cover the full image.
|
||||
// 'region_of_interest' field. If not specified, the full image is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
|
@ -175,7 +169,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// sent to the object detector. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
//
|
||||
// The "result_callback" prvoides
|
||||
// The "result_callback" provides:
|
||||
// - The classification results as a ClassificationResult object.
|
||||
// - The const reference to the corresponding input image that the image
|
||||
// classifier runs on. Note that the const reference to the image will no
|
||||
|
@ -183,12 +177,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// TODO: add Classify() variants taking a region of interest as
|
||||
// additional argument.
|
||||
|
||||
// Shuts down the ImageClassifier when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
|
|
@ -138,8 +138,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
|
@ -35,6 +34,8 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -49,9 +50,11 @@ namespace image_classifier {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -547,12 +550,9 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
options->classifier_options.max_results = 1;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the soccer ball.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.532);
|
||||
image_processing_options.set_y_center(0.521);
|
||||
image_processing_options.set_width(0.164);
|
||||
image_processing_options.set_height(0.427);
|
||||
// Region-of-interest around the soccer ball.
|
||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
@ -572,8 +572,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
ImageClassifier::Create(std::move(options)));
|
||||
|
||||
// Specify a 90° anti-clockwise rotation.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_rotation(M_PI / 2.0);
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
@ -616,13 +616,10 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
options->classifier_options.max_results = 1;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the chair, with 90° anti-clockwise rotation.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.2821);
|
||||
image_processing_options.set_y_center(0.2406);
|
||||
image_processing_options.set_width(0.5642);
|
||||
image_processing_options.set_height(0.1286);
|
||||
image_processing_options.set_rotation(M_PI / 2.0);
|
||||
// Region-of-interest around the chair, with 90° anti-clockwise rotation.
|
||||
Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/-90};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
@ -633,7 +630,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
entries {
|
||||
categories {
|
||||
index: 560
|
||||
score: 0.6800408
|
||||
score: 0.6522213
|
||||
category_name: "folding chair"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
|
@ -643,6 +640,69 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
})pb"));
|
||||
}
|
||||
|
||||
// Testing all these once with ImageClassifier.
|
||||
TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"multi_objects.jpg")));
|
||||
auto options = std::make_unique<ImageClassifierOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
|
||||
// Invalid: left > right.
|
||||
Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/0};
|
||||
auto results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected Rect with left < right and top < bottom"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
|
||||
// Invalid: top > bottom.
|
||||
roi = {/*left=*/0, /*top=*/0.9, /*right=*/1, /*bottom=*/0.1};
|
||||
image_processing_options = {roi,
|
||||
/*rotation_degrees=*/0};
|
||||
results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected Rect with left < right and top < bottom"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
|
||||
// Invalid: coordinates out of [0,1] range.
|
||||
roi = {/*left=*/-0.1, /*top=*/0, /*right=*/1, /*bottom=*/1};
|
||||
image_processing_options = {roi,
|
||||
/*rotation_degrees=*/0};
|
||||
results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected Rect values to be in [0,1]"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
|
||||
// Invalid: rotation not a multiple of 90°.
|
||||
image_processing_options = {/*region_of_interest=*/std::nullopt,
|
||||
/*rotation_degrees=*/1};
|
||||
results = image_classifier->Classify(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("Expected rotation to be a multiple of 90°"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
|
@ -732,11 +792,9 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the soccer ball.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.532);
|
||||
image_processing_options.set_y_center(0.521);
|
||||
image_processing_options.set_width(0.164);
|
||||
image_processing_options.set_height(0.427);
|
||||
// Region-of-interest around the soccer ball.
|
||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -877,11 +935,8 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
// Crop around the soccer ball.
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.532);
|
||||
image_processing_options.set_y_center(0.521);
|
||||
image_processing_options.set_width(0.164);
|
||||
image_processing_options.set_height(0.427);
|
||||
Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK(
|
||||
|
|
|
@ -58,6 +58,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
||||
|
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
||||
|
@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap;
|
|||
using ::mediapipe::tasks::vision::image_embedder::proto::
|
||||
ImageEmbedderGraphOptions;
|
||||
|
||||
// Builds a NormalizedRect covering the entire image.
|
||||
NormalizedRect BuildFullImageNormRect() {
|
||||
NormalizedRect norm_rect;
|
||||
norm_rect.set_x_center(0.5);
|
||||
norm_rect.set_y_center(0.5);
|
||||
norm_rect.set_width(1);
|
||||
norm_rect.set_height(1);
|
||||
return norm_rect;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a single node of type
|
||||
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
|
||||
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
|
||||
|
@ -148,15 +139,16 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||
Image image, std::optional<NormalizedRect> roi) {
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
|
@ -167,15 +159,16 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
|||
}
|
||||
|
||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
|
@ -188,16 +181,17 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
|||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
||||
}
|
||||
|
||||
absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> roi) {
|
||||
absl::Status ImageEmbedder::EmbedAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
|
|
|
@ -21,11 +21,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
|
||||
std::unique_ptr<ImageEmbedderOptions> options);
|
||||
|
||||
// Performs embedding extraction on the provided single image. Extraction
|
||||
// is performed on the region of interest specified by the `roi` argument if
|
||||
// provided, or on the entire image otherwise.
|
||||
// Performs embedding extraction on the provided single image.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing embedding
|
||||
// extraction, by setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform embedding extraction, by
|
||||
// setting its 'region_of_interest' field. If not specified, the full image
|
||||
// is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageEmbedder is created with the image
|
||||
// running mode.
|
||||
|
@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs embedding extraction on the provided video frame. Extraction
|
||||
// is performed on the region of interested specified by the `roi` argument if
|
||||
// provided, or on the entire image otherwise.
|
||||
// Performs embedding extraction on the provided video frame.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing embedding
|
||||
// extraction, by setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform embedding extraction, by
|
||||
// setting its 'region_of_interest' field. If not specified, the full image
|
||||
// is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageEmbedder is created with the video
|
||||
// running mode.
|
||||
|
@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to embedder, and the results will be available via
|
||||
// the "result_callback" provided in the ImageEmbedderOptions. Embedding
|
||||
// extraction is performed on the region of interested specified by the `roi`
|
||||
// argument if provided, or on the entire image otherwise.
|
||||
// the "result_callback" provided in the ImageEmbedderOptions.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing embedding
|
||||
// extraction, by setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform embedding extraction, by
|
||||
// setting its 'region_of_interest' field. If not specified, the full image
|
||||
// is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageEmbedder is created with the live
|
||||
// stream running mode.
|
||||
|
@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status EmbedAsync(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the ImageEmbedder when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -134,8 +134,10 @@ class ImageEmbedderGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
@ -42,7 +41,9 @@ namespace image_embedder {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image crop, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
// Bounding box in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
NormalizedRect roi;
|
||||
roi.set_x_center(200.0 / 480);
|
||||
roi.set_y_center(0.5);
|
||||
roi.set_width(400.0 / 480);
|
||||
roi.set_height(1.0f);
|
||||
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
image_embedder->Embed(image, roi));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& image_result,
|
||||
image_embedder->Embed(image, image_processing_options));
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
|
||||
|
@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
// Load images: one is a rotated version of the other.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"burger_rotated.jpg")));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
image_embedder->Embed(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& rotated_result,
|
||||
image_embedder->Embed(rotated, image_processing_options));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(rotated_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
rotated_result.embeddings(0).entries(0)));
|
||||
double expected_similarity = 0.572265;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image crop, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"burger_rotated.jpg")));
|
||||
// Region-of-interest corresponding to burger_crop.jpg.
|
||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/-90};
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& rotated_result,
|
||||
image_embedder->Embed(rotated, image_processing_options));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
CheckMobileNetV3Result(rotated_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
|
||||
rotated_result.embeddings(0).entries(0)));
|
||||
double expected_similarity = 0.62838;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
|
|
|
@ -24,10 +24,12 @@ cc_library(
|
|||
":image_segmenter_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||
|
@ -48,6 +50,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
|
|
|
@ -17,8 +17,10 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
|
||||
|
@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
|||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||
graph.Out(kGroupedSegmentationTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag);
|
||||
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
|
||||
{kImageTag, kNormRectTag},
|
||||
kGroupedSegmentationTag);
|
||||
}
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -139,47 +146,68 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||
mediapipe::Image image) {
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData({{kImageInStreamName,
|
||||
mediapipe::MakePacket<Image>(std::move(image))}}));
|
||||
ProcessImageData(
|
||||
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms) {
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
|
||||
absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) {
|
||||
absl::Status ImageSegmenter::SegmentAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
|
@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// running mode.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs image segmentation on the provided video frame.
|
||||
// Only use this method when the ImageSegmenter is created with the video
|
||||
|
@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms);
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to perform image segmentation, and the results will
|
||||
// be available via the "result_callback" provided in the
|
||||
|
@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// sent to the image segmenter. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" prvoides
|
||||
// - A vector of segmented image masks.
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
|
@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// no longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms);
|
||||
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the ImageSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
|
@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
|||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||
|
||||
|
@ -159,6 +161,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform segmentation on.
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection
|
||||
// on.
|
||||
// @Optional: rect covering the whole image is used if not specified.
|
||||
//
|
||||
// Outputs:
|
||||
// SEGMENTATION - mediapipe::Image @Multiple
|
||||
|
@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<ImageSegmenterOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto output_streams,
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildSegmentationTask(
|
||||
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
|
||||
auto& merge_images_to_vector =
|
||||
graph.AddNode("MergeImagesToVectorCalculator");
|
||||
|
@ -228,18 +236,21 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||
const ImageSegmenterOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
|
||||
// Adds preprocessing calculators and connects them to the graph input image
|
||||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the ImageToTensorCalculator.
|
||||
|
|
|
@ -29,8 +29,10 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
@ -44,6 +46,8 @@ namespace {
|
|||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
EXPECT_EQ(confidence_masks.size(), 21);
|
||||
|
||||
|
@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
EXPECT_EQ(confidence_masks.size(), 21);
|
||||
|
||||
cv::Mat expected_mask =
|
||||
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
cv::Mat expected_mask_float;
|
||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||
|
||||
// Cat category index 8.
|
||||
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||
confidence_masks[8].GetImageFrameSharedPtr().get());
|
||||
EXPECT_THAT(cat_mask,
|
||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = segmenter->Segment(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("This task doesn't support region-of-interest"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||
Image image =
|
||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
||||
|
|
|
@ -75,6 +75,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto",
|
||||
|
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h"
|
||||
|
@ -58,31 +59,6 @@ constexpr int kMicroSecondsPerMilliSecond = 1000;
|
|||
using ObjectDetectorOptionsProto =
|
||||
object_detector::proto::ObjectDetectorOptions;
|
||||
|
||||
// Returns a NormalizedRect filling the whole image. If input is present, its
|
||||
// rotation is set in the returned NormalizedRect and a check is performed to
|
||||
// make sure no region-of-interest was provided. Otherwise, rotation is set to
|
||||
// 0.
|
||||
absl::StatusOr<NormalizedRect> FillNormalizedRect(
|
||||
std::optional<NormalizedRect> normalized_rect) {
|
||||
NormalizedRect result;
|
||||
if (normalized_rect.has_value()) {
|
||||
result = *normalized_rect;
|
||||
}
|
||||
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
|
||||
result.has_width() || result.has_height();
|
||||
if (has_coordinates) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"ObjectDetector does not support region-of-interest.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
result.set_x_center(0.5);
|
||||
result.set_y_center(0.5);
|
||||
result.set_width(1);
|
||||
result.set_height(1);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.ObjectDetectorGraph". If the task is running in the
|
||||
// live stream mode, a "FlowLimiterCalculator" will be added to limit the
|
||||
|
@ -170,15 +146,16 @@ absl::StatusOr<std::unique_ptr<ObjectDetector>> ObjectDetector::Create(
|
|||
|
||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
|
@ -189,15 +166,16 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::Detect(
|
|||
|
||||
absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
|
@ -212,15 +190,16 @@ absl::StatusOr<std::vector<Detection>> ObjectDetector::DetectForVideo(
|
|||
|
||||
absl::Status ObjectDetector::DetectAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
|
|
|
@ -27,9 +27,9 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -154,10 +154,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// after the yuv support is implemented.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation). Note that specifying a region-of-interest using
|
||||
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
|
||||
// the rotation to apply to the image before performing detection, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// For CPU images, the returned bounding boxes are expressed in the
|
||||
|
@ -168,7 +167,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// images after enabling the gpu support in MediaPipe Tasks.
|
||||
absl::StatusOr<std::vector<mediapipe::Detection>> Detect(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs object detection on the provided video frame.
|
||||
|
@ -180,10 +179,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation). Note that specifying a region-of-interest using
|
||||
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
|
||||
// the rotation to apply to the image before performing detection, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// For CPU images, the returned bounding boxes are expressed in the
|
||||
|
@ -192,7 +190,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// underlying image data.
|
||||
absl::StatusOr<std::vector<mediapipe::Detection>> DetectForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to perform object detection, and the results will be
|
||||
|
@ -206,10 +204,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing classification, by
|
||||
// setting its 'rotation' field in radians (e.g. 'M_PI / 2' for a 90°
|
||||
// anti-clockwise rotation). Note that specifying a region-of-interest using
|
||||
// the 'x_center', 'y_center', 'width' and 'height' fields is NOT supported
|
||||
// the rotation to apply to the image before performing detection, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" provides
|
||||
|
@ -223,7 +220,7 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi {
|
|||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the ObjectDetector when all works are done.
|
||||
|
|
|
@ -563,8 +563,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph");
|
||||
bool use_gpu = components::DetermineImagePreprocessingGpuBackend(
|
||||
task_options.base_options().acceleration());
|
||||
MP_RETURN_IF_ERROR(ConfigureImagePreprocessing(
|
||||
model_resources,
|
||||
model_resources, use_gpu,
|
||||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
|
|
|
@ -31,11 +31,12 @@ limitations under the License.
|
|||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
@ -64,6 +65,8 @@ namespace vision {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -532,8 +535,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_rotation(M_PI / 2.0);
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto results, object_detector->Detect(image, image_processing_options));
|
||||
MP_ASSERT_OK(object_detector->Close());
|
||||
|
@ -557,16 +560,17 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
|||
JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ObjectDetector> object_detector,
|
||||
ObjectDetector::Create(std::move(options)));
|
||||
NormalizedRect image_processing_options;
|
||||
image_processing_options.set_x_center(0.5);
|
||||
image_processing_options.set_y_center(0.5);
|
||||
image_processing_options.set_width(1.0);
|
||||
image_processing_options.set_height(1.0);
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = object_detector->Detect(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("ObjectDetector does not support region-of-interest"));
|
||||
HasSubstr("This task doesn't support region-of-interest"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
|
|
|
@ -31,6 +31,7 @@ android_binary(
|
|||
multidex = "native",
|
||||
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector;
|
|||
|
||||
import android.content.Intent;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Matrix;
|
||||
import android.media.MediaMetadataRetriever;
|
||||
import android.os.Bundle;
|
||||
import android.provider.MediaStore;
|
||||
|
@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher;
|
|||
import androidx.activity.result.contract.ActivityResultContracts;
|
||||
import androidx.exifinterface.media.ExifInterface;
|
||||
// ContentResolver dependency
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
|
||||
|
@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity {
|
|||
if (resultIntent != null) {
|
||||
if (result.getResultCode() == RESULT_OK) {
|
||||
Bitmap bitmap = null;
|
||||
int rotation = 0;
|
||||
try {
|
||||
bitmap =
|
||||
downscaleBitmap(
|
||||
|
@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity {
|
|||
try {
|
||||
InputStream imageData =
|
||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||
bitmap = rotateBitmap(bitmap, imageData);
|
||||
} catch (IOException e) {
|
||||
rotation = getImageRotation(imageData);
|
||||
} catch (IOException | MediaPipeException e) {
|
||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||
}
|
||||
if (bitmap != null) {
|
||||
Image image = new BitmapImageBuilder(bitmap).build();
|
||||
ObjectDetectionResult detectionResult = objectDetector.detect(image);
|
||||
MPImage image = new BitmapImageBuilder(bitmap).build();
|
||||
ObjectDetectionResult detectionResult =
|
||||
objectDetector.detect(
|
||||
image,
|
||||
ImageProcessingOptions.builder().setRotationDegrees(rotation).build());
|
||||
imageView.setData(image, detectionResult);
|
||||
runOnUiThread(() -> imageView.update());
|
||||
}
|
||||
|
@ -144,7 +149,8 @@ public class MainActivity extends AppCompatActivity {
|
|||
MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT));
|
||||
long frameIntervalMs = duration / numFrames;
|
||||
for (int i = 0; i < numFrames; ++i) {
|
||||
Image image = new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build();
|
||||
MPImage image =
|
||||
new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build();
|
||||
ObjectDetectionResult detectionResult =
|
||||
objectDetector.detectForVideo(image, frameIntervalMs * i);
|
||||
// Currently only annotates the detection result on the first video frame and
|
||||
|
@ -209,28 +215,25 @@ public class MainActivity extends AppCompatActivity {
|
|||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||
}
|
||||
|
||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||
private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException {
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||
return inputBitmap;
|
||||
}
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_NORMAL:
|
||||
return 0;
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
return 90;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
return 180;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
return 270;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
// TODO: use getRotationDegrees() and isFlipped() instead of switch once flip
|
||||
// is supported.
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(),
|
||||
"Flipped images are not supported yet.");
|
||||
}
|
||||
return Bitmap.createBitmap(
|
||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import android.graphics.Matrix;
|
|||
import android.graphics.Paint;
|
||||
import androidx.appcompat.widget.AppCompatImageView;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.Detection;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
||||
|
||||
|
@ -40,12 +40,12 @@ public class ObjectDetectionResultImageView extends AppCompatImageView {
|
|||
}
|
||||
|
||||
/**
|
||||
* Sets an {@link Image} and an {@link ObjectDetectionResult} to render.
|
||||
* Sets a {@link MPImage} and an {@link ObjectDetectionResult} to render.
|
||||
*
|
||||
* @param image an {@link Image} object for annotation.
|
||||
* @param image a {@link MPImage} object for annotation.
|
||||
* @param result an {@link ObjectDetectionResult} object that contains the detection result.
|
||||
*/
|
||||
public void setData(Image image, ObjectDetectionResult result) {
|
||||
public void setData(MPImage image, ObjectDetectionResult result) {
|
||||
if (image == null || result == null) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -1 +1,15 @@
|
|||
# dummy file for tap test to find the pattern
|
||||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
|
15
mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD
Normal file
15
mediapipe/tasks/java/com/google/mediapipe/tasks/BUILD
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
|
@ -36,3 +36,15 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar")
|
||||
|
||||
mediapipe_tasks_core_aar(
|
||||
name = "tasks_core",
|
||||
srcs = glob(["*.java"]) + [
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image:java_src",
|
||||
],
|
||||
manifest = "AndroidManifest.xml",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Building MediaPipe Tasks AARs."""
|
||||
|
||||
load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build_aar_with_jni", "mediapipe_java_proto_src_extractor", "mediapipe_java_proto_srcs")
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
|
||||
|
||||
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||
]
|
||||
|
||||
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
]
|
||||
|
||||
def mediapipe_tasks_core_aar(name, srcs, manifest):
|
||||
"""Builds medaipipe tasks core AAR.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
srcs: MediaPipe Tasks' core layer source files.
|
||||
manifest: The Android manifest.
|
||||
"""
|
||||
|
||||
mediapipe_tasks_java_proto_srcs = []
|
||||
for target in _CORE_TASKS_JAVA_PROTO_LITE_TARGETS:
|
||||
mediapipe_tasks_java_proto_srcs.append(
|
||||
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||
)
|
||||
|
||||
for target in _VISION_TASKS_JAVA_PROTO_LITE_TARGETS:
|
||||
mediapipe_tasks_java_proto_srcs.append(
|
||||
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||
)
|
||||
|
||||
for target in _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS:
|
||||
mediapipe_tasks_java_proto_srcs.append(
|
||||
_mediapipe_tasks_java_proto_src_extractor(target = target),
|
||||
)
|
||||
|
||||
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/calculator/proto/FlowLimiterCalculatorProto.java",
|
||||
))
|
||||
|
||||
mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/calculator/proto/InferenceCalculatorProto.java",
|
||||
))
|
||||
|
||||
android_library(
|
||||
name = name,
|
||||
srcs = srcs + [
|
||||
"//mediapipe/java/com/google/mediapipe/framework:java_src",
|
||||
] + mediapipe_java_proto_srcs() +
|
||||
mediapipe_tasks_java_proto_srcs,
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = manifest,
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_profile_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework:mediapipe_options_java_proto_lite",
|
||||
"//mediapipe/framework:packet_factory_java_proto_lite",
|
||||
"//mediapipe/framework:packet_generator_java_proto_lite",
|
||||
"//mediapipe/framework:status_handler_java_proto_lite",
|
||||
"//mediapipe/framework:stream_handler_java_proto_lite",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:detection_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
"//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
"//third_party:androidx_annotation",
|
||||
"//third_party:autovalue",
|
||||
"@com_google_protobuf//:protobuf_javalite",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_google_flogger_flogger",
|
||||
"@maven//:com_google_flogger_flogger_system_backend",
|
||||
"@maven//:com_google_code_findbugs_jsr305",
|
||||
] +
|
||||
_CORE_TASKS_JAVA_PROTO_LITE_TARGETS +
|
||||
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS +
|
||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||
)
|
||||
|
||||
def mediapipe_tasks_vision_aar(name, srcs, native_library):
|
||||
"""Builds medaipipe tasks vision AAR.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
srcs: MediaPipe Vision Tasks' source files.
|
||||
native_library: The native library that contains vision tasks' graph and calculators.
|
||||
"""
|
||||
|
||||
native.genrule(
|
||||
name = name + "tasks_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision">
|
||||
<uses-sdk
|
||||
android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
)
|
||||
|
||||
_mediapipe_tasks_aar(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
manifest = "AndroidManifest.xml",
|
||||
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||
native_library = native_library,
|
||||
)
|
||||
|
||||
def mediapipe_tasks_text_aar(name, srcs, native_library):
|
||||
"""Builds medaipipe tasks text AAR.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
srcs: MediaPipe Text Tasks' source files.
|
||||
native_library: The native library that contains text tasks' graph and calculators.
|
||||
"""
|
||||
|
||||
native.genrule(
|
||||
name = name + "tasks_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.text">
|
||||
<uses-sdk
|
||||
android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
)
|
||||
|
||||
_mediapipe_tasks_aar(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
manifest = "AndroidManifest.xml",
|
||||
java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS,
|
||||
native_library = native_library,
|
||||
)
|
||||
|
||||
def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_library):
|
||||
"""Builds medaipipe tasks AAR."""
|
||||
|
||||
# When "--define EXCLUDE_OPENCV_SO_LIB=1" is set in the build command,
|
||||
# the OpenCV so libraries will be excluded from the AAR package to
|
||||
# save the package size.
|
||||
native.config_setting(
|
||||
name = "exclude_opencv_so_lib",
|
||||
define_values = {
|
||||
"EXCLUDE_OPENCV_SO_LIB": "1",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
native.cc_library(
|
||||
name = name + "_jni_opencv_cc_lib",
|
||||
srcs = select({
|
||||
"//mediapipe:android_arm64": ["@android_opencv//:libopencv_java3_so_arm64-v8a"],
|
||||
"//mediapipe:android_armeabi": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"],
|
||||
"//mediapipe:android_arm": ["@android_opencv//:libopencv_java3_so_armeabi-v7a"],
|
||||
"//mediapipe:android_x86": ["@android_opencv//:libopencv_java3_so_x86"],
|
||||
"//mediapipe:android_x86_64": ["@android_opencv//:libopencv_java3_so_x86_64"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = name + "_android_lib",
|
||||
srcs = srcs,
|
||||
manifest = manifest,
|
||||
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
||||
deps = java_proto_lite_targets + [native_library] + [
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_java_proto_lite",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:detection_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
"//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
] + select({
|
||||
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
|
||||
"//mediapipe/framework/port:disable_opencv": [],
|
||||
"exclude_opencv_so_lib": [],
|
||||
}),
|
||||
)
|
||||
|
||||
mediapipe_build_aar_with_jni(name, name + "_android_lib")
|
||||
|
||||
def _mediapipe_tasks_java_proto_src_extractor(target):
|
||||
proto_path = "com/google/" + target.split(":")[0].replace("cc/", "").replace("//", "").replace("_", "") + "/"
|
||||
proto_name = target.split(":")[-1].replace("_java_proto_lite", "").replace("_", " ").title().replace(" ", "") + "Proto.java"
|
||||
return mediapipe_java_proto_src_extractor(
|
||||
target = target,
|
||||
src_out = proto_path + proto_name,
|
||||
)
|
|
@ -61,3 +61,11 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_text_aar")
|
||||
|
||||
mediapipe_tasks_text_aar(
|
||||
name = "tasks_text",
|
||||
srcs = glob(["**/*.java"]),
|
||||
native_library = ":libmediapipe_tasks_text_jni_lib",
|
||||
)
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
|
|||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
|
|
@ -28,6 +28,7 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
@ -128,6 +129,7 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
|
@ -140,3 +142,11 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
|
||||
|
||||
mediapipe_tasks_vision_aar(
|
||||
name = "tasks_vision",
|
||||
srcs = glob(["**/*.java"]),
|
||||
native_library = ":libmediapipe_tasks_vision_jni_lib",
|
||||
)
|
||||
|
|
|
@ -19,12 +19,11 @@ import com.google.mediapipe.formats.proto.RectProto.NormalizedRect;
|
|||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/** The base class of MediaPipe vision tasks. */
|
||||
public class BaseVisionTaskApi implements AutoCloseable {
|
||||
|
@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
private final TaskRunner runner;
|
||||
private final RunningMode runningMode;
|
||||
private final String imageStreamName;
|
||||
private final Optional<String> normRectStreamName;
|
||||
private final String normRectStreamName;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||
|
@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
* @param imageStreamName the name of the input image stream.
|
||||
*/
|
||||
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.imageStreamName = imageStreamName;
|
||||
this.normRectStreamName = Optional.empty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
|
||||
* input.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
* @param imageStreamName the name of the input image stream.
|
||||
* @param normRectStreamName the name of the input normalized rect image stream.
|
||||
* @param normRectStreamName the name of the input normalized rect image stream used to provide
|
||||
* (mandatory) rotation and (optional) region-of-interest.
|
||||
*/
|
||||
public BaseVisionTaskApi(
|
||||
TaskRunner runner,
|
||||
|
@ -70,61 +55,31 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.imageStreamName = imageStreamName;
|
||||
this.normRectStreamName = Optional.of(normRectStreamName);
|
||||
this.normRectStreamName = normRectStreamName;
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
|
||||
* input.
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @throws MediaPipeException if the task is not in the image mode.
|
||||
*/
|
||||
protected TaskResult processImageData(Image image) {
|
||||
protected TaskResult processImageData(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected TaskResult processImageData(Image image, RectF roi) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
normRectStreamName,
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
|
@ -132,56 +87,25 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
||||
* input.
|
||||
* @throws MediaPipeException if the task is not in the video mode.
|
||||
*/
|
||||
protected TaskResult processVideoData(Image image, long timestampMs) {
|
||||
protected TaskResult processVideoData(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
normRectStreamName,
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
|
@ -189,56 +113,25 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
||||
* input.
|
||||
* @throws MediaPipeException if the task is not in the stream mode.
|
||||
*/
|
||||
protected void sendLiveStreamData(Image image, long timestampMs) {
|
||||
protected void sendLiveStreamData(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
normRectStreamName,
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
|
@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
runner.close();
|
||||
}
|
||||
|
||||
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
|
||||
private static NormalizedRect convertToNormalizedRect(RectF rect) {
|
||||
/**
|
||||
* Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf
|
||||
* message.
|
||||
*/
|
||||
private static NormalizedRect convertToNormalizedRect(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
RectF regionOfInterest =
|
||||
imageProcessingOptions.regionOfInterest().isPresent()
|
||||
? imageProcessingOptions.regionOfInterest().get()
|
||||
: new RectF(0, 0, 1, 1);
|
||||
return NormalizedRect.newBuilder()
|
||||
.setXCenter(rect.centerX())
|
||||
.setYCenter(rect.centerY())
|
||||
.setWidth(rect.width())
|
||||
.setHeight(rect.height())
|
||||
.setXCenter(regionOfInterest.centerX())
|
||||
.setYCenter(regionOfInterest.centerY())
|
||||
.setWidth(regionOfInterest.width())
|
||||
.setHeight(regionOfInterest.height())
|
||||
// Convert to radians anti-clockwise.
|
||||
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.core;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Optional;
|
||||
|
||||
// TODO: add support for image flipping.
|
||||
/** Options for image processing. */
|
||||
@AutoValue
|
||||
public abstract class ImageProcessingOptions {
|
||||
|
||||
/**
|
||||
* Builder for {@link ImageProcessingOptions}.
|
||||
*
|
||||
* <p>If both region-of-interest and rotation are specified, the crop around the
|
||||
* region-of-interest is extracted first, then the specified rotation is applied to the crop.
|
||||
*/
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets the optional region-of-interest to crop from the image. If not specified, the full image
|
||||
* is used.
|
||||
*
|
||||
* <p>Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be
|
||||
* < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()}
|
||||
* is called.
|
||||
*/
|
||||
public abstract Builder setRegionOfInterest(RectF value);
|
||||
|
||||
/**
|
||||
* Sets the rotation to apply to the image (or cropped region-of-interest), in degrees
|
||||
* clockwise. Defaults to 0.
|
||||
*
|
||||
* <p>The rotation must be a multiple (positive or negative) of 90°, otherwise an
|
||||
* IllegalArgumentException will be thrown when {@link #build()} is called.
|
||||
*/
|
||||
public abstract Builder setRotationDegrees(int value);
|
||||
|
||||
abstract ImageProcessingOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ImageProcessingOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if some of the provided values do not meet their
|
||||
* requirements.
|
||||
*/
|
||||
public final ImageProcessingOptions build() {
|
||||
ImageProcessingOptions options = autoBuild();
|
||||
if (options.regionOfInterest().isPresent()) {
|
||||
RectF roi = options.regionOfInterest().get();
|
||||
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Expected left < right and top < bottom, found: %s.", roi.toShortString()));
|
||||
}
|
||||
if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString()));
|
||||
}
|
||||
}
|
||||
if (options.rotationDegrees() % 90 != 0) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Expected rotation to be a multiple of 90°, found: %d.",
|
||||
options.rotationDegrees()));
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
public abstract Optional<RectF> regionOfInterest();
|
||||
|
||||
public abstract int rotationDegrees();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0);
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user