Merge branch 'master' into gesture-recognizer-python
This commit is contained in:
commit
0de97497fa
|
@ -163,7 +163,6 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||||
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
|
||||||
const auto& input_tensors = *kInTensors(cc);
|
const auto& input_tensors = *kInTensors(cc);
|
||||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||||
|
@ -182,12 +181,6 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||||
auto raw_scores = view.buffer<float>();
|
auto raw_scores = view.buffer<float>();
|
||||||
|
|
||||||
auto classification_list = absl::make_unique<ClassificationList>();
|
auto classification_list = absl::make_unique<ClassificationList>();
|
||||||
if (options.has_tensor_index()) {
|
|
||||||
classification_list->set_tensor_index(options.tensor_index());
|
|
||||||
}
|
|
||||||
if (options.has_tensor_name()) {
|
|
||||||
classification_list->set_tensor_name(options.tensor_name());
|
|
||||||
}
|
|
||||||
if (is_binary_classification_) {
|
if (is_binary_classification_) {
|
||||||
Classification* class_first = classification_list->add_classification();
|
Classification* class_first = classification_list->add_classification();
|
||||||
Classification* class_second = classification_list->add_classification();
|
Classification* class_second = classification_list->add_classification();
|
||||||
|
|
|
@ -72,9 +72,4 @@ message TensorsToClassificationCalculatorOptions {
|
||||||
// that are not in the `allow_classes` field will be completely ignored.
|
// that are not in the `allow_classes` field will be completely ignored.
|
||||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||||
repeated int32 allow_classes = 8 [packed = true];
|
repeated int32 allow_classes = 8 [packed = true];
|
||||||
|
|
||||||
// The optional index of the tensor these classifications originate from.
|
|
||||||
optional int32 tensor_index = 10;
|
|
||||||
// The optional name of the tensor these classifications originate from.
|
|
||||||
optional string tensor_name = 11;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -240,36 +240,6 @@ TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TensorsToClassificationCalculatorTest,
|
|
||||||
CorrectOutputWithTensorNameAndIndex) {
|
|
||||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
|
||||||
calculator: "TensorsToClassificationCalculator"
|
|
||||||
input_stream: "TENSORS:tensors"
|
|
||||||
output_stream: "CLASSIFICATIONS:classifications"
|
|
||||||
options {
|
|
||||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
|
||||||
tensor_index: 1
|
|
||||||
tensor_name: "foo"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)pb"));
|
|
||||||
|
|
||||||
BuildGraph(&runner, {0, 0.5, 1});
|
|
||||||
MP_ASSERT_OK(runner.Run());
|
|
||||||
|
|
||||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
|
||||||
|
|
||||||
EXPECT_EQ(1, output_packets_.size());
|
|
||||||
|
|
||||||
const auto& classification_list =
|
|
||||||
output_packets_[0].Get<ClassificationList>();
|
|
||||||
EXPECT_EQ(3, classification_list.classification_size());
|
|
||||||
|
|
||||||
// Verify that the tensor_index and tensor_name fields are correctly set.
|
|
||||||
EXPECT_EQ(classification_list.tensor_index(), 1);
|
|
||||||
EXPECT_EQ(classification_list.tensor_name(), "foo");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TensorsToClassificationCalculatorTest,
|
TEST_F(TensorsToClassificationCalculatorTest,
|
||||||
ClassNameAllowlistWithLabelItems) {
|
ClassNameAllowlistWithLabelItems) {
|
||||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||||
|
|
|
@ -293,7 +293,6 @@ mediapipe_proto_library(
|
||||||
name = "rect_proto",
|
name = "rect_proto",
|
||||||
srcs = ["rect.proto"],
|
srcs = ["rect.proto"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = ["//mediapipe/framework/formats:location_data_proto"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_register_type(
|
mediapipe_register_type(
|
||||||
|
|
|
@ -37,10 +37,6 @@ message Classification {
|
||||||
// Group of Classification protos.
|
// Group of Classification protos.
|
||||||
message ClassificationList {
|
message ClassificationList {
|
||||||
repeated Classification classification = 1;
|
repeated Classification classification = 1;
|
||||||
// Optional index of the tensor that produced these classifications.
|
|
||||||
optional int32 tensor_index = 2;
|
|
||||||
// Optional name of the tensor that produced these classifications.
|
|
||||||
optional string tensor_name = 3;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group of ClassificationList protos.
|
// Group of ClassificationList protos.
|
||||||
|
|
|
@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
|
||||||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
||||||
|
|
||||||
static void EglThreadExitCallback(void* key_value) {
|
static void EglThreadExitCallback(void* key_value) {
|
||||||
|
#if defined(__ANDROID__)
|
||||||
|
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||||
|
EGL_NO_CONTEXT);
|
||||||
|
#else
|
||||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
||||||
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
||||||
// implementations, and should be considered as an undocumented vendor
|
// implementations, and should be considered as an undocumented vendor
|
||||||
// extension.
|
// extension.
|
||||||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||||
|
//
|
||||||
|
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
|
||||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||||
|
#endif
|
||||||
eglReleaseThread();
|
eglReleaseThread();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
||||||
return model.fit(
|
return model.fit(
|
||||||
x=train_ds,
|
x=train_ds,
|
||||||
epochs=hparams.train_epochs,
|
epochs=hparams.train_epochs,
|
||||||
steps_per_epoch=hparams.steps_per_epoch,
|
|
||||||
validation_data=validation_ds,
|
validation_data=validation_ds,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
|
@ -87,6 +87,7 @@ cc_library(
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "builtin_task_graphs",
|
name = "builtin_task_graphs",
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||||
],
|
],
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
"""The public facing packet getter APIs."""
|
"""The public facing packet getter APIs."""
|
||||||
|
|
||||||
from typing import List, Type
|
from typing import List
|
||||||
|
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
from google.protobuf import symbol_database
|
from google.protobuf import symbol_database
|
||||||
|
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
|
||||||
get_matrix = _packet_getter.get_matrix
|
get_matrix = _packet_getter.get_matrix
|
||||||
|
|
||||||
|
|
||||||
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
|
def get_proto(packet: mp_packet.Packet) -> message.Message:
|
||||||
"""Get the content of a MediaPipe proto Packet as a proto message.
|
"""Get the content of a MediaPipe proto Packet as a proto message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -46,6 +46,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||||
|
|
|
@ -17,7 +17,7 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.containers.proto;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||||
option java_outer_classname = "CategoryProto";
|
option java_outer_classname = "CategoryProto";
|
||||||
|
|
||||||
// A single classification result.
|
// A single classification result.
|
||||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||||
option java_outer_classname = "ClassificationsProto";
|
option java_outer_classname = "ClassificationsProto";
|
||||||
|
|
||||||
// List of predicted categories with an optional timestamp.
|
// List of predicted categories with an optional timestamp.
|
||||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.containers.proto;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||||
|
option java_outer_classname = "EmbeddingsProto";
|
||||||
|
|
||||||
// Defines a dense floating-point embedding.
|
// Defines a dense floating-point embedding.
|
||||||
message FloatEmbedding {
|
message FloatEmbedding {
|
||||||
repeated float values = 1 [packed = true];
|
repeated float values = 1 [packed = true];
|
||||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -128,6 +129,9 @@ absl::Status ConfigureImageToTensorCalculator(
|
||||||
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
||||||
std);
|
std);
|
||||||
}
|
}
|
||||||
|
// TODO: need to.support different GPU origin on differnt
|
||||||
|
// platforms or applications.
|
||||||
|
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,6 +73,19 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "sentencepiece_tokenizer_test",
|
||||||
|
srcs = ["sentencepiece_tokenizer_test.cc"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:albert_model",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":sentencepiece_tokenizer",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tokenizer_utils",
|
name = "tokenizer_utils",
|
||||||
srcs = ["tokenizer_utils.cc"],
|
srcs = ["tokenizer_utils.cc"],
|
||||||
|
@ -95,6 +108,33 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tokenizer_utils_test",
|
||||||
|
srcs = ["tokenizer_utils_test.cc"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/text:albert_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:mobile_bert_model",
|
||||||
|
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||||
|
],
|
||||||
|
linkopts = ["-ldl"],
|
||||||
|
deps = [
|
||||||
|
":bert_tokenizer",
|
||||||
|
":regex_tokenizer",
|
||||||
|
":sentencepiece_tokenizer",
|
||||||
|
":tokenizer_utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:cord",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "regex_tokenizer",
|
name = "regex_tokenizer",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
|
@ -58,6 +58,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
||||||
|
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
||||||
|
@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap;
|
||||||
using ::mediapipe::tasks::vision::image_embedder::proto::
|
using ::mediapipe::tasks::vision::image_embedder::proto::
|
||||||
ImageEmbedderGraphOptions;
|
ImageEmbedderGraphOptions;
|
||||||
|
|
||||||
// Builds a NormalizedRect covering the entire image.
|
|
||||||
NormalizedRect BuildFullImageNormRect() {
|
|
||||||
NormalizedRect norm_rect;
|
|
||||||
norm_rect.set_x_center(0.5);
|
|
||||||
norm_rect.set_y_center(0.5);
|
|
||||||
norm_rect.set_width(1);
|
|
||||||
norm_rect.set_height(1);
|
|
||||||
return norm_rect;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a MediaPipe graph config that contains a single node of type
|
// Creates a MediaPipe graph config that contains a single node of type
|
||||||
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
|
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
|
||||||
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
|
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
|
||||||
|
@ -148,15 +139,16 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||||
Image image, std::optional<NormalizedRect> roi) {
|
Image image,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
"GPU input images are currently not supported.",
|
"GPU input images are currently not supported.",
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
NormalizedRect norm_rect =
|
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
ConvertToNormalizedRect(image_processing_options));
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_packets,
|
auto output_packets,
|
||||||
ProcessImageData(
|
ProcessImageData(
|
||||||
|
@ -167,15 +159,16 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||||
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
|
Image image, int64 timestamp_ms,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
"GPU input images are currently not supported.",
|
"GPU input images are currently not supported.",
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
NormalizedRect norm_rect =
|
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
ConvertToNormalizedRect(image_processing_options));
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_packets,
|
auto output_packets,
|
||||||
ProcessVideoData(
|
ProcessVideoData(
|
||||||
|
@ -188,16 +181,17 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms,
|
absl::Status ImageEmbedder::EmbedAsync(
|
||||||
std::optional<NormalizedRect> roi) {
|
Image image, int64 timestamp_ms,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
"GPU input images are currently not supported.",
|
"GPU input images are currently not supported.",
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
NormalizedRect norm_rect =
|
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
ConvertToNormalizedRect(image_processing_options));
|
||||||
return SendLiveStreamData(
|
return SendLiveStreamData(
|
||||||
{{kImageInStreamName,
|
{{kImageInStreamName,
|
||||||
MakePacket<Image>(std::move(image))
|
MakePacket<Image>(std::move(image))
|
||||||
|
|
|
@ -21,11 +21,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
|
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
|
||||||
std::unique_ptr<ImageEmbedderOptions> options);
|
std::unique_ptr<ImageEmbedderOptions> options);
|
||||||
|
|
||||||
// Performs embedding extraction on the provided single image. Extraction
|
// Performs embedding extraction on the provided single image.
|
||||||
// is performed on the region of interest specified by the `roi` argument if
|
//
|
||||||
// provided, or on the entire image otherwise.
|
// The optional 'image_processing_options' parameter can be used to specify:
|
||||||
|
// - the rotation to apply to the image before performing embedding
|
||||||
|
// extraction, by setting its 'rotation_degrees' field.
|
||||||
|
// and/or
|
||||||
|
// - the region-of-interest on which to perform embedding extraction, by
|
||||||
|
// setting its 'region_of_interest' field. If not specified, the full image
|
||||||
|
// is used.
|
||||||
|
// If both are specified, the crop around the region-of-interest is extracted
|
||||||
|
// first, then the specified rotation is applied to the crop.
|
||||||
//
|
//
|
||||||
// Only use this method when the ImageEmbedder is created with the image
|
// Only use this method when the ImageEmbedder is created with the image
|
||||||
// running mode.
|
// running mode.
|
||||||
|
@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
// Performs embedding extraction on the provided video frame. Extraction
|
// Performs embedding extraction on the provided video frame.
|
||||||
// is performed on the region of interested specified by the `roi` argument if
|
//
|
||||||
// provided, or on the entire image otherwise.
|
// The optional 'image_processing_options' parameter can be used to specify:
|
||||||
|
// - the rotation to apply to the image before performing embedding
|
||||||
|
// extraction, by setting its 'rotation_degrees' field.
|
||||||
|
// and/or
|
||||||
|
// - the region-of-interest on which to perform embedding extraction, by
|
||||||
|
// setting its 'region_of_interest' field. If not specified, the full image
|
||||||
|
// is used.
|
||||||
|
// If both are specified, the crop around the region-of-interest is extracted
|
||||||
|
// first, then the specified rotation is applied to the crop.
|
||||||
//
|
//
|
||||||
// Only use this method when the ImageEmbedder is created with the video
|
// Only use this method when the ImageEmbedder is created with the video
|
||||||
// running mode.
|
// running mode.
|
||||||
|
@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to embedder, and the results will be available via
|
// Sends live image data to embedder, and the results will be available via
|
||||||
// the "result_callback" provided in the ImageEmbedderOptions. Embedding
|
// the "result_callback" provided in the ImageEmbedderOptions.
|
||||||
// extraction is performed on the region of interested specified by the `roi`
|
//
|
||||||
// argument if provided, or on the entire image otherwise.
|
// The optional 'image_processing_options' parameter can be used to specify:
|
||||||
|
// - the rotation to apply to the image before performing embedding
|
||||||
|
// extraction, by setting its 'rotation_degrees' field.
|
||||||
|
// and/or
|
||||||
|
// - the region-of-interest on which to perform embedding extraction, by
|
||||||
|
// setting its 'region_of_interest' field. If not specified, the full image
|
||||||
|
// is used.
|
||||||
|
// If both are specified, the crop around the region-of-interest is extracted
|
||||||
|
// first, then the specified rotation is applied to the crop.
|
||||||
//
|
//
|
||||||
// Only use this method when the ImageEmbedder is created with the live
|
// Only use this method when the ImageEmbedder is created with the live
|
||||||
// stream running mode.
|
// stream running mode.
|
||||||
|
@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
||||||
// longer be valid when the callback returns. To access the image data
|
// longer be valid when the callback returns. To access the image data
|
||||||
// outside of the callback, callers need to make a copy of the image.
|
// outside of the callback, callers need to make a copy of the image.
|
||||||
// - The input timestamp in milliseconds.
|
// - The input timestamp in milliseconds.
|
||||||
absl::Status EmbedAsync(
|
absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
std::optional<core::ImageProcessingOptions>
|
||||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
image_processing_options = std::nullopt);
|
||||||
|
|
||||||
// Shuts down the ImageEmbedder when all works are done.
|
// Shuts down the ImageEmbedder when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
@ -42,7 +41,9 @@ namespace image_embedder {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::Rect;
|
||||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
Image crop, DecodeImageFromFile(
|
Image crop, DecodeImageFromFile(
|
||||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
// Bounding box in "burger.jpg" corresponding to "burger_crop.jpg".
|
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||||
NormalizedRect roi;
|
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
||||||
roi.set_x_center(200.0 / 480);
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
roi.set_y_center(0.5);
|
|
||||||
roi.set_width(400.0 / 480);
|
|
||||||
roi.set_height(1.0f);
|
|
||||||
|
|
||||||
// Extract both embeddings.
|
// Extract both embeddings.
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
image_embedder->Embed(image, roi));
|
const EmbeddingResult& image_result,
|
||||||
|
image_embedder->Embed(image, image_processing_options));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||||
image_embedder->Embed(crop));
|
image_embedder->Embed(crop));
|
||||||
|
|
||||||
|
@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
|
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
// Load images: one is a rotated version of the other.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
|
"burger_rotated.jpg")));
|
||||||
|
ImageProcessingOptions image_processing_options;
|
||||||
|
image_processing_options.rotation_degrees = -90;
|
||||||
|
|
||||||
|
// Extract both embeddings.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||||
|
image_embedder->Embed(image));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
const EmbeddingResult& rotated_result,
|
||||||
|
image_embedder->Embed(rotated, image_processing_options));
|
||||||
|
|
||||||
|
// Check results.
|
||||||
|
CheckMobileNetV3Result(image_result, false);
|
||||||
|
CheckMobileNetV3Result(rotated_result, false);
|
||||||
|
// CheckCosineSimilarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity,
|
||||||
|
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||||
|
rotated_result.embeddings(0).entries(0)));
|
||||||
|
double expected_similarity = 0.572265;
|
||||||
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
|
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||||
|
ImageEmbedder::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image crop, DecodeImageFromFile(
|
||||||
|
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||||
|
"burger_rotated.jpg")));
|
||||||
|
// Region-of-interest corresponding to burger_crop.jpg.
|
||||||
|
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
||||||
|
ImageProcessingOptions image_processing_options{roi,
|
||||||
|
/*rotation_degrees=*/-90};
|
||||||
|
|
||||||
|
// Extract both embeddings.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||||
|
image_embedder->Embed(crop));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
const EmbeddingResult& rotated_result,
|
||||||
|
image_embedder->Embed(rotated, image_processing_options));
|
||||||
|
|
||||||
|
// Check results.
|
||||||
|
CheckMobileNetV3Result(crop_result, false);
|
||||||
|
CheckMobileNetV3Result(rotated_result, false);
|
||||||
|
// CheckCosineSimilarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity,
|
||||||
|
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
|
||||||
|
rotated_result.embeddings(0).entries(0)));
|
||||||
|
double expected_similarity = 0.62838;
|
||||||
|
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||||
|
}
|
||||||
|
|
||||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||||
|
|
|
@ -24,10 +24,12 @@ cc_library(
|
||||||
":image_segmenter_graph",
|
":image_segmenter_graph",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
|
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||||
|
@ -48,6 +50,7 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
|
|
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
|
|
||||||
|
@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
|
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
|
||||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||||
|
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||||
graph.Out(kGroupedSegmentationTag);
|
graph.Out(kGroupedSegmentationTag);
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(
|
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
|
||||||
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag);
|
{kImageTag, kNormRectTag},
|
||||||
|
kGroupedSegmentationTag);
|
||||||
}
|
}
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
|
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,47 +146,68 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||||
mediapipe::Image image) {
|
mediapipe::Image image,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
NormalizedRect norm_rect,
|
||||||
|
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_packets,
|
auto output_packets,
|
||||||
ProcessImageData({{kImageInStreamName,
|
ProcessImageData(
|
||||||
mediapipe::MakePacket<Image>(std::move(image))}}));
|
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||||
|
{kNormRectStreamName,
|
||||||
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
|
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms) {
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
NormalizedRect norm_rect,
|
||||||
|
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_packets,
|
auto output_packets,
|
||||||
ProcessVideoData(
|
ProcessVideoData(
|
||||||
{{kImageInStreamName,
|
{{kImageInStreamName,
|
||||||
MakePacket<Image>(std::move(image))
|
MakePacket<Image>(std::move(image))
|
||||||
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||||
|
{kNormRectStreamName,
|
||||||
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) {
|
absl::Status ImageSegmenter::SegmentAsync(
|
||||||
|
Image image, int64 timestamp_ms,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
absl::StrCat("GPU input images are currently not supported."),
|
absl::StrCat("GPU input images are currently not supported."),
|
||||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||||
}
|
}
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
NormalizedRect norm_rect,
|
||||||
|
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||||
return SendLiveStreamData(
|
return SendLiveStreamData(
|
||||||
{{kImageInStreamName,
|
{{kImageInStreamName,
|
||||||
MakePacket<Image>(std::move(image))
|
MakePacket<Image>(std::move(image))
|
||||||
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||||
|
{kNormRectStreamName,
|
||||||
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// running mode.
|
// running mode.
|
||||||
//
|
//
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
// TODO: Describes how the input image will be preprocessed
|
//
|
||||||
// after the yuv support is implemented.
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing segmentation, by
|
||||||
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
//
|
//
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||||
// per-category segmented image mask.
|
// per-category segmented image mask.
|
||||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||||
// contains only one confidence image mask.
|
// contains only one confidence image mask.
|
||||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
|
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||||
|
mediapipe::Image image,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
// Performs image segmentation on the provided video frame.
|
// Performs image segmentation on the provided video frame.
|
||||||
// Only use this method when the ImageSegmenter is created with the video
|
// Only use this method when the ImageSegmenter is created with the video
|
||||||
|
@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
//
|
//
|
||||||
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing segmentation, by
|
||||||
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
|
//
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||||
// per-category segmented image mask.
|
// per-category segmented image mask.
|
||||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||||
// contains only one confidence image mask.
|
// contains only one confidence image mask.
|
||||||
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
|
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms);
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to perform image segmentation, and the results will
|
// Sends live image data to perform image segmentation, and the results will
|
||||||
// be available via the "result_callback" provided in the
|
// be available via the "result_callback" provided in the
|
||||||
|
@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// sent to the image segmenter. The input timestamps must be monotonically
|
// sent to the image segmenter. The input timestamps must be monotonically
|
||||||
// increasing.
|
// increasing.
|
||||||
//
|
//
|
||||||
|
// The optional 'image_processing_options' parameter can be used to specify
|
||||||
|
// the rotation to apply to the image before performing segmentation, by
|
||||||
|
// setting its 'rotation_degrees' field. Note that specifying a
|
||||||
|
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||||
|
// and will result in an invalid argument error being returned.
|
||||||
|
//
|
||||||
// The "result_callback" prvoides
|
// The "result_callback" prvoides
|
||||||
// - A vector of segmented image masks.
|
// - A vector of segmented image masks.
|
||||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||||
|
@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// no longer be valid when the callback returns. To access the image data
|
// no longer be valid when the callback returns. To access the image data
|
||||||
// outside of the callback, callers need to make a copy of the image.
|
// outside of the callback, callers need to make a copy of the image.
|
||||||
// - The input timestamp in milliseconds.
|
// - The input timestamp in milliseconds.
|
||||||
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms);
|
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||||
|
std::optional<core::ImageProcessingOptions>
|
||||||
|
image_processing_options = std::nullopt);
|
||||||
|
|
||||||
// Shuts down the ImageSegmenter when all works are done.
|
// Shuts down the ImageSegmenter when all works are done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||||
|
@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||||
|
|
||||||
|
@ -159,6 +161,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
// Image to perform segmentation on.
|
// Image to perform segmentation on.
|
||||||
|
// NORM_RECT - NormalizedRect @Optional
|
||||||
|
// Describes image rotation and region of image to perform detection
|
||||||
|
// on.
|
||||||
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// SEGMENTATION - mediapipe::Image @Multiple
|
// SEGMENTATION - mediapipe::Image @Multiple
|
||||||
|
@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||||
CreateModelResources<ImageSegmenterOptions>(sc));
|
CreateModelResources<ImageSegmenterOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(auto output_streams,
|
ASSIGN_OR_RETURN(
|
||||||
|
auto output_streams,
|
||||||
BuildSegmentationTask(
|
BuildSegmentationTask(
|
||||||
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
||||||
graph[Input<Image>(kImageTag)], graph));
|
graph[Input<Image>(kImageTag)],
|
||||||
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
|
|
||||||
auto& merge_images_to_vector =
|
auto& merge_images_to_vector =
|
||||||
graph.AddNode("MergeImagesToVectorCalculator");
|
graph.AddNode("MergeImagesToVectorCalculator");
|
||||||
|
@ -228,7 +236,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||||
const ImageSegmenterOptions& task_options,
|
const ImageSegmenterOptions& task_options,
|
||||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||||
Graph& graph) {
|
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||||
|
|
||||||
// Adds preprocessing calculators and connects them to the graph input image
|
// Adds preprocessing calculators and connects them to the graph input image
|
||||||
|
@ -240,6 +248,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
||||||
&preprocessing
|
&preprocessing
|
||||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
|
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||||
|
|
||||||
// Adds inference subgraph and connects its input stream to the output
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
// tensors produced by the ImageToTensorCalculator.
|
// tensors produced by the ImageToTensorCalculator.
|
||||||
|
|
|
@ -29,8 +29,10 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
@ -44,6 +46,8 @@ namespace {
|
||||||
|
|
||||||
using ::mediapipe::Image;
|
using ::mediapipe::Image;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::Rect;
|
||||||
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
ImageSegmenter::Create(std::move(options)));
|
ImageSegmenter::Create(std::move(options)));
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||||
EXPECT_EQ(confidence_masks.size(), 21);
|
EXPECT_EQ(confidence_masks.size(), 21);
|
||||||
|
|
||||||
|
@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
||||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image, DecodeImageFromFile(
|
||||||
|
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
|
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
|
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
ImageProcessingOptions image_processing_options;
|
||||||
|
image_processing_options.rotation_degrees = -90;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||||
|
EXPECT_EQ(confidence_masks.size(), 21);
|
||||||
|
|
||||||
|
cv::Mat expected_mask =
|
||||||
|
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
||||||
|
cv::IMREAD_GRAYSCALE);
|
||||||
|
cv::Mat expected_mask_float;
|
||||||
|
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||||
|
|
||||||
|
// Cat category index 8.
|
||||||
|
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||||
|
confidence_masks[8].GetImageFrameSharedPtr().get());
|
||||||
|
EXPECT_THAT(cat_mask,
|
||||||
|
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
Image image,
|
||||||
|
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
||||||
|
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||||
|
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||||
|
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||||
|
ImageSegmenter::Create(std::move(options)));
|
||||||
|
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||||
|
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||||
|
|
||||||
|
auto results = segmenter->Segment(image, image_processing_options);
|
||||||
|
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||||
|
EXPECT_THAT(results.status().message(),
|
||||||
|
HasSubstr("This task doesn't support region-of-interest"));
|
||||||
|
EXPECT_THAT(
|
||||||
|
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||||
|
Optional(absl::Cord(absl::StrCat(
|
||||||
|
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||||
Image image =
|
Image image =
|
||||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
||||||
|
|
|
@ -31,6 +31,7 @@ android_binary(
|
||||||
multidex = "native",
|
multidex = "native",
|
||||||
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
|
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector;
|
||||||
|
|
||||||
import android.content.Intent;
|
import android.content.Intent;
|
||||||
import android.graphics.Bitmap;
|
import android.graphics.Bitmap;
|
||||||
import android.graphics.Matrix;
|
|
||||||
import android.media.MediaMetadataRetriever;
|
import android.media.MediaMetadataRetriever;
|
||||||
import android.os.Bundle;
|
import android.os.Bundle;
|
||||||
import android.provider.MediaStore;
|
import android.provider.MediaStore;
|
||||||
|
@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher;
|
||||||
import androidx.activity.result.contract.ActivityResultContracts;
|
import androidx.activity.result.contract.ActivityResultContracts;
|
||||||
import androidx.exifinterface.media.ExifInterface;
|
import androidx.exifinterface.media.ExifInterface;
|
||||||
// ContentResolver dependency
|
// ContentResolver dependency
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.MPImage;
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
||||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
|
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
|
||||||
|
@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity {
|
||||||
if (resultIntent != null) {
|
if (resultIntent != null) {
|
||||||
if (result.getResultCode() == RESULT_OK) {
|
if (result.getResultCode() == RESULT_OK) {
|
||||||
Bitmap bitmap = null;
|
Bitmap bitmap = null;
|
||||||
|
int rotation = 0;
|
||||||
try {
|
try {
|
||||||
bitmap =
|
bitmap =
|
||||||
downscaleBitmap(
|
downscaleBitmap(
|
||||||
|
@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity {
|
||||||
try {
|
try {
|
||||||
InputStream imageData =
|
InputStream imageData =
|
||||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||||
bitmap = rotateBitmap(bitmap, imageData);
|
rotation = getImageRotation(imageData);
|
||||||
} catch (IOException e) {
|
} catch (IOException | MediaPipeException e) {
|
||||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||||
}
|
}
|
||||||
if (bitmap != null) {
|
if (bitmap != null) {
|
||||||
MPImage image = new BitmapImageBuilder(bitmap).build();
|
MPImage image = new BitmapImageBuilder(bitmap).build();
|
||||||
ObjectDetectionResult detectionResult = objectDetector.detect(image);
|
ObjectDetectionResult detectionResult =
|
||||||
|
objectDetector.detect(
|
||||||
|
image,
|
||||||
|
ImageProcessingOptions.builder().setRotationDegrees(rotation).build());
|
||||||
imageView.setData(image, detectionResult);
|
imageView.setData(image, detectionResult);
|
||||||
runOnUiThread(() -> imageView.update());
|
runOnUiThread(() -> imageView.update());
|
||||||
}
|
}
|
||||||
|
@ -210,28 +215,25 @@ public class MainActivity extends AppCompatActivity {
|
||||||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException {
|
||||||
int orientation =
|
int orientation =
|
||||||
new ExifInterface(imageData)
|
new ExifInterface(imageData)
|
||||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
|
||||||
return inputBitmap;
|
|
||||||
}
|
|
||||||
Matrix matrix = new Matrix();
|
|
||||||
switch (orientation) {
|
switch (orientation) {
|
||||||
|
case ExifInterface.ORIENTATION_NORMAL:
|
||||||
|
return 0;
|
||||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||||
matrix.postRotate(90);
|
return 90;
|
||||||
break;
|
|
||||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||||
matrix.postRotate(180);
|
return 180;
|
||||||
break;
|
|
||||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||||
matrix.postRotate(270);
|
return 270;
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
matrix.postRotate(0);
|
// TODO: use getRotationDegrees() and isFlipped() instead of switch once flip
|
||||||
}
|
// is supported.
|
||||||
return Bitmap.createBitmap(
|
throw new MediaPipeException(
|
||||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(),
|
||||||
|
"Flipped images are not supported yet.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
package com.google.mediapipe.tasks.text.textclassifier;
|
package com.google.mediapipe.tasks.text.textclassifier;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
|
||||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
|
@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
|
||||||
import com.google.mediapipe.framework.Packet;
|
import com.google.mediapipe.framework.Packet;
|
||||||
import com.google.mediapipe.framework.PacketGetter;
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
import com.google.mediapipe.framework.ProtoUtil;
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
|
|
@ -28,6 +28,7 @@ android_library(
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
/** The base class of MediaPipe vision tasks. */
|
/** The base class of MediaPipe vision tasks. */
|
||||||
public class BaseVisionTaskApi implements AutoCloseable {
|
public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
|
@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
private final TaskRunner runner;
|
private final TaskRunner runner;
|
||||||
private final RunningMode runningMode;
|
private final RunningMode runningMode;
|
||||||
private final String imageStreamName;
|
private final String imageStreamName;
|
||||||
private final Optional<String> normRectStreamName;
|
private final String normRectStreamName;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
System.loadLibrary("mediapipe_tasks_vision_jni");
|
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||||
|
@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
|
* Constructor to initialize a {@link BaseVisionTaskApi}.
|
||||||
*
|
*
|
||||||
* @param runner a {@link TaskRunner}.
|
* @param runner a {@link TaskRunner}.
|
||||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||||
* @param imageStreamName the name of the input image stream.
|
* @param imageStreamName the name of the input image stream.
|
||||||
*/
|
* @param normRectStreamName the name of the input normalized rect image stream used to provide
|
||||||
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
|
* (mandatory) rotation and (optional) region-of-interest.
|
||||||
this.runner = runner;
|
|
||||||
this.runningMode = runningMode;
|
|
||||||
this.imageStreamName = imageStreamName;
|
|
||||||
this.normRectStreamName = Optional.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
|
|
||||||
* input.
|
|
||||||
*
|
|
||||||
* @param runner a {@link TaskRunner}.
|
|
||||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
|
||||||
* @param imageStreamName the name of the input image stream.
|
|
||||||
* @param normRectStreamName the name of the input normalized rect image stream.
|
|
||||||
*/
|
*/
|
||||||
public BaseVisionTaskApi(
|
public BaseVisionTaskApi(
|
||||||
TaskRunner runner,
|
TaskRunner runner,
|
||||||
|
@ -70,7 +55,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
this.runner = runner;
|
this.runner = runner;
|
||||||
this.runningMode = runningMode;
|
this.runningMode = runningMode;
|
||||||
this.imageStreamName = imageStreamName;
|
this.imageStreamName = imageStreamName;
|
||||||
this.normRectStreamName = Optional.of(normRectStreamName);
|
this.normRectStreamName = normRectStreamName;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -78,53 +63,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
* failure status or a successful result is returned.
|
* failure status or a successful result is returned.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
* input.
|
* input image before running inference.
|
||||||
|
* @throws MediaPipeException if the task is not in the image mode.
|
||||||
*/
|
*/
|
||||||
protected TaskResult processImageData(MPImage image) {
|
protected TaskResult processImageData(
|
||||||
|
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
if (runningMode != RunningMode.IMAGE) {
|
if (runningMode != RunningMode.IMAGE) {
|
||||||
throw new MediaPipeException(
|
throw new MediaPipeException(
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
"Task is not initialized with the image mode. Current running mode:"
|
"Task is not initialized with the image mode. Current running mode:"
|
||||||
+ runningMode.name());
|
+ runningMode.name());
|
||||||
}
|
}
|
||||||
if (normRectStreamName.isPresent()) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task expects a normalized rect as input.");
|
|
||||||
}
|
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
|
||||||
return runner.process(inputPackets);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
|
||||||
* failure status or a successful result is returned.
|
|
||||||
*
|
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
|
||||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
|
||||||
* are expected to be specified as normalized values in [0,1].
|
|
||||||
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
|
|
||||||
* rect.
|
|
||||||
*/
|
|
||||||
protected TaskResult processImageData(MPImage image, RectF roi) {
|
|
||||||
if (runningMode != RunningMode.IMAGE) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task is not initialized with the image mode. Current running mode:"
|
|
||||||
+ runningMode.name());
|
|
||||||
}
|
|
||||||
if (!normRectStreamName.isPresent()) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task doesn't expect a normalized rect as input.");
|
|
||||||
}
|
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
inputPackets.put(
|
inputPackets.put(
|
||||||
normRectStreamName.get(),
|
normRectStreamName,
|
||||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||||
return runner.process(inputPackets);
|
return runner.process(inputPackets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,55 +88,24 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
* until a failure status or a successful result is returned.
|
* until a failure status or a successful result is returned.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference.
|
||||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
* @throws MediaPipeException if the task is not in the video mode.
|
||||||
* input.
|
|
||||||
*/
|
*/
|
||||||
protected TaskResult processVideoData(MPImage image, long timestampMs) {
|
protected TaskResult processVideoData(
|
||||||
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
if (runningMode != RunningMode.VIDEO) {
|
if (runningMode != RunningMode.VIDEO) {
|
||||||
throw new MediaPipeException(
|
throw new MediaPipeException(
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
"Task is not initialized with the video mode. Current running mode:"
|
"Task is not initialized with the video mode. Current running mode:"
|
||||||
+ runningMode.name());
|
+ runningMode.name());
|
||||||
}
|
}
|
||||||
if (normRectStreamName.isPresent()) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task expects a normalized rect as input.");
|
|
||||||
}
|
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
|
||||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
|
||||||
* until a failure status or a successful result is returned.
|
|
||||||
*
|
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
|
||||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
|
||||||
* are expected to be specified as normalized values in [0,1].
|
|
||||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
|
||||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
|
||||||
* rect.
|
|
||||||
*/
|
|
||||||
protected TaskResult processVideoData(MPImage image, RectF roi, long timestampMs) {
|
|
||||||
if (runningMode != RunningMode.VIDEO) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task is not initialized with the video mode. Current running mode:"
|
|
||||||
+ runningMode.name());
|
|
||||||
}
|
|
||||||
if (!normRectStreamName.isPresent()) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task doesn't expect a normalized rect as input.");
|
|
||||||
}
|
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
inputPackets.put(
|
inputPackets.put(
|
||||||
normRectStreamName.get(),
|
normRectStreamName,
|
||||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,55 +114,24 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
* available in the user-defined result listener.
|
* available in the user-defined result listener.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference.
|
||||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
* @throws MediaPipeException if the task is not in the stream mode.
|
||||||
* input.
|
|
||||||
*/
|
*/
|
||||||
protected void sendLiveStreamData(MPImage image, long timestampMs) {
|
protected void sendLiveStreamData(
|
||||||
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||||
throw new MediaPipeException(
|
throw new MediaPipeException(
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
"Task is not initialized with the live stream mode. Current running mode:"
|
"Task is not initialized with the live stream mode. Current running mode:"
|
||||||
+ runningMode.name());
|
+ runningMode.name());
|
||||||
}
|
}
|
||||||
if (normRectStreamName.isPresent()) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task expects a normalized rect as input.");
|
|
||||||
}
|
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
|
||||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
|
||||||
* available in the user-defined result listener.
|
|
||||||
*
|
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
|
||||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
|
||||||
* are expected to be specified as normalized values in [0,1].
|
|
||||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
|
||||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
|
||||||
* rect.
|
|
||||||
*/
|
|
||||||
protected void sendLiveStreamData(MPImage image, RectF roi, long timestampMs) {
|
|
||||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task is not initialized with the live stream mode. Current running mode:"
|
|
||||||
+ runningMode.name());
|
|
||||||
}
|
|
||||||
if (!normRectStreamName.isPresent()) {
|
|
||||||
throw new MediaPipeException(
|
|
||||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
|
||||||
"Task doesn't expect a normalized rect as input.");
|
|
||||||
}
|
|
||||||
Map<String, Packet> inputPackets = new HashMap<>();
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
inputPackets.put(
|
inputPackets.put(
|
||||||
normRectStreamName.get(),
|
normRectStreamName,
|
||||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
runner.close();
|
runner.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
|
/**
|
||||||
private static NormalizedRect convertToNormalizedRect(RectF rect) {
|
* Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf
|
||||||
|
* message.
|
||||||
|
*/
|
||||||
|
private static NormalizedRect convertToNormalizedRect(
|
||||||
|
ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
RectF regionOfInterest =
|
||||||
|
imageProcessingOptions.regionOfInterest().isPresent()
|
||||||
|
? imageProcessingOptions.regionOfInterest().get()
|
||||||
|
: new RectF(0, 0, 1, 1);
|
||||||
return NormalizedRect.newBuilder()
|
return NormalizedRect.newBuilder()
|
||||||
.setXCenter(rect.centerX())
|
.setXCenter(regionOfInterest.centerX())
|
||||||
.setYCenter(rect.centerY())
|
.setYCenter(regionOfInterest.centerY())
|
||||||
.setWidth(rect.width())
|
.setWidth(regionOfInterest.width())
|
||||||
.setHeight(rect.height())
|
.setHeight(regionOfInterest.height())
|
||||||
|
// Convert to radians anti-clockwise.
|
||||||
|
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.core;
|
||||||
|
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
// TODO: add support for image flipping.
|
||||||
|
/** Options for image processing. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ImageProcessingOptions {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builder for {@link ImageProcessingOptions}.
|
||||||
|
*
|
||||||
|
* <p>If both region-of-interest and rotation are specified, the crop around the
|
||||||
|
* region-of-interest is extracted first, then the specified rotation is applied to the crop.
|
||||||
|
*/
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/**
|
||||||
|
* Sets the optional region-of-interest to crop from the image. If not specified, the full image
|
||||||
|
* is used.
|
||||||
|
*
|
||||||
|
* <p>Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be
|
||||||
|
* < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()}
|
||||||
|
* is called.
|
||||||
|
*/
|
||||||
|
public abstract Builder setRegionOfInterest(RectF value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the rotation to apply to the image (or cropped region-of-interest), in degrees
|
||||||
|
* clockwise. Defaults to 0.
|
||||||
|
*
|
||||||
|
* <p>The rotation must be a multiple (positive or negative) of 90°, otherwise an
|
||||||
|
* IllegalArgumentException will be thrown when {@link #build()} is called.
|
||||||
|
*/
|
||||||
|
public abstract Builder setRotationDegrees(int value);
|
||||||
|
|
||||||
|
abstract ImageProcessingOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link ImageProcessingOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if some of the provided values do not meet their
|
||||||
|
* requirements.
|
||||||
|
*/
|
||||||
|
public final ImageProcessingOptions build() {
|
||||||
|
ImageProcessingOptions options = autoBuild();
|
||||||
|
if (options.regionOfInterest().isPresent()) {
|
||||||
|
RectF roi = options.regionOfInterest().get();
|
||||||
|
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Expected left < right and top < bottom, found: %s.", roi.toShortString()));
|
||||||
|
}
|
||||||
|
if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (options.rotationDegrees() % 90 != 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Expected rotation to be a multiple of 90°, found: %d.",
|
||||||
|
options.rotationDegrees()));
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract Optional<RectF> regionOfInterest();
|
||||||
|
|
||||||
|
public abstract int rotationDegrees();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0);
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,7 +15,6 @@
|
||||||
package com.google.mediapipe.tasks.vision.gesturerecognizer;
|
package com.google.mediapipe.tasks.vision.gesturerecognizer;
|
||||||
|
|
||||||
import android.content.Context;
|
import android.content.Context;
|
||||||
import android.graphics.RectF;
|
|
||||||
import android.os.ParcelFileDescriptor;
|
import android.os.ParcelFileDescriptor;
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||||
|
@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto;
|
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto;
|
||||||
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto;
|
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto;
|
||||||
|
@ -212,6 +212,25 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs gesture recognition on the provided single image with default image processing
|
||||||
|
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||||
|
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||||
|
* for input image format.
|
||||||
|
*
|
||||||
|
* <p>{@link GestureRecognizer} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public GestureRecognitionResult recognize(MPImage image) {
|
||||||
|
return recognize(image, ImageProcessingOptions.builder().build());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs gesture recognition on the provided single image. Only use this method when the {@link
|
* Performs gesture recognition on the provided single image. Only use this method when the {@link
|
||||||
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||||
|
@ -223,12 +242,41 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public GestureRecognitionResult recognize(MPImage inputImage) {
|
public GestureRecognitionResult recognize(
|
||||||
// TODO: add proper support for rotations.
|
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF());
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
return (GestureRecognitionResult) processImageData(image, imageProcessingOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs gesture recognition on the provided video frame with default image processing options,
|
||||||
|
* i.e. without any rotation applied. Only use this method when the {@link GestureRecognizer} is
|
||||||
|
* created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link GestureRecognizer} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public GestureRecognitionResult recognizeForVideo(MPImage image, long timestampMs) {
|
||||||
|
return recognizeForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -244,14 +292,43 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public GestureRecognitionResult recognizeForVideo(MPImage inputImage, long inputTimestampMs) {
|
public GestureRecognitionResult recognizeForVideo(
|
||||||
// TODO: add proper support for rotations.
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
return (GestureRecognitionResult)
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
return (GestureRecognitionResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data to perform gesture recognition with default image processing options,
|
||||||
|
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||||
|
* ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method when the
|
||||||
|
* {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the gesture recognizer. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link GestureRecognizer} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void recognizeAsync(MPImage image, long timestampMs) {
|
||||||
|
recognizeAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -268,13 +345,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public void recognizeAsync(MPImage inputImage, long inputTimestampMs) {
|
public void recognizeAsync(
|
||||||
// TODO: add proper support for rotations.
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Options for setting up an {@link GestureRecognizer}. */
|
/** Options for setting up an {@link GestureRecognizer}. */
|
||||||
|
@ -445,8 +529,14 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Creates a RectF covering the full image. */
|
/**
|
||||||
private static RectF buildFullImageRectF() {
|
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||||
return new RectF(0, 0, 1, 1);
|
* region-of-interest.
|
||||||
|
*/
|
||||||
|
private static void validateImageProcessingOptions(
|
||||||
|
ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||||
|
throw new IllegalArgumentException("GestureRecognizer doesn't support region-of-interest.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,11 +15,11 @@
|
||||||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||||
|
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
|
||||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
|
||||||
import com.google.mediapipe.tasks.components.containers.Category;
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.core.TaskResult;
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||||
|
|
||||||
import android.content.Context;
|
import android.content.Context;
|
||||||
import android.graphics.RectF;
|
|
||||||
import android.os.ParcelFileDescriptor;
|
import android.os.ParcelFileDescriptor;
|
||||||
import com.google.auto.value.AutoValue;
|
import com.google.auto.value.AutoValue;
|
||||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
@ -26,7 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
||||||
import com.google.mediapipe.framework.ProtoUtil;
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
import com.google.mediapipe.framework.image.MPImage;
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
|
@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
|
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -215,6 +215,24 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs classification on the provided single image with default image processing options,
|
||||||
|
* i.e. using the whole image as region-of-interest and without any rotation applied. Only use
|
||||||
|
* this method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||||
|
*
|
||||||
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ImageClassificationResult classify(MPImage image) {
|
||||||
|
return classify(image, ImageProcessingOptions.builder().build());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs classification on the provided single image. Only use this method when the {@link
|
* Performs classification on the provided single image. Only use this method when the {@link
|
||||||
* ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
* ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||||
|
@ -225,16 +243,23 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classify(MPImage inputImage) {
|
public ImageClassificationResult classify(
|
||||||
return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF());
|
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
return (ImageClassificationResult) processImageData(image, imageProcessingOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs classification on the provided single image and region-of-interest. Only use this
|
* Performs classification on the provided video frame with default image processing options, i.e.
|
||||||
* method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
* using the whole image as region-of-interest and without any rotation applied. Only use this
|
||||||
|
* method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
*
|
*
|
||||||
* <p>{@link ImageClassifier} supports the following color space types:
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
*
|
*
|
||||||
|
@ -242,13 +267,12 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classify(MPImage inputImage, RectF roi) {
|
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) {
|
||||||
return (ImageClassificationResult) processImageData(inputImage, roi);
|
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -264,21 +288,26 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classifyForVideo(MPImage inputImage, long inputTimestampMs) {
|
public ImageClassificationResult classifyForVideo(
|
||||||
return (ImageClassificationResult)
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs classification on the provided video frame with additional region-of-interest. Only
|
* Sends live image data to perform classification with default image processing options, i.e.
|
||||||
* use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
|
* using the whole image as region-of-interest and without any rotation applied, and the results
|
||||||
|
* will be available via the {@link ResultListener} provided in the {@link
|
||||||
|
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
|
||||||
|
* {@link RunningMode.LIVE_STREAM}.
|
||||||
*
|
*
|
||||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
* must be monotonically increasing.
|
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
*
|
*
|
||||||
* <p>{@link ImageClassifier} supports the following color space types:
|
* <p>{@link ImageClassifier} supports the following color space types:
|
||||||
*
|
*
|
||||||
|
@ -286,15 +315,12 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ImageClassificationResult classifyForVideo(
|
public void classifyAsync(MPImage image, long timestampMs) {
|
||||||
MPImage inputImage, RectF roi, long inputTimestampMs) {
|
classifyAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -311,37 +337,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public void classifyAsync(MPImage inputImage, long inputTimestampMs) {
|
public void classifyAsync(
|
||||||
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
}
|
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends live image data and additional region-of-interest to perform classification, and the
|
|
||||||
* results will be available via the {@link ResultListener} provided in the {@link
|
|
||||||
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
|
|
||||||
* {@link RunningMode.LIVE_STREAM}.
|
|
||||||
*
|
|
||||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
|
||||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
|
||||||
*
|
|
||||||
* <p>{@link ImageClassifier} supports the following color space types:
|
|
||||||
*
|
|
||||||
* <ul>
|
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
|
||||||
* </ul>
|
|
||||||
*
|
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
|
||||||
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
|
||||||
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
|
||||||
* @throws MediaPipeException if there is an internal error.
|
|
||||||
*/
|
|
||||||
public void classifyAsync(MPImage inputImage, RectF roi, long inputTimestampMs) {
|
|
||||||
sendLiveStreamData(inputImage, roi, inputTimestampMs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Options for setting up and {@link ImageClassifier}. */
|
/** Options for setting up and {@link ImageClassifier}. */
|
||||||
|
@ -447,9 +451,4 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Creates a RectF covering the full image. */
|
|
||||||
private static RectF buildFullImageRectF() {
|
|
||||||
return new RectF(0, 0, 1, 1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto;
|
import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto;
|
||||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
|
@ -96,8 +97,10 @@ import java.util.Optional;
|
||||||
public final class ObjectDetector extends BaseVisionTaskApi {
|
public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
private static final String TAG = ObjectDetector.class.getSimpleName();
|
private static final String TAG = ObjectDetector.class.getSimpleName();
|
||||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||||
|
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||||
private static final List<String> INPUT_STREAMS =
|
private static final List<String> INPUT_STREAMS =
|
||||||
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
private static final List<String> OUTPUT_STREAMS =
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
||||||
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
||||||
|
@ -204,7 +207,25 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||||
*/
|
*/
|
||||||
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
|
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
|
||||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
|
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs object detection on the provided single image with default image processing options,
|
||||||
|
* i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is
|
||||||
|
* created with {@link RunningMode.IMAGE}.
|
||||||
|
*
|
||||||
|
* <p>{@link ObjectDetector} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ObjectDetectionResult detect(MPImage image) {
|
||||||
|
return detect(image, ImageProcessingOptions.builder().build());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -217,11 +238,41 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionResult detect(MPImage inputImage) {
|
public ObjectDetectionResult detect(
|
||||||
return (ObjectDetectionResult) processImageData(inputImage);
|
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs object detection on the provided video frame with default image processing options,
|
||||||
|
* i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is
|
||||||
|
* created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ObjectDetector} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) {
|
||||||
|
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -237,12 +288,43 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public ObjectDetectionResult detectForVideo(MPImage inputImage, long inputTimestampMs) {
|
public ObjectDetectionResult detectForVideo(
|
||||||
return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data to perform object detection with default image processing options, i.e.
|
||||||
|
* without any rotation applied, and the results will be available via the {@link ResultListener}
|
||||||
|
* provided in the {@link ObjectDetectorOptions}. Only use this method when the {@link
|
||||||
|
* ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ObjectDetector} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void detectAsync(MPImage image, long timestampMs) {
|
||||||
|
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -259,12 +341,20 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
* @throws MediaPipeException if there is an internal error.
|
* @throws MediaPipeException if there is an internal error.
|
||||||
*/
|
*/
|
||||||
public void detectAsync(MPImage inputImage, long inputTimestampMs) {
|
public void detectAsync(
|
||||||
sendLiveStreamData(inputImage, inputTimestampMs);
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Options for setting up an {@link ObjectDetector}. */
|
/** Options for setting up an {@link ObjectDetector}. */
|
||||||
|
@ -415,4 +505,15 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||||
|
* region-of-interest.
|
||||||
|
*/
|
||||||
|
private static void validateImageProcessingOptions(
|
||||||
|
ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||||
|
throw new IllegalArgumentException("ObjectDetector doesn't support region-of-interest.");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.vision.coretest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="coretest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.vision.coretest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable this in OSS
|
|
@ -0,0 +1,70 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.core;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
|
||||||
|
/** Test for {@link ImageProcessingOptions}/ */
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public final class ImageProcessingOptionsTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void succeedsWithValidInputs() throws Exception {
|
||||||
|
ImageProcessingOptions options =
|
||||||
|
ImageProcessingOptions.builder()
|
||||||
|
.setRegionOfInterest(new RectF(0.0f, 0.1f, 1.0f, 0.9f))
|
||||||
|
.setRotationDegrees(270)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failsWithLeftHigherThanRight() {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageProcessingOptions.builder()
|
||||||
|
.setRegionOfInterest(new RectF(0.9f, 0.0f, 0.1f, 1.0f))
|
||||||
|
.build());
|
||||||
|
assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failsWithBottomHigherThanTop() {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ImageProcessingOptions.builder()
|
||||||
|
.setRegionOfInterest(new RectF(0.0f, 0.9f, 1.0f, 0.1f))
|
||||||
|
.build());
|
||||||
|
assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void failsWithInvalidRotation() {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() -> ImageProcessingOptions.builder().setRotationDegrees(1).build());
|
||||||
|
assertThat(exception).hasMessageThat().contains("Expected rotation to be a multiple of 90°");
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
import android.content.res.AssetManager;
|
import android.content.res.AssetManager;
|
||||||
import android.graphics.BitmapFactory;
|
import android.graphics.BitmapFactory;
|
||||||
|
import android.graphics.RectF;
|
||||||
import androidx.test.core.app.ApplicationProvider;
|
import androidx.test.core.app.ApplicationProvider;
|
||||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
import com.google.common.truth.Correspondence;
|
import com.google.common.truth.Correspondence;
|
||||||
|
@ -30,6 +31,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.containers.Landmark;
|
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||||
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
|
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions;
|
import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
@ -46,11 +48,14 @@ public class GestureRecognizerTest {
|
||||||
private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task";
|
private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task";
|
||||||
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
|
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
|
||||||
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
|
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
|
||||||
|
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
|
||||||
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
|
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
|
||||||
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
|
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
|
||||||
private static final String TAG = "Gesture Recognizer Test";
|
private static final String TAG = "Gesture Recognizer Test";
|
||||||
private static final String THUMB_UP_LABEL = "Thumb_Up";
|
private static final String THUMB_UP_LABEL = "Thumb_Up";
|
||||||
private static final int THUMB_UP_INDEX = 5;
|
private static final int THUMB_UP_INDEX = 5;
|
||||||
|
private static final String POINTING_UP_LABEL = "Pointing_Up";
|
||||||
|
private static final int POINTING_UP_INDEX = 3;
|
||||||
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
||||||
private static final int IMAGE_WIDTH = 382;
|
private static final int IMAGE_WIDTH = 382;
|
||||||
private static final int IMAGE_HEIGHT = 406;
|
private static final int IMAGE_HEIGHT = 406;
|
||||||
|
@ -135,6 +140,53 @@ public class GestureRecognizerTest {
|
||||||
gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE));
|
gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE));
|
||||||
assertThat(actualResult.handednesses()).hasSize(2);
|
assertThat(actualResult.handednesses()).hasSize(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_successWithRotation() throws Exception {
|
||||||
|
GestureRecognizerOptions options =
|
||||||
|
GestureRecognizerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setNumHands(1)
|
||||||
|
.build();
|
||||||
|
GestureRecognizer gestureRecognizer =
|
||||||
|
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||||
|
GestureRecognitionResult actualResult =
|
||||||
|
gestureRecognizer.recognize(
|
||||||
|
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
assertThat(actualResult.gestures()).hasSize(1);
|
||||||
|
assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX);
|
||||||
|
assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_failsWithRegionOfInterest() throws Exception {
|
||||||
|
GestureRecognizerOptions options =
|
||||||
|
GestureRecognizerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setNumHands(1)
|
||||||
|
.build();
|
||||||
|
GestureRecognizer gestureRecognizer =
|
||||||
|
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
gestureRecognizer.recognize(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("GestureRecognizer doesn't support region-of-interest");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
@ -195,12 +247,16 @@ public class GestureRecognizerTest {
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
() ->
|
||||||
|
gestureRecognizer.recognizeForVideo(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
() ->
|
||||||
|
gestureRecognizer.recognizeAsync(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,7 +281,9 @@ public class GestureRecognizerTest {
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
() ->
|
||||||
|
gestureRecognizer.recognizeAsync(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,7 +309,9 @@ public class GestureRecognizerTest {
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
() ->
|
||||||
|
gestureRecognizer.recognizeForVideo(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,7 +351,8 @@ public class GestureRecognizerTest {
|
||||||
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
|
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
GestureRecognitionResult actualResult =
|
GestureRecognitionResult actualResult =
|
||||||
gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i);
|
gestureRecognizer.recognizeForVideo(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i);
|
||||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -317,9 +378,11 @@ public class GestureRecognizerTest {
|
||||||
.build();
|
.build();
|
||||||
try (GestureRecognizer gestureRecognizer =
|
try (GestureRecognizer gestureRecognizer =
|
||||||
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
gestureRecognizer.recognizeAsync(image, 1);
|
gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 1);
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0));
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 0));
|
||||||
assertThat(exception)
|
assertThat(exception)
|
||||||
.hasMessageThat()
|
.hasMessageThat()
|
||||||
.contains("having a smaller timestamp than the processed timestamp");
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
@ -348,7 +411,7 @@ public class GestureRecognizerTest {
|
||||||
try (GestureRecognizer gestureRecognizer =
|
try (GestureRecognizer gestureRecognizer =
|
||||||
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
gestureRecognizer.recognizeAsync(image, i);
|
gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.TestUtils;
|
import com.google.mediapipe.tasks.core.TestUtils;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
|
import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
@ -47,7 +48,9 @@ public class ImageClassifierTest {
|
||||||
private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
|
private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
|
||||||
private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
|
private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
|
||||||
private static final String BURGER_IMAGE = "burger.jpg";
|
private static final String BURGER_IMAGE = "burger.jpg";
|
||||||
|
private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg";
|
||||||
private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
|
private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
|
||||||
|
private static final String MULTI_OBJECTS_ROTATED_IMAGE = "multi_objects_rotated.jpg";
|
||||||
|
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
public static final class General extends ImageClassifierTest {
|
public static final class General extends ImageClassifierTest {
|
||||||
|
@ -209,13 +212,60 @@ public class ImageClassifierTest {
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
// RectF around the soccer ball.
|
// RectF around the soccer ball.
|
||||||
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
|
||||||
ImageClassificationResult results =
|
ImageClassificationResult results =
|
||||||
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi);
|
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithRotation() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||||
|
ImageClassificationResult results =
|
||||||
|
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results,
|
||||||
|
Arrays.asList(
|
||||||
|
Category.create(0.6390683f, 934, "cheeseburger", ""),
|
||||||
|
Category.create(0.0495407f, 963, "meat loaf", ""),
|
||||||
|
Category.create(0.0469720f, 925, "guacamole", "")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void classify_succeedsWithRegionOfInterestAndRotation() throws Exception {
|
||||||
|
ImageClassifierOptions options =
|
||||||
|
ImageClassifierOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||||
|
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||||
|
.build();
|
||||||
|
ImageClassifier imageClassifier =
|
||||||
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
// RectF around the chair.
|
||||||
|
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
|
||||||
|
ImageClassificationResult results =
|
||||||
|
imageClassifier.classify(
|
||||||
|
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
|
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||||
|
assertCategoriesAre(
|
||||||
|
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@RunWith(AndroidJUnit4.class)
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
@ -269,12 +319,16 @@ public class ImageClassifierTest {
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
|
() ->
|
||||||
|
imageClassifier.classifyForVideo(
|
||||||
|
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
|
() ->
|
||||||
|
imageClassifier.classifyAsync(
|
||||||
|
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,7 +350,9 @@ public class ImageClassifierTest {
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
|
() ->
|
||||||
|
imageClassifier.classifyAsync(
|
||||||
|
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,7 +376,9 @@ public class ImageClassifierTest {
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
|
() ->
|
||||||
|
imageClassifier.classifyForVideo(
|
||||||
|
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -352,7 +410,8 @@ public class ImageClassifierTest {
|
||||||
ImageClassifier imageClassifier =
|
ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
ImageClassificationResult results = imageClassifier.classifyForVideo(image, i);
|
ImageClassificationResult results =
|
||||||
|
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
||||||
assertHasOneHeadAndOneTimestamp(results, i);
|
assertHasOneHeadAndOneTimestamp(results, i);
|
||||||
assertCategoriesAre(
|
assertCategoriesAre(
|
||||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||||
|
@ -377,9 +436,11 @@ public class ImageClassifierTest {
|
||||||
.build();
|
.build();
|
||||||
try (ImageClassifier imageClassifier =
|
try (ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1);
|
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0));
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0));
|
||||||
assertThat(exception)
|
assertThat(exception)
|
||||||
.hasMessageThat()
|
.hasMessageThat()
|
||||||
.contains("having a smaller timestamp than the processed timestamp");
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
@ -405,7 +466,7 @@ public class ImageClassifierTest {
|
||||||
try (ImageClassifier imageClassifier =
|
try (ImageClassifier imageClassifier =
|
||||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int i = 0; i < 3; ++i) {
|
||||||
imageClassifier.classifyAsync(image, i);
|
imageClassifier.classifyAsync(image, /*timestampMs=*/ i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
import com.google.mediapipe.tasks.components.containers.Detection;
|
import com.google.mediapipe.tasks.components.containers.Detection;
|
||||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
import com.google.mediapipe.tasks.core.TestUtils;
|
import com.google.mediapipe.tasks.core.TestUtils;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
|
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
@ -45,10 +46,11 @@ import org.junit.runners.Suite.SuiteClasses;
|
||||||
public class ObjectDetectorTest {
|
public class ObjectDetectorTest {
|
||||||
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||||
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
||||||
|
private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg";
|
||||||
private static final int IMAGE_WIDTH = 1200;
|
private static final int IMAGE_WIDTH = 1200;
|
||||||
private static final int IMAGE_HEIGHT = 600;
|
private static final int IMAGE_HEIGHT = 600;
|
||||||
private static final float CAT_SCORE = 0.69f;
|
private static final float CAT_SCORE = 0.69f;
|
||||||
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
|
private static final RectF CAT_BOUNDING_BOX = new RectF(611, 164, 986, 596);
|
||||||
// TODO: Figure out why android_x86 and android_arm tests have slightly different
|
// TODO: Figure out why android_x86 and android_arm tests have slightly different
|
||||||
// scores (0.6875 vs 0.69921875).
|
// scores (0.6875 vs 0.69921875).
|
||||||
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
|
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
|
||||||
|
@ -67,7 +69,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -104,7 +106,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
// The score threshold should block all other other objects, except cat.
|
// The score threshold should block all other other objects, except cat.
|
||||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -175,7 +177,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -228,6 +230,46 @@ public class ObjectDetectorTest {
|
||||||
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
|
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithRotation() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setMaxResults(1)
|
||||||
|
.setCategoryAllowlist(Arrays.asList("cat"))
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||||
|
ObjectDetectionResult results =
|
||||||
|
objectDetector.detect(
|
||||||
|
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
|
||||||
|
assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_failsWithRegionOfInterest() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
objectDetector.detect(
|
||||||
|
getImageFromAsset(CAT_AND_DOG_IMAGE), imageProcessingOptions));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("ObjectDetector doesn't support region-of-interest");
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
|
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
|
||||||
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
|
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
|
||||||
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
|
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
|
||||||
|
@ -282,12 +324,16 @@ public class ObjectDetectorTest {
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
() ->
|
||||||
|
objectDetector.detectForVideo(
|
||||||
|
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
() ->
|
||||||
|
objectDetector.detectAsync(
|
||||||
|
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,7 +355,9 @@ public class ObjectDetectorTest {
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
() ->
|
||||||
|
objectDetector.detectAsync(
|
||||||
|
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -333,7 +381,9 @@ public class ObjectDetectorTest {
|
||||||
exception =
|
exception =
|
||||||
assertThrows(
|
assertThrows(
|
||||||
MediaPipeException.class,
|
MediaPipeException.class,
|
||||||
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
() ->
|
||||||
|
objectDetector.detectForVideo(
|
||||||
|
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,7 +398,7 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector objectDetector =
|
ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -363,8 +413,9 @@ public class ObjectDetectorTest {
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
ObjectDetectionResult results =
|
ObjectDetectionResult results =
|
||||||
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
|
objectDetector.detectForVideo(
|
||||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i);
|
||||||
|
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -377,16 +428,18 @@ public class ObjectDetectorTest {
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(objectDetectionResult, inputImage) -> {
|
(objectDetectionResult, inputImage) -> {
|
||||||
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
assertImageSizeIsExpected(inputImage);
|
assertImageSizeIsExpected(inputImage);
|
||||||
})
|
})
|
||||||
.setMaxResults(1)
|
.setMaxResults(1)
|
||||||
.build();
|
.build();
|
||||||
try (ObjectDetector objectDetector =
|
try (ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
objectDetector.detectAsync(image, 1);
|
objectDetector.detectAsync(image, /*timestampsMs=*/ 1);
|
||||||
MediaPipeException exception =
|
MediaPipeException exception =
|
||||||
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0));
|
||||||
assertThat(exception)
|
assertThat(exception)
|
||||||
.hasMessageThat()
|
.hasMessageThat()
|
||||||
.contains("having a smaller timestamp than the processed timestamp");
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
@ -402,7 +455,7 @@ public class ObjectDetectorTest {
|
||||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
.setResultListener(
|
.setResultListener(
|
||||||
(objectDetectionResult, inputImage) -> {
|
(objectDetectionResult, inputImage) -> {
|
||||||
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||||
assertImageSizeIsExpected(inputImage);
|
assertImageSizeIsExpected(inputImage);
|
||||||
})
|
})
|
||||||
.setMaxResults(1)
|
.setMaxResults(1)
|
||||||
|
@ -410,7 +463,7 @@ public class ObjectDetectorTest {
|
||||||
try (ObjectDetector objectDetector =
|
try (ObjectDetector objectDetector =
|
||||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
objectDetector.detectAsync(image, i);
|
objectDetector.detectAsync(image, /*timestampsMs=*/ i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,3 +86,13 @@ py_library(
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "classifications",
|
||||||
|
srcs = ["classifications.py"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
168
mediapipe/tasks/python/components/containers/classifications.py
Normal file
168
mediapipe/tasks/python/components/containers/classifications.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Classifications data class."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
|
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
|
||||||
|
_ClassificationsProto = classifications_pb2.Classifications
|
||||||
|
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ClassificationEntry:
|
||||||
|
"""List of predicted classes (aka labels) for a given classifier head.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
categories: The array of predicted categories, usually sorted by descending
|
||||||
|
scores (e.g. from high to low probability).
|
||||||
|
timestamp_ms: The optional timestamp (in milliseconds) associated to the
|
||||||
|
classification entry. This is useful for time series use cases, e.g.,
|
||||||
|
audio classification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
categories: List[category_module.Category]
|
||||||
|
timestamp_ms: Optional[int] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ClassificationEntryProto:
|
||||||
|
"""Generates a ClassificationEntry protobuf object."""
|
||||||
|
return _ClassificationEntryProto(
|
||||||
|
categories=[category.to_pb2() for category in self.categories],
|
||||||
|
timestamp_ms=self.timestamp_ms)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
|
||||||
|
"""Creates a `ClassificationEntry` object from the given protobuf object."""
|
||||||
|
return ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
category_module.Category.create_from_pb2(category)
|
||||||
|
for category in pb2_obj.categories
|
||||||
|
],
|
||||||
|
timestamp_ms=pb2_obj.timestamp_ms)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, ClassificationEntry):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Classifications:
|
||||||
|
"""Represents the classifications for a given classifier head.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
entries: A list of `ClassificationEntry` objects.
|
||||||
|
head_index: The index of the classifier head these categories refer to. This
|
||||||
|
is useful for multi-head models.
|
||||||
|
head_name: The name of the classifier head, which is the corresponding
|
||||||
|
tensor metadata name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
entries: List[ClassificationEntry]
|
||||||
|
head_index: int
|
||||||
|
head_name: str
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ClassificationsProto:
|
||||||
|
"""Generates a Classifications protobuf object."""
|
||||||
|
return _ClassificationsProto(
|
||||||
|
entries=[entry.to_pb2() for entry in self.entries],
|
||||||
|
head_index=self.head_index,
|
||||||
|
head_name=self.head_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||||
|
"""Creates a `Classifications` object from the given protobuf object."""
|
||||||
|
return Classifications(
|
||||||
|
entries=[
|
||||||
|
ClassificationEntry.create_from_pb2(entry)
|
||||||
|
for entry in pb2_obj.entries
|
||||||
|
],
|
||||||
|
head_index=pb2_obj.head_index,
|
||||||
|
head_name=pb2_obj.head_name)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, Classifications):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ClassificationResult:
|
||||||
|
"""Contains one set of results per classifier head.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
classifications: A list of `Classifications` objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
classifications: List[Classifications]
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ClassificationResultProto:
|
||||||
|
"""Generates a ClassificationResult protobuf object."""
|
||||||
|
return _ClassificationResultProto(classifications=[
|
||||||
|
classification.to_pb2() for classification in self.classifications
|
||||||
|
])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def create_from_pb2(
|
||||||
|
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
|
||||||
|
"""Creates a `ClassificationResult` object from the given protobuf object.
|
||||||
|
"""
|
||||||
|
return ClassificationResult(classifications=[
|
||||||
|
Classifications.create_from_pb2(classification)
|
||||||
|
for classification in pb2_obj.classifications
|
||||||
|
])
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other: The object to be compared with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the objects are equal.
|
||||||
|
"""
|
||||||
|
if not isinstance(other, ClassificationResult):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -26,9 +26,7 @@ _NormalizedRectProto = rect_pb2.NormalizedRect
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Rect:
|
class Rect:
|
||||||
"""A rectangle with rotation in image coordinates.
|
"""A rectangle with rotation in image coordinates.
|
||||||
|
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
|
||||||
Attributes:
|
|
||||||
x_center : The X coordinate of the top-left corner, in pixels.
|
|
||||||
y_center : The Y coordinate of the top-left corner, in pixels.
|
y_center : The Y coordinate of the top-left corner, in pixels.
|
||||||
width: The width of the rectangle, in pixels.
|
width: The width of the rectangle, in pixels.
|
||||||
height: The height of the rectangle, in pixels.
|
height: The height of the rectangle, in pixels.
|
||||||
|
@ -81,11 +79,10 @@ class Rect:
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class NormalizedRect:
|
class NormalizedRect:
|
||||||
"""A rectangle with rotation in normalized coordinates. The values of box
|
"""A rectangle with rotation in normalized coordinates.
|
||||||
|
The values of box
|
||||||
center location and size are within [0, 1].
|
center location and size are within [0, 1].
|
||||||
|
Attributes: x_center : The X normalized coordinate of the top-left corner.
|
||||||
Attributes:
|
|
||||||
x_center : The X normalized coordinate of the top-left corner.
|
|
||||||
y_center : The Y normalized coordinate of the top-left corner.
|
y_center : The Y normalized coordinate of the top-left corner.
|
||||||
width: The width of the rectangle.
|
width: The width of the rectangle.
|
||||||
height: The height of the rectangle.
|
height: The height of the rectangle.
|
||||||
|
@ -110,8 +107,7 @@ class NormalizedRect:
|
||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
rotation=self.rotation,
|
rotation=self.rotation,
|
||||||
rect_id=self.rect_id
|
rect_id=self.rect_id)
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
|
@ -123,8 +119,7 @@ class NormalizedRect:
|
||||||
width=pb2_obj.width,
|
width=pb2_obj.width,
|
||||||
height=pb2_obj.height,
|
height=pb2_obj.height,
|
||||||
rotation=pb2_obj.rotation,
|
rotation=pb2_obj.rotation,
|
||||||
rect_id=pb2_obj.rect_id
|
rect_id=pb2_obj.rect_id)
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Checks if this object is equal to the given object.
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
# Placeholder for internal Python strict library compatibility macro.
|
# Placeholder for internal Python strict library compatibility macro.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
|
@ -61,19 +61,13 @@ class ClassifierOptions:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def create_from_pb2(
|
def create_from_pb2(cls,
|
||||||
cls,
|
pb2_obj: _ClassifierOptionsProto) -> 'ClassifierOptions':
|
||||||
pb2_obj: _ClassifierOptionsProto
|
|
||||||
) -> 'ClassifierOptions':
|
|
||||||
"""Creates a `ClassifierOptions` object from the given protobuf object."""
|
"""Creates a `ClassifierOptions` object from the given protobuf object."""
|
||||||
return ClassifierOptions(
|
return ClassifierOptions(
|
||||||
score_threshold=pb2_obj.score_threshold,
|
score_threshold=pb2_obj.score_threshold,
|
||||||
category_allowlist=[
|
category_allowlist=[str(name) for name in pb2_obj.category_allowlist],
|
||||||
str(name) for name in pb2_obj.class_name_allowlist
|
category_denylist=[str(name) for name in pb2_obj.category_denylist],
|
||||||
],
|
|
||||||
category_denylist=[
|
|
||||||
str(name) for name in pb2_obj.class_name_denylist
|
|
||||||
],
|
|
||||||
display_names_locale=pb2_obj.display_names_locale,
|
display_names_locale=pb2_obj.display_names_locale,
|
||||||
max_results=pb2_obj.max_results)
|
max_results=pb2_obj.max_results)
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,26 @@ py_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "image_classifier_test",
|
||||||
|
srcs = ["image_classifier_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_models",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/vision:image_classifier",
|
||||||
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "gesture_recognizer_test",
|
name = "gesture_recognizer_test",
|
||||||
srcs = ["gesture_recognizer_test.py"],
|
srcs = ["gesture_recognizer_test.py"],
|
||||||
|
|
515
mediapipe/tasks/python/test/vision/image_classifier_test.py
Normal file
515
mediapipe/tasks/python/test/vision/image_classifier_test.py
Normal file
|
@ -0,0 +1,515 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for image classifier."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mediapipe.python._framework_bindings import image
|
||||||
|
from mediapipe.tasks.python.components.containers import category
|
||||||
|
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||||
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.vision import image_classifier
|
||||||
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
|
_NormalizedRect = rect.NormalizedRect
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
_Category = category.Category
|
||||||
|
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||||
|
_Classifications = classifications_module.Classifications
|
||||||
|
_ClassificationResult = classifications_module.ClassificationResult
|
||||||
|
_Image = image.Image
|
||||||
|
_ImageClassifier = image_classifier.ImageClassifier
|
||||||
|
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
|
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
|
||||||
|
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||||
|
_IMAGE_FILE = 'burger.jpg'
|
||||||
|
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||||
|
_DENY_LIST = ['cheeseburger']
|
||||||
|
_SCORE_THRESHOLD = 0.5
|
||||||
|
_MAX_RESULTS = 3
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Port assertProtoEquals
|
||||||
|
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
return _ClassificationResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
return _ClassificationResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
_Category(
|
||||||
|
index=934,
|
||||||
|
score=0.7939587831497192,
|
||||||
|
display_name='',
|
||||||
|
category_name='cheeseburger'),
|
||||||
|
_Category(
|
||||||
|
index=932,
|
||||||
|
score=0.02739289402961731,
|
||||||
|
display_name='',
|
||||||
|
category_name='bagel'),
|
||||||
|
_Category(
|
||||||
|
index=925,
|
||||||
|
score=0.01934075355529785,
|
||||||
|
display_name='',
|
||||||
|
category_name='guacamole'),
|
||||||
|
_Category(
|
||||||
|
index=963,
|
||||||
|
score=0.006327860057353973,
|
||||||
|
display_name='',
|
||||||
|
category_name='meat loaf')
|
||||||
|
],
|
||||||
|
timestamp_ms=timestamp_ms)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
||||||
|
return _ClassificationResult(classifications=[
|
||||||
|
_Classifications(
|
||||||
|
entries=[
|
||||||
|
_ClassificationEntry(
|
||||||
|
categories=[
|
||||||
|
_Category(
|
||||||
|
index=806,
|
||||||
|
score=0.9965274930000305,
|
||||||
|
display_name='',
|
||||||
|
category_name='soccer ball')
|
||||||
|
],
|
||||||
|
timestamp_ms=timestamp_ms)
|
||||||
|
],
|
||||||
|
head_index=0,
|
||||||
|
head_name='probability')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifierTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||||
|
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _ImageClassifier)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _ImageClassifierOptions(base_options=base_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
self.assertIsInstance(classifier, _ImageClassifier)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
# Invalid empty model path.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
r"ExternalFile must specify at least one of 'file_content', "
|
||||||
|
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||||
|
base_options = _BaseOptions(model_asset_path='')
|
||||||
|
options = _ImageClassifierOptions(base_options=base_options)
|
||||||
|
_ImageClassifier.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _ImageClassifierOptions(base_options=base_options)
|
||||||
|
classifier = _ImageClassifier.create_from_options(options)
|
||||||
|
self.assertIsInstance(classifier, _ImageClassifier)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
||||||
|
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
||||||
|
def test_classify(self, model_file_type, max_results,
|
||||||
|
expected_classification_result):
|
||||||
|
# Creates classifier.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=max_results)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=base_options, classifier_options=custom_classifier_options)
|
||||||
|
classifier = _ImageClassifier.create_from_options(options)
|
||||||
|
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
# Comparing results.
|
||||||
|
_assert_proto_equals(image_result.to_pb2(),
|
||||||
|
expected_classification_result.to_pb2())
|
||||||
|
# Closes the classifier explicitly when the classifier is not used in
|
||||||
|
# a context.
|
||||||
|
classifier.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
||||||
|
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
||||||
|
def test_classify_in_context(self, model_file_type, max_results,
|
||||||
|
expected_classification_result):
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=max_results)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=base_options, classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
# Comparing results.
|
||||||
|
_assert_proto_equals(image_result.to_pb2(),
|
||||||
|
expected_classification_result.to_pb2())
|
||||||
|
|
||||||
|
def test_classify_succeeds_with_region_of_interest(self):
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=base_options, classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Load the test image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||||
|
# NormalizedRect around the soccer ball.
|
||||||
|
roi = _NormalizedRect(
|
||||||
|
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(test_image, roi)
|
||||||
|
# Comparing results.
|
||||||
|
_assert_proto_equals(image_result.to_pb2(),
|
||||||
|
_generate_soccer_ball_results(0).to_pb2())
|
||||||
|
|
||||||
|
def test_score_threshold_option(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(
|
||||||
|
score_threshold=_SCORE_THRESHOLD)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
classifications = image_result.classifications
|
||||||
|
|
||||||
|
for classification in classifications:
|
||||||
|
for entry in classification.entries:
|
||||||
|
score = entry.categories[0].score
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
score, _SCORE_THRESHOLD,
|
||||||
|
f'Classification with score lower than threshold found. '
|
||||||
|
f'{classification}')
|
||||||
|
|
||||||
|
def test_max_results_option(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(
|
||||||
|
score_threshold=_SCORE_THRESHOLD)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
categories = image_result.classifications[0].entries[0].categories
|
||||||
|
|
||||||
|
self.assertLessEqual(
|
||||||
|
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||||
|
|
||||||
|
def test_allow_list_option(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(
|
||||||
|
category_allowlist=_ALLOW_LIST)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
classifications = image_result.classifications
|
||||||
|
|
||||||
|
for classification in classifications:
|
||||||
|
for entry in classification.entries:
|
||||||
|
label = entry.categories[0].category_name
|
||||||
|
self.assertIn(label, _ALLOW_LIST,
|
||||||
|
f'Label {label} found but not in label allow list')
|
||||||
|
|
||||||
|
def test_deny_list_option(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
classifications = image_result.classifications
|
||||||
|
|
||||||
|
for classification in classifications:
|
||||||
|
for entry in classification.entries:
|
||||||
|
label = entry.categories[0].category_name
|
||||||
|
self.assertNotIn(label, _DENY_LIST,
|
||||||
|
f'Label {label} found but in deny list.')
|
||||||
|
|
||||||
|
def test_combined_allowlist_and_denylist(self):
|
||||||
|
# Fails with combined allowlist and denylist
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
r'`category_allowlist` and `category_denylist` are mutually '
|
||||||
|
r'exclusive options.'):
|
||||||
|
custom_classifier_options = _ClassifierOptions(
|
||||||
|
category_allowlist=['foo'], category_denylist=['bar'])
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_empty_classification_outputs(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(score_threshold=1)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Performs image classification on the input.
|
||||||
|
image_result = classifier.classify(self.test_image)
|
||||||
|
self.assertEmpty(image_result.classifications[0].entries[0].categories)
|
||||||
|
|
||||||
|
def test_missing_result_callback(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback must be provided'):
|
||||||
|
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||||
|
def test_illegal_result_callback(self, running_mode):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=running_mode,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback should not be provided'):
|
||||||
|
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_calling_classify_for_video_in_image_mode(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
classifier.classify_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_classify_async_in_image_mode(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
classifier.classify_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_classify_in_video_mode(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
classifier.classify(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_classify_async_in_video_mode(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
classifier.classify_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_classify_for_video_with_out_of_order_timestamp(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
unused_result = classifier.classify_for_video(self.test_image, 1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
classifier.classify_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_classify_for_video(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=4)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
classification_result = classifier.classify_for_video(
|
||||||
|
self.test_image, timestamp)
|
||||||
|
_assert_proto_equals(classification_result.to_pb2(),
|
||||||
|
_generate_burger_results(timestamp).to_pb2())
|
||||||
|
|
||||||
|
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
classifier_options=custom_classifier_options)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
# Load the test image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||||
|
# NormalizedRect around the soccer ball.
|
||||||
|
roi = _NormalizedRect(
|
||||||
|
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
classification_result = classifier.classify_for_video(
|
||||||
|
test_image, timestamp, roi)
|
||||||
|
self.assertEqual(classification_result,
|
||||||
|
_generate_soccer_ball_results(timestamp))
|
||||||
|
|
||||||
|
def test_calling_classify_in_live_stream_mode(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
classifier.classify(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_classify_for_video_in_live_stream_mode(self):
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
classifier.classify_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=4)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
classifier_options=custom_classifier_options,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
classifier.classify_async(self.test_image, 100)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
classifier.classify_async(self.test_image, 0)
|
||||||
|
|
||||||
|
@parameterized.parameters((0, _generate_burger_results),
|
||||||
|
(1, _generate_empty_results))
|
||||||
|
def test_classify_async_calls(self, threshold, expected_result_fn):
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
|
timestamp_ms: int):
|
||||||
|
_assert_proto_equals(result.to_pb2(),
|
||||||
|
expected_result_fn(timestamp_ms).to_pb2())
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(output_image.numpy_view(),
|
||||||
|
self.test_image.numpy_view()))
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
custom_classifier_options = _ClassifierOptions(
|
||||||
|
max_results=4, score_threshold=threshold)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
classifier_options=custom_classifier_options,
|
||||||
|
result_callback=check_result)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
classifier.classify_async(self.test_image, timestamp)
|
||||||
|
|
||||||
|
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||||
|
# Load the test image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||||
|
# NormalizedRect around the soccer ball.
|
||||||
|
roi = _NormalizedRect(
|
||||||
|
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||||
|
timestamp_ms: int):
|
||||||
|
_assert_proto_equals(result.to_pb2(),
|
||||||
|
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||||
|
self.assertEqual(output_image.width, test_image.width)
|
||||||
|
self.assertEqual(output_image.height, test_image.height)
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||||
|
options = _ImageClassifierOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
classifier_options=custom_classifier_options,
|
||||||
|
result_callback=check_result)
|
||||||
|
with _ImageClassifier.create_from_options(options) as classifier:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
classifier.classify_async(test_image, timestamp, roi)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -37,6 +37,28 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_classifier",
|
||||||
|
srcs = [
|
||||||
|
"image_classifier.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:classifications",
|
||||||
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "gesture_recognizer",
|
name = "gesture_recognizer",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
294
mediapipe/tasks/python/vision/image_classifier.py
Normal file
294
mediapipe/tasks/python/vision/image_classifier.py
Normal file
|
@ -0,0 +1,294 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe image classifier task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Callable, Mapping, Optional
|
||||||
|
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
# TODO: Import MPImage directly one we have an alias
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.python._framework_bindings import packet
|
||||||
|
from mediapipe.python._framework_bindings import task_runner
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||||
|
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import classifications
|
||||||
|
from mediapipe.tasks.python.components.containers import rect
|
||||||
|
from mediapipe.tasks.python.components.processors import classifier_options
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
|
_NormalizedRect = rect.NormalizedRect
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||||
|
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||||
|
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
_TaskRunner = task_runner.TaskRunner
|
||||||
|
|
||||||
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||||
|
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||||
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
_NORM_RECT_NAME = 'norm_rect_in'
|
||||||
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||||
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def _build_full_image_norm_rect() -> _NormalizedRect:
|
||||||
|
# Builds a NormalizedRect covering the entire image.
|
||||||
|
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ImageClassifierOptions:
|
||||||
|
"""Options for the image classifier task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the image classifier task.
|
||||||
|
running_mode: The running mode of the task. Default to the image mode. Image
|
||||||
|
classifier task has three running modes: 1) The image mode for classifying
|
||||||
|
objects on single image inputs. 2) The video mode for classifying objects
|
||||||
|
on the decoded frames of a video. 3) The live stream mode for classifying
|
||||||
|
objects on a live stream of input data, such as from camera.
|
||||||
|
classifier_options: Options for the image classification task.
|
||||||
|
result_callback: The user-defined result callback for processing live stream
|
||||||
|
data. The result callback should only be specified when the running mode
|
||||||
|
is set to the live stream mode.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
|
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||||
|
result_callback: Optional[
|
||||||
|
Callable[[classifications.ClassificationResult, image_module.Image, int],
|
||||||
|
None]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||||
|
"""Generates an ImageClassifierOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
|
classifier_options_proto = self.classifier_options.to_pb2()
|
||||||
|
|
||||||
|
return _ImageClassifierGraphOptionsProto(
|
||||||
|
base_options=base_options_proto,
|
||||||
|
classifier_options=classifier_options_proto)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
"""Class that performs image classification on images."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'ImageClassifier':
|
||||||
|
"""Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
|
||||||
|
|
||||||
|
Note that the created `ImageClassifier` instance is in image mode, for
|
||||||
|
classifying objects on single image inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`ImageClassifier` object that's created from the model file and the
|
||||||
|
default `ImageClassifierOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `ImageClassifier` object from the provided
|
||||||
|
file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = ImageClassifierOptions(
|
||||||
|
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls,
|
||||||
|
options: ImageClassifierOptions) -> 'ImageClassifier':
|
||||||
|
"""Creates the `ImageClassifier` object from image classifier options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the image classifier task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`ImageClassifier` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `ImageClassifier` object from
|
||||||
|
`ImageClassifierOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||||
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
|
return
|
||||||
|
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(
|
||||||
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
|
classification_result = classifications.ClassificationResult([
|
||||||
|
classifications.Classifications.create_from_pb2(classification)
|
||||||
|
for classification in classification_result_proto.classifications
|
||||||
|
])
|
||||||
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
|
options.result_callback(classification_result, image,
|
||||||
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
||||||
|
],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([
|
||||||
|
_CLASSIFICATION_RESULT_TAG,
|
||||||
|
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
||||||
|
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(
|
||||||
|
task_info.generate_graph_config(
|
||||||
|
enable_flow_limiting=options.running_mode ==
|
||||||
|
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||||
|
packets_callback if options.result_callback else None)
|
||||||
|
|
||||||
|
# TODO: Replace _NormalizedRect with ImageProcessingOption
|
||||||
|
def classify(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
roi: Optional[_NormalizedRect] = None
|
||||||
|
) -> classifications.ClassificationResult:
|
||||||
|
"""Performs image classification on the provided MediaPipe Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
roi: The region of interest.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A classification result object that contains a list of classifications.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If image classification failed to run.
|
||||||
|
"""
|
||||||
|
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||||
|
output_packets = self._process_image_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||||
|
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
|
||||||
|
})
|
||||||
|
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(
|
||||||
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
|
return classifications.ClassificationResult([
|
||||||
|
classifications.Classifications.create_from_pb2(classification)
|
||||||
|
for classification in classification_result_proto.classifications
|
||||||
|
])
|
||||||
|
|
||||||
|
def classify_for_video(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
timestamp_ms: int,
|
||||||
|
roi: Optional[_NormalizedRect] = None
|
||||||
|
) -> classifications.ClassificationResult:
|
||||||
|
"""Performs image classification on the provided video frames.
|
||||||
|
|
||||||
|
Only use this method when the ImageClassifier is created with the video
|
||||||
|
running mode. It's required to provide the video frame's timestamp (in
|
||||||
|
milliseconds) along with the video frame. The input timestamps should be
|
||||||
|
monotonically increasing for adjacent calls of this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
|
roi: The region of interest.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A classification result object that contains a list of classifications.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If image classification failed to run.
|
||||||
|
"""
|
||||||
|
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||||
|
output_packets = self._process_video_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
|
_NORM_RECT_NAME:
|
||||||
|
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
||||||
|
|
||||||
|
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||||
|
classification_result_proto.CopyFrom(
|
||||||
|
packet_getter.get_proto(
|
||||||
|
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||||
|
|
||||||
|
return classifications.ClassificationResult([
|
||||||
|
classifications.Classifications.create_from_pb2(classification)
|
||||||
|
for classification in classification_result_proto.classifications
|
||||||
|
])
|
||||||
|
|
||||||
|
def classify_async(self,
|
||||||
|
image: image_module.Image,
|
||||||
|
timestamp_ms: int,
|
||||||
|
roi: Optional[_NormalizedRect] = None) -> None:
|
||||||
|
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
||||||
|
|
||||||
|
Only use this method when the ImageClassifier is created with the live
|
||||||
|
stream running mode. The input timestamps should be monotonically increasing
|
||||||
|
for adjacent calls of this method. This method will return immediately after
|
||||||
|
the input image is accepted. The results will be available via the
|
||||||
|
`result_callback` provided in the `ImageClassifierOptions`. The
|
||||||
|
`classify_async` method is designed to process live stream data such as
|
||||||
|
camera input. To lower the overall latency, image classifier may drop the
|
||||||
|
input images if needed. In other words, it's not guaranteed to have output
|
||||||
|
per input image.
|
||||||
|
|
||||||
|
The `result_callback` provides:
|
||||||
|
- A classification result object that contains a list of classifications.
|
||||||
|
- The input image that the image classifier runs on.
|
||||||
|
- The input timestamp in milliseconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||||
|
roi: The region of interest.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the current input timestamp is smaller than what the image
|
||||||
|
classifier has already processed.
|
||||||
|
"""
|
||||||
|
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||||
|
self._send_live_stream_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
|
_NORM_RECT_NAME:
|
||||||
|
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
4
mediapipe/tasks/testdata/vision/BUILD
vendored
4
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -28,6 +28,8 @@ mediapipe_files(srcs = [
|
||||||
"burger_rotated.jpg",
|
"burger_rotated.jpg",
|
||||||
"cat.jpg",
|
"cat.jpg",
|
||||||
"cat_mask.jpg",
|
"cat_mask.jpg",
|
||||||
|
"cat_rotated.jpg",
|
||||||
|
"cat_rotated_mask.jpg",
|
||||||
"cats_and_dogs.jpg",
|
"cats_and_dogs.jpg",
|
||||||
"cats_and_dogs_no_resizing.jpg",
|
"cats_and_dogs_no_resizing.jpg",
|
||||||
"cats_and_dogs_rotated.jpg",
|
"cats_and_dogs_rotated.jpg",
|
||||||
|
@ -84,6 +86,8 @@ filegroup(
|
||||||
"burger_rotated.jpg",
|
"burger_rotated.jpg",
|
||||||
"cat.jpg",
|
"cat.jpg",
|
||||||
"cat_mask.jpg",
|
"cat_mask.jpg",
|
||||||
|
"cat_rotated.jpg",
|
||||||
|
"cat_rotated_mask.jpg",
|
||||||
"cats_and_dogs.jpg",
|
"cats_and_dogs.jpg",
|
||||||
"cats_and_dogs_no_resizing.jpg",
|
"cats_and_dogs_no_resizing.jpg",
|
||||||
"cats_and_dogs_rotated.jpg",
|
"cats_and_dogs_rotated.jpg",
|
||||||
|
|
|
@ -82,7 +82,8 @@ absl::StatusOr<std::string> PathToResourceAsFile(const std::string& path) {
|
||||||
// If that fails, assume it was a relative path, and try just the base name.
|
// If that fails, assume it was a relative path, and try just the base name.
|
||||||
{
|
{
|
||||||
const size_t last_slash_idx = path.find_last_of("\\/");
|
const size_t last_slash_idx = path.find_last_of("\\/");
|
||||||
CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path.
|
RET_CHECK(last_slash_idx != std::string::npos)
|
||||||
|
<< path << " doesn't have a slash in it"; // Make sure it's a path.
|
||||||
auto base_name = path.substr(last_slash_idx + 1);
|
auto base_name = path.substr(last_slash_idx + 1);
|
||||||
auto status_or_path = PathToResourceAsFileInternal(base_name);
|
auto status_or_path = PathToResourceAsFileInternal(base_name);
|
||||||
if (status_or_path.ok()) {
|
if (status_or_path.ok()) {
|
||||||
|
|
|
@ -71,7 +71,8 @@ absl::StatusOr<std::string> PathToResourceAsFile(const std::string& path) {
|
||||||
// If that fails, assume it was a relative path, and try just the base name.
|
// If that fails, assume it was a relative path, and try just the base name.
|
||||||
{
|
{
|
||||||
const size_t last_slash_idx = path.find_last_of("\\/");
|
const size_t last_slash_idx = path.find_last_of("\\/");
|
||||||
CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path.
|
RET_CHECK(last_slash_idx != std::string::npos)
|
||||||
|
<< path << " doesn't have a slash in it"; // Make sure it's a path.
|
||||||
auto base_name = path.substr(last_slash_idx + 1);
|
auto base_name = path.substr(last_slash_idx + 1);
|
||||||
auto status_or_path = PathToResourceAsFileInternal(base_name);
|
auto status_or_path = PathToResourceAsFileInternal(base_name);
|
||||||
if (status_or_path.ok()) {
|
if (status_or_path.ok()) {
|
||||||
|
|
40
third_party/external_files.bzl
vendored
40
third_party/external_files.bzl
vendored
|
@ -76,6 +76,18 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_cat_rotated_jpg",
|
||||||
|
sha256 = "b78cee5ad14c9f36b1c25d103db371d81ca74d99030063c46a38e80bb8f38649",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated.jpg?generation=1666304165042123"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_cat_rotated_mask_jpg",
|
||||||
|
sha256 = "f336973e7621d602f2ebc9a6ab1c62d8502272d391713f369d3b99541afda861",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated_mask.jpg?generation=1666304167148173"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_cats_and_dogs_jpg",
|
name = "com_google_mediapipe_cats_and_dogs_jpg",
|
||||||
sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c",
|
sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c",
|
||||||
|
@ -162,8 +174,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
|
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
|
||||||
sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf",
|
sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666629474155924"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -174,8 +186,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
|
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
|
||||||
sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a",
|
sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -258,8 +270,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
|
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
|
||||||
sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370",
|
sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666629478777955"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -456,14 +468,14 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_json",
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_json",
|
||||||
sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b",
|
sha256 = "94613ea9539a20a3352604004be6d4d64d4d76250bc9042fcd8685c9a8498517",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1666633416316646"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json",
|
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json",
|
||||||
sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6",
|
sha256 = "3703eadcf838b65bbc2b2aa11dbb1f1bc654c7a09a7aba5ca75a26096484a8ac",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1666633418665507"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -606,8 +618,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
|
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
|
||||||
sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de",
|
sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -798,8 +810,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
|
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
|
||||||
sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102",
|
sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666629489421733"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user