Merge branch 'master' into gesture-recognizer-python

This commit is contained in:
Kinar R 2022-10-26 11:37:12 +05:30 committed by GitHub
commit 0de97497fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 2182 additions and 444 deletions

View File

@ -163,7 +163,6 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
}
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
const auto& input_tensors = *kInTensors(cc);
RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
@ -182,12 +181,6 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
auto raw_scores = view.buffer<float>();
auto classification_list = absl::make_unique<ClassificationList>();
if (options.has_tensor_index()) {
classification_list->set_tensor_index(options.tensor_index());
}
if (options.has_tensor_name()) {
classification_list->set_tensor_name(options.tensor_name());
}
if (is_binary_classification_) {
Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification();

View File

@ -72,9 +72,4 @@ message TensorsToClassificationCalculatorOptions {
// that are not in the `allow_classes` field will be completely ignored.
// `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 allow_classes = 8 [packed = true];
// The optional index of the tensor these classifications originate from.
optional int32 tensor_index = 10;
// The optional name of the tensor these classifications originate from.
optional string tensor_name = 11;
}

View File

@ -240,36 +240,6 @@ TEST_F(TensorsToClassificationCalculatorTest,
}
}
TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithTensorNameAndIndex) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
tensor_index: 1
tensor_name: "foo"
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
// Verify that the tensor_index and tensor_name fields are correctly set.
EXPECT_EQ(classification_list.tensor_index(), 1);
EXPECT_EQ(classification_list.tensor_name(), "foo");
}
TEST_F(TensorsToClassificationCalculatorTest,
ClassNameAllowlistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(

View File

@ -293,7 +293,6 @@ mediapipe_proto_library(
name = "rect_proto",
srcs = ["rect.proto"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework/formats:location_data_proto"],
)
mediapipe_register_type(

View File

@ -37,10 +37,6 @@ message Classification {
// Group of Classification protos.
message ClassificationList {
repeated Classification classification = 1;
// Optional index of the tensor that produced these classifications.
optional int32 tensor_index = 2;
// Optional name of the tensor that produced these classifications.
optional string tensor_name = 3;
}
// Group of ClassificationList protos.

View File

@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
static void EglThreadExitCallback(void* key_value) {
#if defined(__ANDROID__)
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
EGL_NO_CONTEXT);
#else
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
// implementations, and should be considered as an undocumented vendor
// extension.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
//
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif
eglReleaseThread();
}

View File

@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
return model.fit(
x=train_ds,
epochs=hparams.train_epochs,
steps_per_epoch=hparams.steps_per_epoch,
validation_data=validation_ds,
callbacks=callbacks)

View File

@ -87,6 +87,7 @@ cc_library(
cc_library(
name = "builtin_task_graphs",
deps = [
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
],

View File

@ -14,7 +14,7 @@
"""The public facing packet getter APIs."""
from typing import List, Type
from typing import List
from google.protobuf import message
from google.protobuf import symbol_database
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
get_matrix = _packet_getter.get_matrix
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
def get_proto(packet: mp_packet.Packet) -> message.Message:
"""Get the content of a MediaPipe proto Packet as a proto message.
Args:

View File

@ -46,6 +46,7 @@ cc_library(
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor",
"//mediapipe/gpu:gpu_origin_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",

View File

@ -17,7 +17,7 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto";
// A single classification result.

View File

@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto";
// List of predicted categories with an optional timestamp.

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "EmbeddingsProto";
// Defines a dense floating-point embedding.
message FloatEmbedding {
repeated float values = 1 [packed = true];

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/gpu/gpu_origin.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
@ -128,6 +129,9 @@ absl::Status ConfigureImageToTensorCalculator(
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
std);
}
// TODO: need to.support different GPU origin on differnt
// platforms or applications.
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
return absl::OkStatus();
}

View File

@ -73,6 +73,19 @@ cc_library(
],
)
cc_test(
name = "sentencepiece_tokenizer_test",
srcs = ["sentencepiece_tokenizer_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:albert_model",
],
deps = [
":sentencepiece_tokenizer",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/core:utils",
],
)
cc_library(
name = "tokenizer_utils",
srcs = ["tokenizer_utils.cc"],
@ -95,6 +108,33 @@ cc_library(
],
)
cc_test(
name = "tokenizer_utils_test",
srcs = ["tokenizer_utils_test.cc"],
data = [
"//mediapipe/tasks/testdata/text:albert_model",
"//mediapipe/tasks/testdata/text:mobile_bert_model",
"//mediapipe/tasks/testdata/text:text_classifier_models",
],
linkopts = ["-ldl"],
deps = [
":bert_tokenizer",
":regex_tokenizer",
":sentencepiece_tokenizer",
":tokenizer_utils",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
],
)
cc_library(
name = "regex_tokenizer",
srcs = [

View File

@ -58,6 +58,7 @@ cc_library(
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap;
using ::mediapipe::tasks::vision::image_embedder::proto::
ImageEmbedderGraphOptions;
// Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() {
NormalizedRect norm_rect;
norm_rect.set_x_center(0.5);
norm_rect.set_y_center(0.5);
norm_rect.set_width(1);
norm_rect.set_height(1);
return norm_rect;
}
// Creates a MediaPipe graph config that contains a single node of type
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
@ -148,15 +139,16 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
}
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
Image image, std::optional<NormalizedRect> roi) {
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
@ -167,15 +159,16 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
}
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
@ -188,16 +181,17 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
}
absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms,
std::optional<NormalizedRect> roi) {
absl::Status ImageEmbedder::EmbedAsync(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
NormalizedRect norm_rect =
roi.has_value() ? roi.value() : BuildFullImageNormRect();
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))

View File

@ -21,11 +21,11 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/embedder_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe {
@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi {
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
std::unique_ptr<ImageEmbedderOptions> options);
// Performs embedding extraction on the provided single image. Extraction
// is performed on the region of interest specified by the `roi` argument if
// provided, or on the entire image otherwise.
// Performs embedding extraction on the provided single image.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageEmbedder is created with the image
// running mode.
@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA.
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs embedding extraction on the provided video frame. Extraction
// is performed on the region of interested specified by the `roi` argument if
// provided, or on the entire image otherwise.
// Performs embedding extraction on the provided video frame.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageEmbedder is created with the video
// running mode.
@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// must be monotonically increasing.
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to embedder, and the results will be available via
// the "result_callback" provided in the ImageEmbedderOptions. Embedding
// extraction is performed on the region of interested specified by the `roi`
// argument if provided, or on the entire image otherwise.
// the "result_callback" provided in the ImageEmbedderOptions.
//
// The optional 'image_processing_options' parameter can be used to specify:
// - the rotation to apply to the image before performing embedding
// extraction, by setting its 'rotation_degrees' field.
// and/or
// - the region-of-interest on which to perform embedding extraction, by
// setting its 'region_of_interest' field. If not specified, the full image
// is used.
// If both are specified, the crop around the region-of-interest is extracted
// first, then the specified rotation is applied to the crop.
//
// Only use this method when the ImageEmbedder is created with the live
// stream running mode.
@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status EmbedAsync(
mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageEmbedder when all works are done.
absl::Status Close() { return runner_->Close(); }

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
@ -42,7 +41,9 @@ namespace image_embedder {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
// Bounding box in "burger.jpg" corresponding to "burger_crop.jpg".
NormalizedRect roi;
roi.set_x_center(200.0 / 480);
roi.set_y_center(0.5);
roi.set_width(400.0 / 480);
roi.set_height(1.0f);
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image, roi));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& image_result,
image_embedder->Embed(image, image_processing_options));
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
// Load images: one is a rotated version of the other.
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
image_embedder->Embed(image));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(image_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.572265;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
auto options = std::make_unique<ImageEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
ImageEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
Image crop, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
"burger_rotated.jpg")));
// Region-of-interest corresponding to burger_crop.jpg.
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
ImageProcessingOptions image_processing_options{roi,
/*rotation_degrees=*/-90};
// Extract both embeddings.
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
image_embedder->Embed(crop));
MP_ASSERT_OK_AND_ASSIGN(
const EmbeddingResult& rotated_result,
image_embedder->Embed(rotated, image_processing_options));
// Check results.
CheckMobileNetV3Result(crop_result, false);
CheckMobileNetV3Result(rotated_result, false);
// CheckCosineSimilarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity,
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
rotated_result.embeddings(0).entries(0)));
double expected_similarity = 0.62838;
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
}
class VideoModeTest : public tflite_shims::testing::Test {};
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {

View File

@ -24,10 +24,12 @@ cc_library(
":image_segmenter_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:image_processing_options",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
@ -48,6 +50,7 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",

View File

@ -17,8 +17,10 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageSegmenterGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig(
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
graph.Out(kGroupedSegmentationTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag);
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
{kImageTag, kNormRectTag},
kGroupedSegmentationTag);
}
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
return graph.GetConfig();
}
@ -139,47 +146,68 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
mediapipe::Image image) {
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData({{kImageInStreamName,
mediapipe::MakePacket<Image>(std::move(image))}}));
ProcessImageData(
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
}
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms) {
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
}
absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) {
absl::Status ImageSegmenter::SegmentAsync(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
NormalizedRect norm_rect,
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/kernels/register.h"
@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// running mode.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Performs image segmentation on the provided video frame.
// Only use this method when the ImageSegmenter is created with the video
@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// If the output_type is CATEGORY_MASK, the returned vector of images is
// per-category segmented image mask.
// If the output_type is CONFIDENCE_MASK, the returned vector of images
// contains only one confidence image mask.
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
mediapipe::Image image, int64 timestamp_ms);
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to perform image segmentation, and the results will
// be available via the "result_callback" provided in the
@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// sent to the image segmenter. The input timestamps must be monotonically
// increasing.
//
// The optional 'image_processing_options' parameter can be used to specify
// the rotation to apply to the image before performing segmentation, by
// setting its 'rotation_degrees' field. Note that specifying a
// region-of-interest using the 'region_of_interest' field is NOT supported
// and will result in an invalid argument error being returned.
//
// The "result_callback" prvoides
// - A vector of segmented image masks.
// If the output_type is CATEGORY_MASK, the returned vector of images is
@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
// no longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms);
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
// Shuts down the ImageSegmenter when all works are done.
absl::Status Close() { return runner_->Close(); }

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
constexpr char kSegmentationTag[] = "SEGMENTATION";
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
@ -159,6 +161,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
// Inputs:
// IMAGE - Image
// Image to perform segmentation on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection
// on.
// @Optional: rect covering the whole image is used if not specified.
//
// Outputs:
// SEGMENTATION - mediapipe::Image @Multiple
@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageSegmenterOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto output_streams,
ASSIGN_OR_RETURN(
auto output_streams,
BuildSegmentationTask(
sc->Options<ImageSegmenterOptions>(), *model_resources,
graph[Input<Image>(kImageTag)], graph));
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
auto& merge_images_to_vector =
graph.AddNode("MergeImagesToVectorCalculator");
@ -228,7 +236,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
const ImageSegmenterOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Graph& graph) {
Source<NormalizedRect> norm_rect_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
// Adds preprocessing calculators and connects them to the graph input image
@ -240,6 +248,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
&preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In(kImageTag);
norm_rect_in >> preprocessing.In(kNormRectTag);
// Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator.

View File

@ -29,8 +29,10 @@ limitations under the License.
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -44,6 +46,8 @@ namespace {
using ::mediapipe::Image;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21);
@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
EXPECT_EQ(confidence_masks.size(), 21);
cv::Mat expected_mask =
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
cv::IMREAD_GRAYSCALE);
cv::Mat expected_mask_float;
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
// Cat category index 8.
cv::Mat cat_mask = mediapipe::formats::MatView(
confidence_masks[8].GetImageFrameSharedPtr().get());
EXPECT_THAT(cat_mask,
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
ImageSegmenter::Create(std::move(options)));
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
auto results = segmenter->Segment(image, image_processing_options);
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(results.status().message(),
HasSubstr("This task doesn't support region-of-interest"));
EXPECT_THAT(
results.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
}
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
Image image =
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));

View File

@ -31,6 +31,7 @@ android_binary(
multidex = "native",
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
deps = [
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",

View File

@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector;
import android.content.Intent;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.MediaMetadataRetriever;
import android.os.Bundle;
import android.provider.MediaStore;
@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher;
import androidx.activity.result.contract.ActivityResultContracts;
import androidx.exifinterface.media.ExifInterface;
// ContentResolver dependency
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity {
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
int rotation = 0;
try {
bitmap =
downscaleBitmap(
@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity {
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
rotation = getImageRotation(imageData);
} catch (IOException | MediaPipeException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}
if (bitmap != null) {
MPImage image = new BitmapImageBuilder(bitmap).build();
ObjectDetectionResult detectionResult = objectDetector.detect(image);
ObjectDetectionResult detectionResult =
objectDetector.detect(
image,
ImageProcessingOptions.builder().setRotationDegrees(rotation).build());
imageView.setData(image, detectionResult);
runOnUiThread(() -> imageView.update());
}
@ -210,28 +215,25 @@ public class MainActivity extends AppCompatActivity {
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_NORMAL:
return 0;
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
return 90;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
return 180;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
return 270;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
// TODO: use getRotationDegrees() and isFlipped() instead of switch once flip
// is supported.
throw new MediaPipeException(
MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(),
"Flipped images are not supported yet.");
}
}
}

View File

@ -15,11 +15,11 @@
package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;

View File

@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.OutputHandler;

View File

@ -28,6 +28,7 @@ android_library(
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable {
@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
private final TaskRunner runner;
private final RunningMode runningMode;
private final String imageStreamName;
private final Optional<String> normRectStreamName;
private final String normRectStreamName;
static {
System.loadLibrary("mediapipe_tasks_vision_jni");
@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
}
/**
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
* Constructor to initialize a {@link BaseVisionTaskApi}.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
*/
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.empty();
}
/**
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
* input.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image stream.
* @param normRectStreamName the name of the input normalized rect image stream used to provide
* (mandatory) rotation and (optional) region-of-interest.
*/
public BaseVisionTaskApi(
TaskRunner runner,
@ -70,7 +55,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.of(normRectStreamName);
this.normRectStreamName = normRectStreamName;
}
/**
@ -78,53 +63,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
* failure status or a successful result is returned.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
* input.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @throws MediaPipeException if the task is not in the image mode.
*/
protected TaskResult processImageData(MPImage image) {
protected TaskResult processImageData(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets);
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
* rect.
*/
protected TaskResult processImageData(MPImage image, RectF roi) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
return runner.process(inputPackets);
}
@ -133,55 +88,24 @@ public class BaseVisionTaskApi implements AutoCloseable {
* until a failure status or a successful result is returned.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
* @throws MediaPipeException if the task is not in the video mode.
*/
protected TaskResult processVideoData(MPImage image, long timestampMs) {
protected TaskResult processVideoData(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected TaskResult processVideoData(MPImage image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
@ -190,55 +114,24 @@ public class BaseVisionTaskApi implements AutoCloseable {
* available in the user-defined result listener.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
* @throws MediaPipeException if the task is not in the stream mode.
*/
protected void sendLiveStreamData(MPImage image, long timestampMs) {
protected void sendLiveStreamData(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected void sendLiveStreamData(MPImage image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
normRectStreamName,
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
runner.close();
}
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
private static NormalizedRect convertToNormalizedRect(RectF rect) {
/**
* Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf
* message.
*/
private static NormalizedRect convertToNormalizedRect(
ImageProcessingOptions imageProcessingOptions) {
RectF regionOfInterest =
imageProcessingOptions.regionOfInterest().isPresent()
? imageProcessingOptions.regionOfInterest().get()
: new RectF(0, 0, 1, 1);
return NormalizedRect.newBuilder()
.setXCenter(rect.centerX())
.setYCenter(rect.centerY())
.setWidth(rect.width())
.setHeight(rect.height())
.setXCenter(regionOfInterest.centerX())
.setYCenter(regionOfInterest.centerY())
.setWidth(regionOfInterest.width())
.setHeight(regionOfInterest.height())
// Convert to radians anti-clockwise.
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
.build();
}
}

View File

@ -0,0 +1,92 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
import android.graphics.RectF;
import com.google.auto.value.AutoValue;
import java.util.Optional;
// TODO: add support for image flipping.
/** Options for image processing. */
@AutoValue
public abstract class ImageProcessingOptions {
/**
* Builder for {@link ImageProcessingOptions}.
*
* <p>If both region-of-interest and rotation are specified, the crop around the
* region-of-interest is extracted first, then the specified rotation is applied to the crop.
*/
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the optional region-of-interest to crop from the image. If not specified, the full image
* is used.
*
* <p>Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be
* < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()}
* is called.
*/
public abstract Builder setRegionOfInterest(RectF value);
/**
* Sets the rotation to apply to the image (or cropped region-of-interest), in degrees
* clockwise. Defaults to 0.
*
* <p>The rotation must be a multiple (positive or negative) of 90°, otherwise an
* IllegalArgumentException will be thrown when {@link #build()} is called.
*/
public abstract Builder setRotationDegrees(int value);
abstract ImageProcessingOptions autoBuild();
/**
* Validates and builds the {@link ImageProcessingOptions} instance.
*
* @throws IllegalArgumentException if some of the provided values do not meet their
* requirements.
*/
public final ImageProcessingOptions build() {
ImageProcessingOptions options = autoBuild();
if (options.regionOfInterest().isPresent()) {
RectF roi = options.regionOfInterest().get();
if (roi.left >= roi.right || roi.top >= roi.bottom) {
throw new IllegalArgumentException(
String.format(
"Expected left < right and top < bottom, found: %s.", roi.toShortString()));
}
if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) {
throw new IllegalArgumentException(
String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString()));
}
}
if (options.rotationDegrees() % 90 != 0) {
throw new IllegalArgumentException(
String.format(
"Expected rotation to be a multiple of 90°, found: %d.",
options.rotationDegrees()));
}
return options;
}
}
public abstract Optional<RectF> regionOfInterest();
public abstract int rotationDegrees();
public static Builder builder() {
return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0);
}
}

View File

@ -15,7 +15,6 @@
package com.google.mediapipe.tasks.vision.gesturerecognizer;
import android.content.Context;
import android.graphics.RectF;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto;
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto;
@ -212,6 +212,25 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
}
/**
* Performs gesture recognition on the provided single image with default image processing
* options, i.e. without any rotation applied. Only use this method when the {@link
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
* for input image format.
*
* <p>{@link GestureRecognizer} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public GestureRecognitionResult recognize(MPImage image) {
return recognize(image, ImageProcessingOptions.builder().build());
}
/**
* Performs gesture recognition on the provided single image. Only use this method when the {@link
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
@ -223,12 +242,41 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public GestureRecognitionResult recognize(MPImage inputImage) {
// TODO: add proper support for rotations.
return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF());
public GestureRecognitionResult recognize(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions);
return (GestureRecognitionResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs gesture recognition on the provided video frame with default image processing options,
* i.e. without any rotation applied. Only use this method when the {@link GestureRecognizer} is
* created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link GestureRecognizer} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public GestureRecognitionResult recognizeForVideo(MPImage image, long timestampMs) {
return recognizeForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
@ -244,14 +292,43 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public GestureRecognitionResult recognizeForVideo(MPImage inputImage, long inputTimestampMs) {
// TODO: add proper support for rotations.
return (GestureRecognitionResult)
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
public GestureRecognitionResult recognizeForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
return (GestureRecognitionResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform gesture recognition with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method when the
* {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the gesture recognizer. The input timestamps must be monotonically increasing.
*
* <p>{@link GestureRecognizer} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void recognizeAsync(MPImage image, long timestampMs) {
recognizeAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
@ -268,13 +345,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void recognizeAsync(MPImage inputImage, long inputTimestampMs) {
// TODO: add proper support for rotations.
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
public void recognizeAsync(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
}
/** Options for setting up an {@link GestureRecognizer}. */
@ -445,8 +529,14 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
}
}
/** Creates a RectF covering the full image. */
private static RectF buildFullImageRectF() {
return new RectF(0, 0, 1, 1);
/**
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
* region-of-interest.
*/
private static void validateImageProcessingOptions(
ImageProcessingOptions imageProcessingOptions) {
if (imageProcessingOptions.regionOfInterest().isPresent()) {
throw new IllegalArgumentException("GestureRecognizer doesn't support region-of-interest.");
}
}
}

View File

@ -15,11 +15,11 @@
package com.google.mediapipe.tasks.vision.imageclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;

View File

@ -15,7 +15,6 @@
package com.google.mediapipe.tasks.vision.imageclassifier;
import android.content.Context;
import android.graphics.RectF;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
@ -26,7 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
import java.io.File;
@ -215,6 +215,24 @@ public final class ImageClassifier extends BaseVisionTaskApi {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
}
/**
* Performs classification on the provided single image with default image processing options,
* i.e. using the whole image as region-of-interest and without any rotation applied. Only use
* this method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classify(MPImage image) {
return classify(image, ImageProcessingOptions.builder().build());
}
/**
* Performs classification on the provided single image. Only use this method when the {@link
* ImageClassifier} is created with {@link RunningMode.IMAGE}.
@ -225,16 +243,23 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classify(MPImage inputImage) {
return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF());
public ImageClassificationResult classify(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
return (ImageClassificationResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs classification on the provided single image and region-of-interest. Only use this
* method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
* Performs classification on the provided video frame with default image processing options, i.e.
* using the whole image as region-of-interest and without any rotation applied. Only use this
* method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
@ -242,13 +267,12 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RectF} specifying the region of interest on which to perform
* classification. Coordinates are expected to be specified as normalized values in [0,1].
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classify(MPImage inputImage, RectF roi) {
return (ImageClassificationResult) processImageData(inputImage, roi);
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) {
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
@ -264,21 +288,26 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classifyForVideo(MPImage inputImage, long inputTimestampMs) {
return (ImageClassificationResult)
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
public ImageClassificationResult classifyForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Performs classification on the provided video frame with additional region-of-interest. Only
* use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
* Sends live image data to perform classification with default image processing options, i.e.
* using the whole image as region-of-interest and without any rotation applied, and the results
* will be available via the {@link ResultListener} provided in the {@link
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
* {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
@ -286,15 +315,12 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RectF} specifying the region of interest on which to perform
* classification. Coordinates are expected to be specified as normalized values in [0,1].
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ImageClassificationResult classifyForVideo(
MPImage inputImage, RectF roi, long inputTimestampMs) {
return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs);
public void classifyAsync(MPImage image, long timestampMs) {
classifyAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
@ -311,37 +337,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void classifyAsync(MPImage inputImage, long inputTimestampMs) {
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
}
/**
* Sends live image data and additional region-of-interest to perform classification, and the
* results will be available via the {@link ResultListener} provided in the {@link
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
* {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ImageClassifier} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param roi a {@link RectF} specifying the region of interest on which to perform
* classification. Coordinates are expected to be specified as normalized values in [0,1].
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void classifyAsync(MPImage inputImage, RectF roi, long inputTimestampMs) {
sendLiveStreamData(inputImage, roi, inputTimestampMs);
public void classifyAsync(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
}
/** Options for setting up and {@link ImageClassifier}. */
@ -447,9 +451,4 @@ public final class ImageClassifier extends BaseVisionTaskApi {
.build();
}
}
/** Creates a RectF covering the full image. */
private static RectF buildFullImageRectF() {
return new RectF(0, 0, 1, 1);
}
}

View File

@ -32,6 +32,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
@ -96,8 +97,10 @@ import java.util.Optional;
public final class ObjectDetector extends BaseVisionTaskApi {
private static final String TAG = ObjectDetector.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
@ -204,7 +207,25 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
}
/**
* Performs object detection on the provided single image with default image processing options,
* i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is
* created with {@link RunningMode.IMAGE}.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(MPImage image) {
return detect(image, ImageProcessingOptions.builder().build());
}
/**
@ -217,11 +238,41 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(MPImage inputImage) {
return (ObjectDetectionResult) processImageData(inputImage);
public ObjectDetectionResult detect(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs object detection on the provided video frame with default image processing options,
* i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is
* created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) {
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
@ -237,12 +288,43 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(MPImage inputImage, long inputTimestampMs) {
return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
public ObjectDetectionResult detectForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform object detection with default image processing options, i.e.
* without any rotation applied, and the results will be available via the {@link ResultListener}
* provided in the {@link ObjectDetectorOptions}. Only use this method when the {@link
* ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(MPImage image, long timestampMs) {
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
@ -259,12 +341,20 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link MPImage} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(MPImage inputImage, long inputTimestampMs) {
sendLiveStreamData(inputImage, inputTimestampMs);
public void detectAsync(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
}
/** Options for setting up an {@link ObjectDetector}. */
@ -415,4 +505,15 @@ public final class ObjectDetector extends BaseVisionTaskApi {
.build();
}
}
/**
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
* region-of-interest.
*/
private static void validateImageProcessingOptions(
ImageProcessingOptions imageProcessingOptions) {
if (imageProcessingOptions.regionOfInterest().isPresent()) {
throw new IllegalArgumentException("ObjectDetector doesn't support region-of-interest.");
}
}
}

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.coretest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="coretest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.coretest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,70 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.graphics.RectF;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import org.junit.Test;
import org.junit.runner.RunWith;
/** Test for {@link ImageProcessingOptions}/ */
@RunWith(AndroidJUnit4.class)
public final class ImageProcessingOptionsTest {
@Test
public void succeedsWithValidInputs() throws Exception {
ImageProcessingOptions options =
ImageProcessingOptions.builder()
.setRegionOfInterest(new RectF(0.0f, 0.1f, 1.0f, 0.9f))
.setRotationDegrees(270)
.build();
}
@Test
public void failsWithLeftHigherThanRight() {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageProcessingOptions.builder()
.setRegionOfInterest(new RectF(0.9f, 0.0f, 0.1f, 1.0f))
.build());
assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom");
}
@Test
public void failsWithBottomHigherThanTop() {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ImageProcessingOptions.builder()
.setRegionOfInterest(new RectF(0.0f, 0.9f, 1.0f, 0.1f))
.build());
assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom");
}
@Test
public void failsWithInvalidRotation() {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() -> ImageProcessingOptions.builder().setRotationDegrees(1).build());
assertThat(exception).hasMessageThat().contains("Expected rotation to be a multiple of 90°");
}
}

View File

@ -19,6 +19,7 @@ import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import android.graphics.RectF;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.common.truth.Correspondence;
@ -30,6 +31,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Landmark;
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions;
import java.io.InputStream;
@ -46,11 +48,14 @@ public class GestureRecognizerTest {
private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task";
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
private static final String TAG = "Gesture Recognizer Test";
private static final String THUMB_UP_LABEL = "Thumb_Up";
private static final int THUMB_UP_INDEX = 5;
private static final String POINTING_UP_LABEL = "Pointing_Up";
private static final int POINTING_UP_INDEX = 3;
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
private static final int IMAGE_WIDTH = 382;
private static final int IMAGE_HEIGHT = 406;
@ -135,6 +140,53 @@ public class GestureRecognizerTest {
gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE));
assertThat(actualResult.handednesses()).hasSize(2);
}
@Test
public void recognize_successWithRotation() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
GestureRecognitionResult actualResult =
gestureRecognizer.recognize(
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
assertThat(actualResult.gestures()).hasSize(1);
assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX);
assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL);
}
@Test
public void recognize_failsWithRegionOfInterest() throws Exception {
GestureRecognizerOptions options =
GestureRecognizerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.build();
GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
gestureRecognizer.recognize(
getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions));
assertThat(exception)
.hasMessageThat()
.contains("GestureRecognizer doesn't support region-of-interest");
}
}
@RunWith(AndroidJUnit4.class)
@ -195,12 +247,16 @@ public class GestureRecognizerTest {
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
() ->
gestureRecognizer.recognizeForVideo(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
() ->
gestureRecognizer.recognizeAsync(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -225,7 +281,9 @@ public class GestureRecognizerTest {
exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
() ->
gestureRecognizer.recognizeAsync(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -251,7 +309,9 @@ public class GestureRecognizerTest {
exception =
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
() ->
gestureRecognizer.recognizeForVideo(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@ -291,7 +351,8 @@ public class GestureRecognizerTest {
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
for (int i = 0; i < 3; i++) {
GestureRecognitionResult actualResult =
gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i);
gestureRecognizer.recognizeForVideo(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
}
@ -317,9 +378,11 @@ public class GestureRecognizerTest {
.build();
try (GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
gestureRecognizer.recognizeAsync(image, 1);
gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0));
assertThrows(
MediaPipeException.class,
() -> gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
@ -348,7 +411,7 @@ public class GestureRecognizerTest {
try (GestureRecognizer gestureRecognizer =
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
gestureRecognizer.recognizeAsync(image, i);
gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ i);
}
}
}

View File

@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
import java.io.InputStream;
@ -47,7 +48,9 @@ public class ImageClassifierTest {
private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
private static final String BURGER_IMAGE = "burger.jpg";
private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg";
private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
private static final String MULTI_OBJECTS_ROTATED_IMAGE = "multi_objects_rotated.jpg";
@RunWith(AndroidJUnit4.class)
public static final class General extends ImageClassifierTest {
@ -209,13 +212,60 @@ public class ImageClassifierTest {
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// RectF around the soccer ball.
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
ImageClassificationResult results =
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi);
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
}
@Test
public void classify_succeedsWithRotation() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ImageClassificationResult results =
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results,
Arrays.asList(
Category.create(0.6390683f, 934, "cheeseburger", ""),
Category.create(0.0495407f, 963, "meat loaf", ""),
Category.create(0.0469720f, 925, "guacamole", "")));
}
@Test
public void classify_succeedsWithRegionOfInterestAndRotation() throws Exception {
ImageClassifierOptions options =
ImageClassifierOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
.build();
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
// RectF around the chair.
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
ImageClassificationResult results =
imageClassifier.classify(
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
assertHasOneHeadAndOneTimestamp(results, 0);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
}
}
@RunWith(AndroidJUnit4.class)
@ -269,12 +319,16 @@ public class ImageClassifierTest {
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
() ->
imageClassifier.classifyForVideo(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
() ->
imageClassifier.classifyAsync(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -296,7 +350,9 @@ public class ImageClassifierTest {
exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
() ->
imageClassifier.classifyAsync(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -320,7 +376,9 @@ public class ImageClassifierTest {
exception =
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
() ->
imageClassifier.classifyForVideo(
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@ -352,7 +410,8 @@ public class ImageClassifierTest {
ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ImageClassificationResult results = imageClassifier.classifyForVideo(image, i);
ImageClassificationResult results =
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
assertHasOneHeadAndOneTimestamp(results, i);
assertCategoriesAre(
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
@ -377,9 +436,11 @@ public class ImageClassifierTest {
.build();
try (ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1);
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0));
assertThrows(
MediaPipeException.class,
() -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
@ -405,7 +466,7 @@ public class ImageClassifierTest {
try (ImageClassifier imageClassifier =
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; ++i) {
imageClassifier.classifyAsync(image, i);
imageClassifier.classifyAsync(image, /*timestampMs=*/ i);
}
}
}

View File

@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Detection;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
import java.io.InputStream;
@ -45,10 +46,11 @@ import org.junit.runners.Suite.SuiteClasses;
public class ObjectDetectorTest {
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg";
private static final int IMAGE_WIDTH = 1200;
private static final int IMAGE_HEIGHT = 600;
private static final float CAT_SCORE = 0.69f;
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
private static final RectF CAT_BOUNDING_BOX = new RectF(611, 164, 986, 596);
// TODO: Figure out why android_x86 and android_arm tests have slightly different
// scores (0.6875 vs 0.69921875).
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
@ -67,7 +69,7 @@ public class ObjectDetectorTest {
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@Test
@ -104,7 +106,7 @@ public class ObjectDetectorTest {
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block all other other objects, except cat.
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@Test
@ -175,7 +177,7 @@ public class ObjectDetectorTest {
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@Test
@ -228,6 +230,46 @@ public class ObjectDetectorTest {
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
}
@Test
public void detect_succeedsWithRotation() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(1)
.setCategoryAllowlist(Arrays.asList("cat"))
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
ObjectDetectionResult results =
objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f);
}
@Test
public void detect_failsWithRegionOfInterest() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
objectDetector.detect(
getImageFromAsset(CAT_AND_DOG_IMAGE), imageProcessingOptions));
assertThat(exception)
.hasMessageThat()
.contains("ObjectDetector doesn't support region-of-interest");
}
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
@ -282,12 +324,16 @@ public class ObjectDetectorTest {
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
() ->
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
() ->
objectDetector.detectAsync(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -309,7 +355,9 @@ public class ObjectDetectorTest {
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
() ->
objectDetector.detectAsync(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@ -333,7 +381,9 @@ public class ObjectDetectorTest {
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
() ->
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@ -348,7 +398,7 @@ public class ObjectDetectorTest {
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
@Test
@ -363,8 +413,9 @@ public class ObjectDetectorTest {
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ObjectDetectionResult results =
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
objectDetector.detectForVideo(
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i);
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
}
}
@ -377,16 +428,18 @@ public class ObjectDetectorTest {
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
.build();
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
objectDetector.detectAsync(image, 1);
objectDetector.detectAsync(image, /*timestampsMs=*/ 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
@ -402,7 +455,7 @@ public class ObjectDetectorTest {
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
@ -410,7 +463,7 @@ public class ObjectDetectorTest {
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
objectDetector.detectAsync(image, i);
objectDetector.detectAsync(image, /*timestampsMs=*/ i);
}
}
}

View File

@ -86,3 +86,13 @@ py_library(
"//mediapipe/tasks/python/core:optional_dependencies",
],
)
py_library(
name = "classifications",
srcs = ["classifications.py"],
deps = [
":category",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -0,0 +1,168 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifications data class."""
import dataclasses
from typing import Any, List, Optional
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
_ClassificationsProto = classifications_pb2.Classifications
_ClassificationResultProto = classifications_pb2.ClassificationResult
@dataclasses.dataclass
class ClassificationEntry:
"""List of predicted classes (aka labels) for a given classifier head.
Attributes:
categories: The array of predicted categories, usually sorted by descending
scores (e.g. from high to low probability).
timestamp_ms: The optional timestamp (in milliseconds) associated to the
classification entry. This is useful for time series use cases, e.g.,
audio classification.
"""
categories: List[category_module.Category]
timestamp_ms: Optional[int] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationEntryProto:
"""Generates a ClassificationEntry protobuf object."""
return _ClassificationEntryProto(
categories=[category.to_pb2() for category in self.categories],
timestamp_ms=self.timestamp_ms)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
"""Creates a `ClassificationEntry` object from the given protobuf object."""
return ClassificationEntry(
categories=[
category_module.Category.create_from_pb2(category)
for category in pb2_obj.categories
],
timestamp_ms=pb2_obj.timestamp_ms)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationEntry):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class Classifications:
"""Represents the classifications for a given classifier head.
Attributes:
entries: A list of `ClassificationEntry` objects.
head_index: The index of the classifier head these categories refer to. This
is useful for multi-head models.
head_name: The name of the classifier head, which is the corresponding
tensor metadata name.
"""
entries: List[ClassificationEntry]
head_index: int
head_name: str
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationsProto:
"""Generates a Classifications protobuf object."""
return _ClassificationsProto(
entries=[entry.to_pb2() for entry in self.entries],
head_index=self.head_index,
head_name=self.head_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
"""Creates a `Classifications` object from the given protobuf object."""
return Classifications(
entries=[
ClassificationEntry.create_from_pb2(entry)
for entry in pb2_obj.entries
],
head_index=pb2_obj.head_index,
head_name=pb2_obj.head_name)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, Classifications):
return False
return self.to_pb2().__eq__(other.to_pb2())
@dataclasses.dataclass
class ClassificationResult:
"""Contains one set of results per classifier head.
Attributes:
classifications: A list of `Classifications` objects.
"""
classifications: List[Classifications]
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ClassificationResultProto:
"""Generates a ClassificationResult protobuf object."""
return _ClassificationResultProto(classifications=[
classification.to_pb2() for classification in self.classifications
])
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
"""Creates a `ClassificationResult` object from the given protobuf object.
"""
return ClassificationResult(classifications=[
Classifications.create_from_pb2(classification)
for classification in pb2_obj.classifications
])
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, ClassificationResult):
return False
return self.to_pb2().__eq__(other.to_pb2())

View File

@ -26,9 +26,7 @@ _NormalizedRectProto = rect_pb2.NormalizedRect
@dataclasses.dataclass
class Rect:
"""A rectangle with rotation in image coordinates.
Attributes:
x_center : The X coordinate of the top-left corner, in pixels.
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
y_center : The Y coordinate of the top-left corner, in pixels.
width: The width of the rectangle, in pixels.
height: The height of the rectangle, in pixels.
@ -81,11 +79,10 @@ class Rect:
@dataclasses.dataclass
class NormalizedRect:
"""A rectangle with rotation in normalized coordinates. The values of box
"""A rectangle with rotation in normalized coordinates.
The values of box
center location and size are within [0, 1].
Attributes:
x_center : The X normalized coordinate of the top-left corner.
Attributes: x_center : The X normalized coordinate of the top-left corner.
y_center : The Y normalized coordinate of the top-left corner.
width: The width of the rectangle.
height: The height of the rectangle.
@ -110,8 +107,7 @@ class NormalizedRect:
width=self.width,
height=self.height,
rotation=self.rotation,
rect_id=self.rect_id
)
rect_id=self.rect_id)
@classmethod
@doc_controls.do_not_generate_docs
@ -123,8 +119,7 @@ class NormalizedRect:
width=pb2_obj.width,
height=pb2_obj.height,
rotation=pb2_obj.rotation,
rect_id=pb2_obj.rect_id
)
rect_id=pb2_obj.rect_id)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.

View File

@ -14,6 +14,8 @@
# Placeholder for internal Python strict library compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])

View File

@ -61,19 +61,13 @@ class ClassifierOptions:
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(
cls,
pb2_obj: _ClassifierOptionsProto
) -> 'ClassifierOptions':
def create_from_pb2(cls,
pb2_obj: _ClassifierOptionsProto) -> 'ClassifierOptions':
"""Creates a `ClassifierOptions` object from the given protobuf object."""
return ClassifierOptions(
score_threshold=pb2_obj.score_threshold,
category_allowlist=[
str(name) for name in pb2_obj.class_name_allowlist
],
category_denylist=[
str(name) for name in pb2_obj.class_name_denylist
],
category_allowlist=[str(name) for name in pb2_obj.category_allowlist],
category_denylist=[str(name) for name in pb2_obj.category_denylist],
display_names_locale=pb2_obj.display_names_locale,
max_results=pb2_obj.max_results)

View File

@ -37,6 +37,26 @@ py_test(
],
)
py_test(
name = "image_classifier_test",
srcs = ["image_classifier_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:image_classifier",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_test(
name = "gesture_recognizer_test",
srcs = ["gesture_recognizer_test.py"],

View File

@ -0,0 +1,515 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for image classifier."""
import enum
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from mediapipe.python._framework_bindings import image
from mediapipe.tasks.python.components.containers import category
from mediapipe.tasks.python.components.containers import classifications as classifications_module
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_classifier
from mediapipe.tasks.python.vision.core import vision_task_running_mode
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_Category = category.Category
_ClassificationEntry = classifications_module.ClassificationEntry
_Classifications = classifications_module.Classifications
_ClassificationResult = classifications_module.ClassificationResult
_Image = image.Image
_ImageClassifier = image_classifier.ImageClassifier
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
_IMAGE_FILE = 'burger.jpg'
_ALLOW_LIST = ['cheeseburger', 'guacamole']
_DENY_LIST = ['cheeseburger']
_SCORE_THRESHOLD = 0.5
_MAX_RESULTS = 3
# TODO: Port assertProtoEquals
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
pass
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms)
],
head_index=0,
head_name='probability')
])
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=934,
score=0.7939587831497192,
display_name='',
category_name='cheeseburger'),
_Category(
index=932,
score=0.02739289402961731,
display_name='',
category_name='bagel'),
_Category(
index=925,
score=0.01934075355529785,
display_name='',
category_name='guacamole'),
_Category(
index=963,
score=0.006327860057353973,
display_name='',
category_name='meat loaf')
],
timestamp_ms=timestamp_ms)
],
head_index=0,
head_name='probability')
])
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
return _ClassificationResult(classifications=[
_Classifications(
entries=[
_ClassificationEntry(
categories=[
_Category(
index=806,
score=0.9965274930000305,
display_name='',
category_name='soccer ball')
],
timestamp_ms=timestamp_ms)
],
head_index=0,
head_name='probability')
])
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class ImageClassifierTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_IMAGE_FILE))
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
self.assertIsInstance(classifier, _ImageClassifier)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _ImageClassifierOptions(base_options=base_options)
with _ImageClassifier.create_from_options(options) as classifier:
self.assertIsInstance(classifier, _ImageClassifier)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
ValueError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='')
options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _ImageClassifierOptions(base_options=base_options)
classifier = _ImageClassifier.create_from_options(options)
self.assertIsInstance(classifier, _ImageClassifier)
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
def test_classify(self, model_file_type, max_results,
expected_classification_result):
# Creates classifier.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
classifier = _ImageClassifier.create_from_options(options)
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
expected_classification_result.to_pb2())
# Closes the classifier explicitly when the classifier is not used in
# a context.
classifier.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
def test_classify_in_context(self, model_file_type, max_results,
expected_classification_result):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
custom_classifier_options = _ClassifierOptions(max_results=max_results)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
expected_classification_result.to_pb2())
def test_classify_succeeds_with_region_of_interest(self):
base_options = _BaseOptions(model_asset_path=self.model_path)
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=base_options, classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
# Performs image classification on the input.
image_result = classifier.classify(test_image, roi)
# Comparing results.
_assert_proto_equals(image_result.to_pb2(),
_generate_soccer_ball_results(0).to_pb2())
def test_score_threshold_option(self):
custom_classifier_options = _ClassifierOptions(
score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
score = entry.categories[0].score
self.assertGreaterEqual(
score, _SCORE_THRESHOLD,
f'Classification with score lower than threshold found. '
f'{classification}')
def test_max_results_option(self):
custom_classifier_options = _ClassifierOptions(
score_threshold=_SCORE_THRESHOLD)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
categories = image_result.classifications[0].entries[0].categories
self.assertLessEqual(
len(categories), _MAX_RESULTS, 'Too many results returned.')
def test_allow_list_option(self):
custom_classifier_options = _ClassifierOptions(
category_allowlist=_ALLOW_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
label = entry.categories[0].category_name
self.assertIn(label, _ALLOW_LIST,
f'Label {label} found but not in label allow list')
def test_deny_list_option(self):
custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
classifications = image_result.classifications
for classification in classifications:
for entry in classification.entries:
label = entry.categories[0].category_name
self.assertNotIn(label, _DENY_LIST,
f'Label {label} found but in deny list.')
def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
ValueError,
r'`category_allowlist` and `category_denylist` are mutually '
r'exclusive options.'):
custom_classifier_options = _ClassifierOptions(
category_allowlist=['foo'], category_denylist=['bar'])
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_empty_classification_outputs(self):
custom_classifier_options = _ClassifierOptions(score_threshold=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Performs image classification on the input.
image_result = classifier.classify(self.test_image)
self.assertEmpty(image_result.classifications[0].entries[0].categories)
def test_missing_result_callback(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _ImageClassifier.create_from_options(options) as unused_classifier:
pass
def test_calling_classify_for_video_in_image_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
classifier.classify_for_video(self.test_image, 0)
def test_calling_classify_async_in_image_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
classifier.classify_async(self.test_image, 0)
def test_calling_classify_in_video_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
classifier.classify(self.test_image)
def test_calling_classify_async_in_video_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
classifier.classify_async(self.test_image, 0)
def test_classify_for_video_with_out_of_order_timestamp(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _ImageClassifier.create_from_options(options) as classifier:
unused_result = classifier.classify_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_for_video(self.test_image, 0)
def test_classify_for_video(self):
custom_classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
self.test_image, timestamp)
_assert_proto_equals(classification_result.to_pb2(),
_generate_burger_results(timestamp).to_pb2())
def test_classify_for_video_succeeds_with_region_of_interest(self):
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
classifier_options=custom_classifier_options)
with _ImageClassifier.create_from_options(options) as classifier:
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
for timestamp in range(0, 300, 30):
classification_result = classifier.classify_for_video(
test_image, timestamp, roi)
self.assertEqual(classification_result,
_generate_soccer_ball_results(timestamp))
def test_calling_classify_in_live_stream_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
classifier.classify(self.test_image)
def test_calling_classify_for_video_in_live_stream_mode(self):
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
classifier.classify_for_video(self.test_image, 0)
def test_classify_async_calls_with_illegal_timestamp(self):
custom_classifier_options = _ClassifierOptions(max_results=4)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options,
result_callback=mock.MagicMock())
with _ImageClassifier.create_from_options(options) as classifier:
classifier.classify_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
classifier.classify_async(self.test_image, 0)
@parameterized.parameters((0, _generate_burger_results),
(1, _generate_empty_results))
def test_classify_async_calls(self, threshold, expected_result_fn):
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
_assert_proto_equals(result.to_pb2(),
expected_result_fn(timestamp_ms).to_pb2())
self.assertTrue(
np.array_equal(output_image.numpy_view(),
self.test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
custom_classifier_options = _ClassifierOptions(
max_results=4, score_threshold=threshold)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options,
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classifier.classify_async(self.test_image, timestamp)
def test_classify_async_succeeds_with_region_of_interest(self):
# Load the test image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path('multi_objects.jpg'))
# NormalizedRect around the soccer ball.
roi = _NormalizedRect(
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
observed_timestamp_ms = -1
def check_result(result: _ClassificationResult, output_image: _Image,
timestamp_ms: int):
_assert_proto_equals(result.to_pb2(),
_generate_soccer_ball_results(timestamp_ms).to_pb2())
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
custom_classifier_options = _ClassifierOptions(max_results=1)
options = _ImageClassifierOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
classifier_options=custom_classifier_options,
result_callback=check_result)
with _ImageClassifier.create_from_options(options) as classifier:
for timestamp in range(0, 300, 30):
classifier.classify_async(test_image, timestamp, roi)
if __name__ == '__main__':
absltest.main()

View File

@ -37,6 +37,28 @@ py_library(
],
)
py_library(
name = "image_classifier",
srcs = [
"image_classifier.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
"//mediapipe/tasks/python/components/containers:classifications",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_library(
name = "gesture_recognizer",
srcs = [

View File

@ -0,0 +1,294 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe image classifier task."""
import dataclasses
from typing import Callable, Mapping, Optional
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
# TODO: Import MPImage directly one we have an alias
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet
from mediapipe.python._framework_bindings import task_runner
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
from mediapipe.tasks.python.components.containers import classifications
from mediapipe.tasks.python.components.containers import rect
from mediapipe.tasks.python.components.processors import classifier_options
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import vision_task_running_mode
_NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
_ClassifierOptions = classifier_options.ClassifierOptions
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
_TaskInfo = task_info_module.TaskInfo
_TaskRunner = task_runner.TaskRunner
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
def _build_full_image_norm_rect() -> _NormalizedRect:
# Builds a NormalizedRect covering the entire image.
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
@dataclasses.dataclass
class ImageClassifierOptions:
"""Options for the image classifier task.
Attributes:
base_options: Base options for the image classifier task.
running_mode: The running mode of the task. Default to the image mode. Image
classifier task has three running modes: 1) The image mode for classifying
objects on single image inputs. 2) The video mode for classifying objects
on the decoded frames of a video. 3) The live stream mode for classifying
objects on a live stream of input data, such as from camera.
classifier_options: Options for the image classification task.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
classifier_options: _ClassifierOptions = _ClassifierOptions()
result_callback: Optional[
Callable[[classifications.ClassificationResult, image_module.Image, int],
None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
"""Generates an ImageClassifierOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
classifier_options_proto = self.classifier_options.to_pb2()
return _ImageClassifierGraphOptionsProto(
base_options=base_options_proto,
classifier_options=classifier_options_proto)
class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs image classification on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'ImageClassifier':
"""Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
Note that the created `ImageClassifier` instance is in image mode, for
classifying objects on single image inputs.
Args:
model_path: Path to the model.
Returns:
`ImageClassifier` object that's created from the model file and the
default `ImageClassifierOptions`.
Raises:
ValueError: If failed to create `ImageClassifier` object from the provided
file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = ImageClassifierOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: ImageClassifierOptions) -> 'ImageClassifier':
"""Creates the `ImageClassifier` object from image classifier options.
Args:
options: Options for the image classifier task.
Returns:
`ImageClassifier` object that's created from `options`.
Raises:
ValueError: If failed to create `ImageClassifier` object from
`ImageClassifierOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
classification_result = classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback(classification_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
],
output_streams=[
':'.join([
_CLASSIFICATION_RESULT_TAG,
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
# TODO: Replace _NormalizedRect with ImageProcessingOption
def classify(
self,
image: image_module.Image,
roi: Optional[_NormalizedRect] = None
) -> classifications.ClassificationResult:
"""Performs image classification on the provided MediaPipe Image.
Args:
image: MediaPipe Image.
roi: The region of interest.
Returns:
A classification result object that contains a list of classifications.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
def classify_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
roi: Optional[_NormalizedRect] = None
) -> classifications.ClassificationResult:
"""Performs image classification on the provided video frames.
Only use this method when the ImageClassifier is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
roi: The region of interest.
Returns:
A classification result object that contains a list of classifications.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If image classification failed to run.
"""
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
classification_result_proto = classifications_pb2.ClassificationResult()
classification_result_proto.CopyFrom(
packet_getter.get_proto(
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
return classifications.ClassificationResult([
classifications.Classifications.create_from_pb2(classification)
for classification in classification_result_proto.classifications
])
def classify_async(self,
image: image_module.Image,
timestamp_ms: int,
roi: Optional[_NormalizedRect] = None) -> None:
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
Only use this method when the ImageClassifier is created with the live
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input image is accepted. The results will be available via the
`result_callback` provided in the `ImageClassifierOptions`. The
`classify_async` method is designed to process live stream data such as
camera input. To lower the overall latency, image classifier may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` provides:
- A classification result object that contains a list of classifications.
- The input image that the image classifier runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
roi: The region of interest.
Raises:
ValueError: If the current input timestamp is smaller than what the image
classifier has already processed.
"""
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_NAME:
packet_creator.create_proto(norm_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})

View File

@ -28,6 +28,8 @@ mediapipe_files(srcs = [
"burger_rotated.jpg",
"cat.jpg",
"cat_mask.jpg",
"cat_rotated.jpg",
"cat_rotated_mask.jpg",
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",
@ -84,6 +86,8 @@ filegroup(
"burger_rotated.jpg",
"cat.jpg",
"cat_mask.jpg",
"cat_rotated.jpg",
"cat_rotated_mask.jpg",
"cats_and_dogs.jpg",
"cats_and_dogs_no_resizing.jpg",
"cats_and_dogs_rotated.jpg",

View File

@ -82,7 +82,8 @@ absl::StatusOr<std::string> PathToResourceAsFile(const std::string& path) {
// If that fails, assume it was a relative path, and try just the base name.
{
const size_t last_slash_idx = path.find_last_of("\\/");
CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path.
RET_CHECK(last_slash_idx != std::string::npos)
<< path << " doesn't have a slash in it"; // Make sure it's a path.
auto base_name = path.substr(last_slash_idx + 1);
auto status_or_path = PathToResourceAsFileInternal(base_name);
if (status_or_path.ok()) {

View File

@ -71,7 +71,8 @@ absl::StatusOr<std::string> PathToResourceAsFile(const std::string& path) {
// If that fails, assume it was a relative path, and try just the base name.
{
const size_t last_slash_idx = path.find_last_of("\\/");
CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path.
RET_CHECK(last_slash_idx != std::string::npos)
<< path << " doesn't have a slash in it"; // Make sure it's a path.
auto base_name = path.substr(last_slash_idx + 1);
auto status_or_path = PathToResourceAsFileInternal(base_name);
if (status_or_path.ok()) {

View File

@ -76,6 +76,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"],
)
http_file(
name = "com_google_mediapipe_cat_rotated_jpg",
sha256 = "b78cee5ad14c9f36b1c25d103db371d81ca74d99030063c46a38e80bb8f38649",
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated.jpg?generation=1666304165042123"],
)
http_file(
name = "com_google_mediapipe_cat_rotated_mask_jpg",
sha256 = "f336973e7621d602f2ebc9a6ab1c62d8502272d391713f369d3b99541afda861",
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated_mask.jpg?generation=1666304167148173"],
)
http_file(
name = "com_google_mediapipe_cats_and_dogs_jpg",
sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c",
@ -162,8 +174,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"],
sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666629474155924"],
)
http_file(
@ -174,8 +186,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"],
sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"],
)
http_file(
@ -258,8 +270,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"],
sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666629478777955"],
)
http_file(
@ -456,14 +468,14 @@ def external_files():
http_file(
name = "com_google_mediapipe_mobilenet_v2_1_0_224_json",
sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"],
sha256 = "94613ea9539a20a3352604004be6d4d64d4d76250bc9042fcd8685c9a8498517",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1666633416316646"],
)
http_file(
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json",
sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"],
sha256 = "3703eadcf838b65bbc2b2aa11dbb1f1bc654c7a09a7aba5ca75a26096484a8ac",
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1666633418665507"],
)
http_file(
@ -606,8 +618,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"],
sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
)
http_file(
@ -798,8 +810,8 @@ def external_files():
http_file(
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102",
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"],
sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257",
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666629489421733"],
)
http_file(