Merge branch 'master' into gesture-recognizer-python
This commit is contained in:
commit
0de97497fa
|
@ -163,7 +163,6 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||
|
@ -182,12 +181,6 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
|||
auto raw_scores = view.buffer<float>();
|
||||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
if (options.has_tensor_index()) {
|
||||
classification_list->set_tensor_index(options.tensor_index());
|
||||
}
|
||||
if (options.has_tensor_name()) {
|
||||
classification_list->set_tensor_name(options.tensor_name());
|
||||
}
|
||||
if (is_binary_classification_) {
|
||||
Classification* class_first = classification_list->add_classification();
|
||||
Classification* class_second = classification_list->add_classification();
|
||||
|
|
|
@ -72,9 +72,4 @@ message TensorsToClassificationCalculatorOptions {
|
|||
// that are not in the `allow_classes` field will be completely ignored.
|
||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||
repeated int32 allow_classes = 8 [packed = true];
|
||||
|
||||
// The optional index of the tensor these classifications originate from.
|
||||
optional int32 tensor_index = 10;
|
||||
// The optional name of the tensor these classifications originate from.
|
||||
optional string tensor_name = 11;
|
||||
}
|
||||
|
|
|
@ -240,36 +240,6 @@ TEST_F(TensorsToClassificationCalculatorTest,
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
CorrectOutputWithTensorNameAndIndex) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
tensor_index: 1
|
||||
tensor_name: "foo"
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(3, classification_list.classification_size());
|
||||
|
||||
// Verify that the tensor_index and tensor_name fields are correctly set.
|
||||
EXPECT_EQ(classification_list.tensor_index(), 1);
|
||||
EXPECT_EQ(classification_list.tensor_name(), "foo");
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
ClassNameAllowlistWithLabelItems) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
|
|
|
@ -293,7 +293,6 @@ mediapipe_proto_library(
|
|||
name = "rect_proto",
|
||||
srcs = ["rect.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework/formats:location_data_proto"],
|
||||
)
|
||||
|
||||
mediapipe_register_type(
|
||||
|
|
|
@ -37,10 +37,6 @@ message Classification {
|
|||
// Group of Classification protos.
|
||||
message ClassificationList {
|
||||
repeated Classification classification = 1;
|
||||
// Optional index of the tensor that produced these classifications.
|
||||
optional int32 tensor_index = 2;
|
||||
// Optional name of the tensor that produced these classifications.
|
||||
optional string tensor_name = 3;
|
||||
}
|
||||
|
||||
// Group of ClassificationList protos.
|
||||
|
|
|
@ -38,13 +38,20 @@ static pthread_key_t egl_release_thread_key;
|
|||
static pthread_once_t egl_release_key_once = PTHREAD_ONCE_INIT;
|
||||
|
||||
static void EglThreadExitCallback(void* key_value) {
|
||||
#if defined(__ANDROID__)
|
||||
eglMakeCurrent(EGL_NO_DISPLAY, EGL_NO_SURFACE, EGL_NO_SURFACE,
|
||||
EGL_NO_CONTEXT);
|
||||
#else
|
||||
// Some implementations have chosen to allow EGL_NO_DISPLAY as a valid display
|
||||
// parameter for eglMakeCurrent. This behavior is not portable to all EGL
|
||||
// implementations, and should be considered as an undocumented vendor
|
||||
// extension.
|
||||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||
//
|
||||
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
|
||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||
#endif
|
||||
eglReleaseThread();
|
||||
}
|
||||
|
||||
|
|
|
@ -98,6 +98,5 @@ def train_model(model: tf.keras.Model, hparams: hp.HParams,
|
|||
return model.fit(
|
||||
x=train_ds,
|
||||
epochs=hparams.train_epochs,
|
||||
steps_per_epoch=hparams.steps_per_epoch,
|
||||
validation_data=validation_ds,
|
||||
callbacks=callbacks)
|
||||
|
|
|
@ -87,6 +87,7 @@ cc_library(
|
|||
cc_library(
|
||||
name = "builtin_task_graphs",
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
],
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
"""The public facing packet getter APIs."""
|
||||
|
||||
from typing import List, Type
|
||||
from typing import List
|
||||
|
||||
from google.protobuf import message
|
||||
from google.protobuf import symbol_database
|
||||
|
@ -39,7 +39,7 @@ get_image_frame = _packet_getter.get_image_frame
|
|||
get_matrix = _packet_getter.get_matrix
|
||||
|
||||
|
||||
def get_proto(packet: mp_packet.Packet) -> Type[message.Message]:
|
||||
def get_proto(packet: mp_packet.Packet) -> message.Message:
|
||||
"""Get the content of a MediaPipe proto Packet as a proto message.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -46,6 +46,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/gpu:gpu_origin_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
|
|
|
@ -17,7 +17,7 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// A single classification result.
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.tasks.components.containers.proto;
|
|||
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// List of predicted categories with an optional timestamp.
|
||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "EmbeddingsProto";
|
||||
|
||||
// Defines a dense floating-point embedding.
|
||||
message FloatEmbedding {
|
||||
repeated float values = 1 [packed = true];
|
||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -128,6 +129,9 @@ absl::Status ConfigureImageToTensorCalculator(
|
|||
options->mutable_output_tensor_float_range()->set_max((255.0f - mean) /
|
||||
std);
|
||||
}
|
||||
// TODO: need to.support different GPU origin on differnt
|
||||
// platforms or applications.
|
||||
options->set_gpu_origin(mediapipe::GpuOrigin::TOP_LEFT);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -73,6 +73,19 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "sentencepiece_tokenizer_test",
|
||||
srcs = ["sentencepiece_tokenizer_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:albert_model",
|
||||
],
|
||||
deps = [
|
||||
":sentencepiece_tokenizer",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tokenizer_utils",
|
||||
srcs = ["tokenizer_utils.cc"],
|
||||
|
@ -95,6 +108,33 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tokenizer_utils_test",
|
||||
srcs = ["tokenizer_utils_test.cc"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/text:albert_model",
|
||||
"//mediapipe/tasks/testdata/text:mobile_bert_model",
|
||||
"//mediapipe/tasks/testdata/text:text_classifier_models",
|
||||
],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":bert_tokenizer",
|
||||
":regex_tokenizer",
|
||||
":sentencepiece_tokenizer",
|
||||
":tokenizer_utils",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "regex_tokenizer",
|
||||
srcs = [
|
||||
|
|
|
@ -58,6 +58,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto",
|
||||
|
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h"
|
||||
|
@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap;
|
|||
using ::mediapipe::tasks::vision::image_embedder::proto::
|
||||
ImageEmbedderGraphOptions;
|
||||
|
||||
// Builds a NormalizedRect covering the entire image.
|
||||
NormalizedRect BuildFullImageNormRect() {
|
||||
NormalizedRect norm_rect;
|
||||
norm_rect.set_x_center(0.5);
|
||||
norm_rect.set_y_center(0.5);
|
||||
norm_rect.set_width(1);
|
||||
norm_rect.set_height(1);
|
||||
return norm_rect;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a single node of type
|
||||
// "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is
|
||||
// running in the live stream mode, a "FlowLimiterCalculator" will be added to
|
||||
|
@ -148,15 +139,16 @@ absl::StatusOr<std::unique_ptr<ImageEmbedder>> ImageEmbedder::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
||||
Image image, std::optional<NormalizedRect> roi) {
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
|
@ -167,15 +159,16 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::Embed(
|
|||
}
|
||||
|
||||
absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
||||
Image image, int64 timestamp_ms, std::optional<NormalizedRect> roi) {
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
|
@ -188,16 +181,17 @@ absl::StatusOr<EmbeddingResult> ImageEmbedder::EmbedForVideo(
|
|||
return output_packets[kEmbeddingResultStreamName].Get<EmbeddingResult>();
|
||||
}
|
||||
|
||||
absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms,
|
||||
std::optional<NormalizedRect> roi) {
|
||||
absl::Status ImageEmbedder::EmbedAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
NormalizedRect norm_rect =
|
||||
roi.has_value() ? roi.value() : BuildFullImageNormRect();
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
|
|
|
@ -21,11 +21,11 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/embedder_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
static absl::StatusOr<std::unique_ptr<ImageEmbedder>> Create(
|
||||
std::unique_ptr<ImageEmbedderOptions> options);
|
||||
|
||||
// Performs embedding extraction on the provided single image. Extraction
|
||||
// is performed on the region of interest specified by the `roi` argument if
|
||||
// provided, or on the entire image otherwise.
|
||||
// Performs embedding extraction on the provided single image.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing embedding
|
||||
// extraction, by setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform embedding extraction, by
|
||||
// setting its 'region_of_interest' field. If not specified, the full image
|
||||
// is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageEmbedder is created with the image
|
||||
// running mode.
|
||||
|
@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult> Embed(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs embedding extraction on the provided video frame. Extraction
|
||||
// is performed on the region of interested specified by the `roi` argument if
|
||||
// provided, or on the entire image otherwise.
|
||||
// Performs embedding extraction on the provided video frame.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing embedding
|
||||
// extraction, by setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform embedding extraction, by
|
||||
// setting its 'region_of_interest' field. If not specified, the full image
|
||||
// is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageEmbedder is created with the video
|
||||
// running mode.
|
||||
|
@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::proto::EmbeddingResult> EmbedForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to embedder, and the results will be available via
|
||||
// the "result_callback" provided in the ImageEmbedderOptions. Embedding
|
||||
// extraction is performed on the region of interested specified by the `roi`
|
||||
// argument if provided, or on the entire image otherwise.
|
||||
// the "result_callback" provided in the ImageEmbedderOptions.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify:
|
||||
// - the rotation to apply to the image before performing embedding
|
||||
// extraction, by setting its 'rotation_degrees' field.
|
||||
// and/or
|
||||
// - the region-of-interest on which to perform embedding extraction, by
|
||||
// setting its 'region_of_interest' field. If not specified, the full image
|
||||
// is used.
|
||||
// If both are specified, the crop around the region-of-interest is extracted
|
||||
// first, then the specified rotation is applied to the crop.
|
||||
//
|
||||
// Only use this method when the ImageEmbedder is created with the live
|
||||
// stream running mode.
|
||||
|
@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi {
|
|||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status EmbedAsync(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the ImageEmbedder when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
@ -42,7 +41,9 @@ namespace image_embedder {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image crop, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
// Bounding box in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
NormalizedRect roi;
|
||||
roi.set_x_center(200.0 / 480);
|
||||
roi.set_y_center(0.5);
|
||||
roi.set_width(400.0 / 480);
|
||||
roi.set_height(1.0f);
|
||||
// Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg".
|
||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
image_embedder->Embed(image, roi));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& image_result,
|
||||
image_embedder->Embed(image, image_processing_options));
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
|
||||
|
@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
// Load images: one is a rotated version of the other.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg")));
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"burger_rotated.jpg")));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result,
|
||||
image_embedder->Embed(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& rotated_result,
|
||||
image_embedder->Embed(rotated, image_processing_options));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(image_result, false);
|
||||
CheckMobileNetV3Result(rotated_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0),
|
||||
rotated_result.embeddings(0).entries(0)));
|
||||
double expected_similarity = 0.572265;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||
auto options = std::make_unique<ImageEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageEmbedder> image_embedder,
|
||||
ImageEmbedder::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image crop, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "burger_crop.jpg")));
|
||||
MP_ASSERT_OK_AND_ASSIGN(Image rotated,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
"burger_rotated.jpg")));
|
||||
// Region-of-interest corresponding to burger_crop.jpg.
|
||||
Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333};
|
||||
ImageProcessingOptions image_processing_options{roi,
|
||||
/*rotation_degrees=*/-90};
|
||||
|
||||
// Extract both embeddings.
|
||||
MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result,
|
||||
image_embedder->Embed(crop));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
const EmbeddingResult& rotated_result,
|
||||
image_embedder->Embed(rotated, image_processing_options));
|
||||
|
||||
// Check results.
|
||||
CheckMobileNetV3Result(crop_result, false);
|
||||
CheckMobileNetV3Result(rotated_result, false);
|
||||
// CheckCosineSimilarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity,
|
||||
ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0),
|
||||
rotated_result.embeddings(0).entries(0)));
|
||||
double expected_similarity = 0.62838;
|
||||
EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy);
|
||||
}
|
||||
|
||||
class VideoModeTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(VideoModeTest, FailsWithCallingWrongMethod) {
|
||||
|
|
|
@ -24,10 +24,12 @@ cc_library(
|
|||
":image_segmenter_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||
|
@ -48,6 +50,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
|
|
|
@ -17,8 +17,10 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
|
||||
|
@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
|||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.ImageSegmenterGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageSegmenterOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >>
|
||||
graph.Out(kGroupedSegmentationTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag);
|
||||
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
|
||||
{kImageTag, kNormRectTag},
|
||||
kGroupedSegmentationTag);
|
||||
}
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -139,47 +146,68 @@ absl::StatusOr<std::unique_ptr<ImageSegmenter>> ImageSegmenter::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::Segment(
|
||||
mediapipe::Image image) {
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData({{kImageInStreamName,
|
||||
mediapipe::MakePacket<Image>(std::move(image))}}));
|
||||
ProcessImageData(
|
||||
{{kImageInStreamName, mediapipe::MakePacket<Image>(std::move(image))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Image>> ImageSegmenter::SegmentForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms) {
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
return output_packets[kSegmentationStreamName].Get<std::vector<Image>>();
|
||||
}
|
||||
|
||||
absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) {
|
||||
absl::Status ImageSegmenter::SegmentAsync(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
NormalizedRect norm_rect,
|
||||
ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
|
@ -116,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// running mode.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(mediapipe::Image image);
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> Segment(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs image segmentation on the provided video frame.
|
||||
// Only use this method when the ImageSegmenter is created with the video
|
||||
|
@ -133,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
// per-category segmented image mask.
|
||||
// If the output_type is CONFIDENCE_MASK, the returned vector of images
|
||||
// contains only one confidence image mask.
|
||||
absl::StatusOr<std::vector<mediapipe::Image>> SegmentForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms);
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to perform image segmentation, and the results will
|
||||
// be available via the "result_callback" provided in the
|
||||
|
@ -150,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// sent to the image segmenter. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
//
|
||||
// The optional 'image_processing_options' parameter can be used to specify
|
||||
// the rotation to apply to the image before performing segmentation, by
|
||||
// setting its 'rotation_degrees' field. Note that specifying a
|
||||
// region-of-interest using the 'region_of_interest' field is NOT supported
|
||||
// and will result in an invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" prvoides
|
||||
// - A vector of segmented image masks.
|
||||
// If the output_type is CATEGORY_MASK, the returned vector of images is
|
||||
|
@ -161,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi {
|
|||
// no longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms);
|
||||
absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the ImageSegmenter when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
|
@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
|||
constexpr char kSegmentationTag[] = "SEGMENTATION";
|
||||
constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kOutputSizeTag[] = "OUTPUT_SIZE";
|
||||
|
||||
|
@ -159,6 +161,10 @@ absl::StatusOr<const Tensor*> GetOutputTensor(
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform segmentation on.
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection
|
||||
// on.
|
||||
// @Optional: rect covering the whole image is used if not specified.
|
||||
//
|
||||
// Outputs:
|
||||
// SEGMENTATION - mediapipe::Image @Multiple
|
||||
|
@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<ImageSegmenterOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto output_streams,
|
||||
BuildSegmentationTask(
|
||||
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildSegmentationTask(
|
||||
sc->Options<ImageSegmenterOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
|
||||
auto& merge_images_to_vector =
|
||||
graph.AddNode("MergeImagesToVectorCalculator");
|
||||
|
@ -228,7 +236,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<ImageSegmenterOutputs> BuildSegmentationTask(
|
||||
const ImageSegmenterOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
|
||||
// Adds preprocessing calculators and connects them to the graph input image
|
||||
|
@ -240,6 +248,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph {
|
|||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In(kImageTag);
|
||||
norm_rect_in >> preprocessing.In(kNormRectTag);
|
||||
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the ImageToTensorCalculator.
|
||||
|
|
|
@ -29,8 +29,10 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
@ -44,6 +46,8 @@ namespace {
|
|||
|
||||
using ::mediapipe::Image;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
@ -237,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
EXPECT_EQ(confidence_masks.size(), 21);
|
||||
|
||||
|
@ -253,6 +256,61 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
EXPECT_EQ(confidence_masks.size(), 21);
|
||||
|
||||
cv::Mat expected_mask =
|
||||
cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"),
|
||||
cv::IMREAD_GRAYSCALE);
|
||||
cv::Mat expected_mask_float;
|
||||
expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f);
|
||||
|
||||
// Cat category index 8.
|
||||
cv::Mat cat_mask = mediapipe::formats::MatView(
|
||||
confidence_masks[8].GetImageFrameSharedPtr().get());
|
||||
EXPECT_THAT(cat_mask,
|
||||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, FailsWithRegionOfInterest) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
ImageSegmenter::Create(std::move(options)));
|
||||
Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1};
|
||||
ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0};
|
||||
|
||||
auto results = segmenter->Segment(image, image_processing_options);
|
||||
EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(results.status().message(),
|
||||
HasSubstr("This task doesn't support region-of-interest"));
|
||||
EXPECT_THAT(
|
||||
results.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kImageProcessingInvalidArgumentError))));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
||||
Image image =
|
||||
GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg"));
|
||||
|
|
|
@ -31,6 +31,7 @@ android_binary(
|
|||
multidex = "native",
|
||||
resource_files = ["//mediapipe/tasks/examples/android:resource_files"],
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
|
|
|
@ -16,7 +16,6 @@ package com.google.mediapipe.tasks.examples.objectdetector;
|
|||
|
||||
import android.content.Intent;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.Matrix;
|
||||
import android.media.MediaMetadataRetriever;
|
||||
import android.os.Bundle;
|
||||
import android.provider.MediaStore;
|
||||
|
@ -29,9 +28,11 @@ import androidx.activity.result.ActivityResultLauncher;
|
|||
import androidx.activity.result.contract.ActivityResultContracts;
|
||||
import androidx.exifinterface.media.ExifInterface;
|
||||
// ContentResolver dependency
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector;
|
||||
|
@ -82,6 +83,7 @@ public class MainActivity extends AppCompatActivity {
|
|||
if (resultIntent != null) {
|
||||
if (result.getResultCode() == RESULT_OK) {
|
||||
Bitmap bitmap = null;
|
||||
int rotation = 0;
|
||||
try {
|
||||
bitmap =
|
||||
downscaleBitmap(
|
||||
|
@ -93,13 +95,16 @@ public class MainActivity extends AppCompatActivity {
|
|||
try {
|
||||
InputStream imageData =
|
||||
this.getContentResolver().openInputStream(resultIntent.getData());
|
||||
bitmap = rotateBitmap(bitmap, imageData);
|
||||
} catch (IOException e) {
|
||||
rotation = getImageRotation(imageData);
|
||||
} catch (IOException | MediaPipeException e) {
|
||||
Log.e(TAG, "Bitmap rotation error:" + e);
|
||||
}
|
||||
if (bitmap != null) {
|
||||
MPImage image = new BitmapImageBuilder(bitmap).build();
|
||||
ObjectDetectionResult detectionResult = objectDetector.detect(image);
|
||||
ObjectDetectionResult detectionResult =
|
||||
objectDetector.detect(
|
||||
image,
|
||||
ImageProcessingOptions.builder().setRotationDegrees(rotation).build());
|
||||
imageView.setData(image, detectionResult);
|
||||
runOnUiThread(() -> imageView.update());
|
||||
}
|
||||
|
@ -210,28 +215,25 @@ public class MainActivity extends AppCompatActivity {
|
|||
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
|
||||
}
|
||||
|
||||
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
|
||||
private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException {
|
||||
int orientation =
|
||||
new ExifInterface(imageData)
|
||||
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
|
||||
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
|
||||
return inputBitmap;
|
||||
}
|
||||
Matrix matrix = new Matrix();
|
||||
switch (orientation) {
|
||||
case ExifInterface.ORIENTATION_NORMAL:
|
||||
return 0;
|
||||
case ExifInterface.ORIENTATION_ROTATE_90:
|
||||
matrix.postRotate(90);
|
||||
break;
|
||||
return 90;
|
||||
case ExifInterface.ORIENTATION_ROTATE_180:
|
||||
matrix.postRotate(180);
|
||||
break;
|
||||
return 180;
|
||||
case ExifInterface.ORIENTATION_ROTATE_270:
|
||||
matrix.postRotate(270);
|
||||
break;
|
||||
return 270;
|
||||
default:
|
||||
matrix.postRotate(0);
|
||||
// TODO: use getRotationDegrees() and isFlipped() instead of switch once flip
|
||||
// is supported.
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(),
|
||||
"Flipped images are not supported yet.");
|
||||
}
|
||||
return Bitmap.createBitmap(
|
||||
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -22,7 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
|
|||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
|
|
|
@ -28,6 +28,7 @@ android_library(
|
|||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -24,7 +24,6 @@ import com.google.mediapipe.tasks.core.TaskResult;
|
|||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/** The base class of MediaPipe vision tasks. */
|
||||
public class BaseVisionTaskApi implements AutoCloseable {
|
||||
|
@ -32,7 +31,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
private final TaskRunner runner;
|
||||
private final RunningMode runningMode;
|
||||
private final String imageStreamName;
|
||||
private final Optional<String> normRectStreamName;
|
||||
private final String normRectStreamName;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||
|
@ -40,27 +39,13 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
* @param imageStreamName the name of the input image stream.
|
||||
*/
|
||||
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.imageStreamName = imageStreamName;
|
||||
this.normRectStreamName = Optional.empty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
|
||||
* input.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
* @param imageStreamName the name of the input image stream.
|
||||
* @param normRectStreamName the name of the input normalized rect image stream.
|
||||
* @param normRectStreamName the name of the input normalized rect image stream used to provide
|
||||
* (mandatory) rotation and (optional) region-of-interest.
|
||||
*/
|
||||
public BaseVisionTaskApi(
|
||||
TaskRunner runner,
|
||||
|
@ -70,7 +55,7 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.imageStreamName = imageStreamName;
|
||||
this.normRectStreamName = Optional.of(normRectStreamName);
|
||||
this.normRectStreamName = normRectStreamName;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -78,53 +63,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
|
||||
* input.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @throws MediaPipeException if the task is not in the image mode.
|
||||
*/
|
||||
protected TaskResult processImageData(MPImage image) {
|
||||
protected TaskResult processImageData(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected TaskResult processImageData(MPImage image, RectF roi) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
normRectStreamName,
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
|
@ -133,55 +88,24 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
||||
* input.
|
||||
* @throws MediaPipeException if the task is not in the video mode.
|
||||
*/
|
||||
protected TaskResult processVideoData(MPImage image, long timestampMs) {
|
||||
protected TaskResult processVideoData(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected TaskResult processVideoData(MPImage image, RectF roi, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
normRectStreamName,
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
|
@ -190,55 +114,24 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
||||
* input.
|
||||
* @throws MediaPipeException if the task is not in the stream mode.
|
||||
*/
|
||||
protected void sendLiveStreamData(MPImage image, long timestampMs) {
|
||||
protected void sendLiveStreamData(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected void sendLiveStreamData(MPImage image, RectF roi, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
normRectStreamName,
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions)));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
|
@ -248,13 +141,23 @@ public class BaseVisionTaskApi implements AutoCloseable {
|
|||
runner.close();
|
||||
}
|
||||
|
||||
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
|
||||
private static NormalizedRect convertToNormalizedRect(RectF rect) {
|
||||
/**
|
||||
* Converts an {@link ImageProcessingOptions} instance into a {@link NormalizedRect} protobuf
|
||||
* message.
|
||||
*/
|
||||
private static NormalizedRect convertToNormalizedRect(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
RectF regionOfInterest =
|
||||
imageProcessingOptions.regionOfInterest().isPresent()
|
||||
? imageProcessingOptions.regionOfInterest().get()
|
||||
: new RectF(0, 0, 1, 1);
|
||||
return NormalizedRect.newBuilder()
|
||||
.setXCenter(rect.centerX())
|
||||
.setYCenter(rect.centerY())
|
||||
.setWidth(rect.width())
|
||||
.setHeight(rect.height())
|
||||
.setXCenter(regionOfInterest.centerX())
|
||||
.setYCenter(regionOfInterest.centerY())
|
||||
.setWidth(regionOfInterest.width())
|
||||
.setHeight(regionOfInterest.height())
|
||||
// Convert to radians anti-clockwise.
|
||||
.setRotation(-(float) Math.PI * imageProcessingOptions.rotationDegrees() / 180.0f)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.core;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Optional;
|
||||
|
||||
// TODO: add support for image flipping.
|
||||
/** Options for image processing. */
|
||||
@AutoValue
|
||||
public abstract class ImageProcessingOptions {
|
||||
|
||||
/**
|
||||
* Builder for {@link ImageProcessingOptions}.
|
||||
*
|
||||
* <p>If both region-of-interest and rotation are specified, the crop around the
|
||||
* region-of-interest is extracted first, then the specified rotation is applied to the crop.
|
||||
*/
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets the optional region-of-interest to crop from the image. If not specified, the full image
|
||||
* is used.
|
||||
*
|
||||
* <p>Coordinates must be in [0,1], {@code left} must be < {@code right} and {@code top} must be
|
||||
* < {@code bottom}, otherwise an IllegalArgumentException will be thrown when {@link #build()}
|
||||
* is called.
|
||||
*/
|
||||
public abstract Builder setRegionOfInterest(RectF value);
|
||||
|
||||
/**
|
||||
* Sets the rotation to apply to the image (or cropped region-of-interest), in degrees
|
||||
* clockwise. Defaults to 0.
|
||||
*
|
||||
* <p>The rotation must be a multiple (positive or negative) of 90°, otherwise an
|
||||
* IllegalArgumentException will be thrown when {@link #build()} is called.
|
||||
*/
|
||||
public abstract Builder setRotationDegrees(int value);
|
||||
|
||||
abstract ImageProcessingOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ImageProcessingOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if some of the provided values do not meet their
|
||||
* requirements.
|
||||
*/
|
||||
public final ImageProcessingOptions build() {
|
||||
ImageProcessingOptions options = autoBuild();
|
||||
if (options.regionOfInterest().isPresent()) {
|
||||
RectF roi = options.regionOfInterest().get();
|
||||
if (roi.left >= roi.right || roi.top >= roi.bottom) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Expected left < right and top < bottom, found: %s.", roi.toShortString()));
|
||||
}
|
||||
if (roi.left < 0 || roi.right > 1 || roi.top < 0 || roi.bottom > 1) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Expected RectF values in [0,1], found: %s.", roi.toShortString()));
|
||||
}
|
||||
}
|
||||
if (options.rotationDegrees() % 90 != 0) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Expected rotation to be a multiple of 90°, found: %d.",
|
||||
options.rotationDegrees()));
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
public abstract Optional<RectF> regionOfInterest();
|
||||
|
||||
public abstract int rotationDegrees();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageProcessingOptions.Builder().setRotationDegrees(0);
|
||||
}
|
||||
}
|
|
@ -15,7 +15,6 @@
|
|||
package com.google.mediapipe.tasks.vision.gesturerecognizer;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.RectF;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
|
@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
|
|||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureClassifierGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.gesturerecognizer.proto.GestureRecognizerGraphOptionsProto;
|
||||
|
@ -212,6 +212,25 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs gesture recognition on the provided single image with default image processing
|
||||
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||
* for input image format.
|
||||
*
|
||||
* <p>{@link GestureRecognizer} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public GestureRecognitionResult recognize(MPImage image) {
|
||||
return recognize(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs gesture recognition on the provided single image. Only use this method when the {@link
|
||||
* GestureRecognizer} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||
|
@ -223,12 +242,41 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public GestureRecognitionResult recognize(MPImage inputImage) {
|
||||
// TODO: add proper support for rotations.
|
||||
return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF());
|
||||
public GestureRecognitionResult recognize(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (GestureRecognitionResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs gesture recognition on the provided video frame with default image processing options,
|
||||
* i.e. without any rotation applied. Only use this method when the {@link GestureRecognizer} is
|
||||
* created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link GestureRecognizer} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public GestureRecognitionResult recognizeForVideo(MPImage image, long timestampMs) {
|
||||
return recognizeForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -244,14 +292,43 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public GestureRecognitionResult recognizeForVideo(MPImage inputImage, long inputTimestampMs) {
|
||||
// TODO: add proper support for rotations.
|
||||
return (GestureRecognitionResult)
|
||||
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||
public GestureRecognitionResult recognizeForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (GestureRecognitionResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform gesture recognition with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link GestureRecognizerOptions}. Only use this method when the
|
||||
* {@link GestureRecognition} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the gesture recognizer. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link GestureRecognizer} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void recognizeAsync(MPImage image, long timestampMs) {
|
||||
recognizeAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -268,13 +345,20 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void recognizeAsync(MPImage inputImage, long inputTimestampMs) {
|
||||
// TODO: add proper support for rotations.
|
||||
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||
public void recognizeAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link GestureRecognizer}. */
|
||||
|
@ -445,8 +529,14 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
}
|
||||
}
|
||||
|
||||
/** Creates a RectF covering the full image. */
|
||||
private static RectF buildFullImageRectF() {
|
||||
return new RectF(0, 0, 1, 1);
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("GestureRecognizer doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
package com.google.mediapipe.tasks.vision.imageclassifier;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.RectF;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
|
@ -26,7 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
|
|||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
|
@ -37,6 +36,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
|
|||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
|
||||
import java.io.File;
|
||||
|
@ -215,6 +215,24 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs classification on the provided single image with default image processing options,
|
||||
* i.e. using the whole image as region-of-interest and without any rotation applied. Only use
|
||||
* this method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link ImageClassifier} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classify(MPImage image) {
|
||||
return classify(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs classification on the provided single image. Only use this method when the {@link
|
||||
* ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||
|
@ -225,16 +243,23 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classify(MPImage inputImage) {
|
||||
return (ImageClassificationResult) processImageData(inputImage, buildFullImageRectF());
|
||||
public ImageClassificationResult classify(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
return (ImageClassificationResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs classification on the provided single image and region-of-interest. Only use this
|
||||
* method when the {@link ImageClassifier} is created with {@link RunningMode.IMAGE}.
|
||||
* Performs classification on the provided video frame with default image processing options, i.e.
|
||||
* using the whole image as region-of-interest and without any rotation applied. Only use this
|
||||
* method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageClassifier} supports the following color space types:
|
||||
*
|
||||
|
@ -242,13 +267,12 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
||||
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classify(MPImage inputImage, RectF roi) {
|
||||
return (ImageClassificationResult) processImageData(inputImage, roi);
|
||||
public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) {
|
||||
return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -264,21 +288,26 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classifyForVideo(MPImage inputImage, long inputTimestampMs) {
|
||||
return (ImageClassificationResult)
|
||||
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||
public ImageClassificationResult classifyForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs classification on the provided video frame with additional region-of-interest. Only
|
||||
* use this method when the {@link ImageClassifier} is created with {@link RunningMode.VIDEO}.
|
||||
* Sends live image data to perform classification with default image processing options, i.e.
|
||||
* using the whole image as region-of-interest and without any rotation applied, and the results
|
||||
* will be available via the {@link ResultListener} provided in the {@link
|
||||
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
|
||||
* {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageClassifier} supports the following color space types:
|
||||
*
|
||||
|
@ -286,15 +315,12 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
||||
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ImageClassificationResult classifyForVideo(
|
||||
MPImage inputImage, RectF roi, long inputTimestampMs) {
|
||||
return (ImageClassificationResult) processVideoData(inputImage, roi, inputTimestampMs);
|
||||
public void classifyAsync(MPImage image, long timestampMs) {
|
||||
classifyAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -311,37 +337,15 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void classifyAsync(MPImage inputImage, long inputTimestampMs) {
|
||||
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data and additional region-of-interest to perform classification, and the
|
||||
* results will be available via the {@link ResultListener} provided in the {@link
|
||||
* ImageClassifierOptions}. Only use this method when the {@link ImageClassifier} is created with
|
||||
* {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageClassifier} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param roi a {@link RectF} specifying the region of interest on which to perform
|
||||
* classification. Coordinates are expected to be specified as normalized values in [0,1].
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void classifyAsync(MPImage inputImage, RectF roi, long inputTimestampMs) {
|
||||
sendLiveStreamData(inputImage, roi, inputTimestampMs);
|
||||
public void classifyAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up and {@link ImageClassifier}. */
|
||||
|
@ -447,9 +451,4 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a RectF covering the full image. */
|
||||
private static RectF buildFullImageRectF() {
|
||||
return new RectF(0, 0, 1, 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,6 +32,7 @@ import com.google.mediapipe.tasks.core.TaskOptions;
|
|||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.proto.ObjectDetectorOptionsProto;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
|
@ -96,8 +97,10 @@ import java.util.Optional;
|
|||
public final class ObjectDetector extends BaseVisionTaskApi {
|
||||
private static final String TAG = ObjectDetector.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
||||
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
||||
|
@ -204,7 +207,25 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs object detection on the provided single image with default image processing options,
|
||||
* i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is
|
||||
* created with {@link RunningMode.IMAGE}.
|
||||
*
|
||||
* <p>{@link ObjectDetector} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detect(MPImage image) {
|
||||
return detect(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -217,11 +238,41 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detect(MPImage inputImage) {
|
||||
return (ObjectDetectionResult) processImageData(inputImage);
|
||||
public ObjectDetectionResult detect(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (ObjectDetectionResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs object detection on the provided video frame with default image processing options,
|
||||
* i.e. without any rotation applied. Only use this method when the {@link ObjectDetector} is
|
||||
* created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ObjectDetector} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detectForVideo(MPImage image, long timestampMs) {
|
||||
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -237,12 +288,43 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detectForVideo(MPImage inputImage, long inputTimestampMs) {
|
||||
return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
|
||||
public ObjectDetectionResult detectForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (ObjectDetectionResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform object detection with default image processing options, i.e.
|
||||
* without any rotation applied, and the results will be available via the {@link ResultListener}
|
||||
* provided in the {@link ObjectDetectorOptions}. Only use this method when the {@link
|
||||
* ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ObjectDetector} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(MPImage image, long timestampMs) {
|
||||
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -259,12 +341,20 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param inputImage a MediaPipe {@link MPImage} object for processing.
|
||||
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(MPImage inputImage, long inputTimestampMs) {
|
||||
sendLiveStreamData(inputImage, inputTimestampMs);
|
||||
public void detectAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ObjectDetector}. */
|
||||
|
@ -415,4 +505,15 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("ObjectDetector doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.coretest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="coretest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.coretest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.core;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
/** Test for {@link ImageProcessingOptions}/ */
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public final class ImageProcessingOptionsTest {
|
||||
|
||||
@Test
|
||||
public void succeedsWithValidInputs() throws Exception {
|
||||
ImageProcessingOptions options =
|
||||
ImageProcessingOptions.builder()
|
||||
.setRegionOfInterest(new RectF(0.0f, 0.1f, 1.0f, 0.9f))
|
||||
.setRotationDegrees(270)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsWithLeftHigherThanRight() {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageProcessingOptions.builder()
|
||||
.setRegionOfInterest(new RectF(0.9f, 0.0f, 0.1f, 1.0f))
|
||||
.build());
|
||||
assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsWithBottomHigherThanTop() {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
ImageProcessingOptions.builder()
|
||||
.setRegionOfInterest(new RectF(0.0f, 0.9f, 1.0f, 0.1f))
|
||||
.build());
|
||||
assertThat(exception).hasMessageThat().contains("Expected left < right and top < bottom");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failsWithInvalidRotation() {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> ImageProcessingOptions.builder().setRotationDegrees(1).build());
|
||||
assertThat(exception).hasMessageThat().contains("Expected rotation to be a multiple of 90°");
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ import static org.junit.Assert.assertThrows;
|
|||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.common.truth.Correspondence;
|
||||
|
@ -30,6 +31,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
|
|||
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.gesturerecognizer.GestureRecognizer.GestureRecognizerOptions;
|
||||
import java.io.InputStream;
|
||||
|
@ -46,11 +48,14 @@ public class GestureRecognizerTest {
|
|||
private static final String GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE = "gesture_recognizer.task";
|
||||
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
|
||||
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
|
||||
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
|
||||
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
|
||||
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
|
||||
private static final String TAG = "Gesture Recognizer Test";
|
||||
private static final String THUMB_UP_LABEL = "Thumb_Up";
|
||||
private static final int THUMB_UP_INDEX = 5;
|
||||
private static final String POINTING_UP_LABEL = "Pointing_Up";
|
||||
private static final int POINTING_UP_INDEX = 3;
|
||||
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
||||
private static final int IMAGE_WIDTH = 382;
|
||||
private static final int IMAGE_HEIGHT = 406;
|
||||
|
@ -135,6 +140,53 @@ public class GestureRecognizerTest {
|
|||
gestureRecognizer.recognize(getImageFromAsset(TWO_HANDS_IMAGE));
|
||||
assertThat(actualResult.handednesses()).hasSize(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithRotation() throws Exception {
|
||||
GestureRecognizerOptions options =
|
||||
GestureRecognizerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setNumHands(1)
|
||||
.build();
|
||||
GestureRecognizer gestureRecognizer =
|
||||
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
GestureRecognitionResult actualResult =
|
||||
gestureRecognizer.recognize(
|
||||
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
|
||||
assertThat(actualResult.gestures()).hasSize(1);
|
||||
assertThat(actualResult.gestures().get(0).get(0).index()).isEqualTo(POINTING_UP_INDEX);
|
||||
assertThat(actualResult.gestures().get(0).get(0).categoryName()).isEqualTo(POINTING_UP_LABEL);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithRegionOfInterest() throws Exception {
|
||||
GestureRecognizerOptions options =
|
||||
GestureRecognizerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(GESTURE_RECOGNIZER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setNumHands(1)
|
||||
.build();
|
||||
GestureRecognizer gestureRecognizer =
|
||||
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
gestureRecognizer.recognize(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("GestureRecognizer doesn't support region-of-interest");
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
|
@ -195,12 +247,16 @@ public class GestureRecognizerTest {
|
|||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
||||
() ->
|
||||
gestureRecognizer.recognizeForVideo(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
||||
() ->
|
||||
gestureRecognizer.recognizeAsync(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -225,7 +281,9 @@ public class GestureRecognizerTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> gestureRecognizer.recognizeAsync(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
||||
() ->
|
||||
gestureRecognizer.recognizeAsync(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -251,7 +309,9 @@ public class GestureRecognizerTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), 0));
|
||||
() ->
|
||||
gestureRecognizer.recognizeForVideo(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
|
@ -291,7 +351,8 @@ public class GestureRecognizerTest {
|
|||
getExpectedGestureRecognitionResult(THUMB_UP_LANDMARKS, THUMB_UP_LABEL, THUMB_UP_INDEX);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
GestureRecognitionResult actualResult =
|
||||
gestureRecognizer.recognizeForVideo(getImageFromAsset(THUMB_UP_IMAGE), i);
|
||||
gestureRecognizer.recognizeForVideo(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
}
|
||||
|
@ -317,9 +378,11 @@ public class GestureRecognizerTest {
|
|||
.build();
|
||||
try (GestureRecognizer gestureRecognizer =
|
||||
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
gestureRecognizer.recognizeAsync(image, 1);
|
||||
gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(MediaPipeException.class, () -> gestureRecognizer.recognizeAsync(image, 0));
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
|
@ -348,7 +411,7 @@ public class GestureRecognizerTest {
|
|||
try (GestureRecognizer gestureRecognizer =
|
||||
GestureRecognizer.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
gestureRecognizer.recognizeAsync(image, i);
|
||||
gestureRecognizer.recognizeAsync(image, /*timestampsMs=*/ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
|
|||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imageclassifier.ImageClassifier.ImageClassifierOptions;
|
||||
import java.io.InputStream;
|
||||
|
@ -47,7 +48,9 @@ public class ImageClassifierTest {
|
|||
private static final String FLOAT_MODEL_FILE = "mobilenet_v2_1.0_224.tflite";
|
||||
private static final String QUANTIZED_MODEL_FILE = "mobilenet_v1_0.25_224_quant.tflite";
|
||||
private static final String BURGER_IMAGE = "burger.jpg";
|
||||
private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg";
|
||||
private static final String MULTI_OBJECTS_IMAGE = "multi_objects.jpg";
|
||||
private static final String MULTI_OBJECTS_ROTATED_IMAGE = "multi_objects_rotated.jpg";
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ImageClassifierTest {
|
||||
|
@ -209,13 +212,60 @@ public class ImageClassifierTest {
|
|||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
// RectF around the soccer ball.
|
||||
RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
|
||||
ImageClassificationResult results =
|
||||
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), roi);
|
||||
imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithRotation() throws Exception {
|
||||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build())
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
ImageClassificationResult results =
|
||||
imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertCategoriesAre(
|
||||
results,
|
||||
Arrays.asList(
|
||||
Category.create(0.6390683f, 934, "cheeseburger", ""),
|
||||
Category.create(0.0495407f, 963, "meat loaf", ""),
|
||||
Category.create(0.0469720f, 925, "guacamole", "")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithRegionOfInterestAndRotation() throws Exception {
|
||||
ImageClassifierOptions options =
|
||||
ImageClassifierOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build())
|
||||
.setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build())
|
||||
.build();
|
||||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
// RectF around the chair.
|
||||
RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
|
||||
ImageClassificationResult results =
|
||||
imageClassifier.classify(
|
||||
getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertHasOneHeadAndOneTimestamp(results, 0);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
|
@ -269,12 +319,16 @@ public class ImageClassifierTest {
|
|||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
|
||||
() ->
|
||||
imageClassifier.classifyForVideo(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
|
||||
() ->
|
||||
imageClassifier.classifyAsync(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -296,7 +350,9 @@ public class ImageClassifierTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 0));
|
||||
() ->
|
||||
imageClassifier.classifyAsync(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -320,7 +376,9 @@ public class ImageClassifierTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageClassifier.classifyForVideo(getImageFromAsset(BURGER_IMAGE), 0));
|
||||
() ->
|
||||
imageClassifier.classifyForVideo(
|
||||
getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
|
@ -352,7 +410,8 @@ public class ImageClassifierTest {
|
|||
ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ImageClassificationResult results = imageClassifier.classifyForVideo(image, i);
|
||||
ImageClassificationResult results =
|
||||
imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
|
||||
assertHasOneHeadAndOneTimestamp(results, i);
|
||||
assertCategoriesAre(
|
||||
results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
|
||||
|
@ -377,9 +436,11 @@ public class ImageClassifierTest {
|
|||
.build();
|
||||
try (ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), 1);
|
||||
imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(MediaPipeException.class, () -> imageClassifier.classifyAsync(image, 0));
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
|
@ -405,7 +466,7 @@ public class ImageClassifierTest {
|
|||
try (ImageClassifier imageClassifier =
|
||||
ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
imageClassifier.classifyAsync(image, i);
|
||||
imageClassifier.classifyAsync(image, /*timestampMs=*/ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@ import com.google.mediapipe.tasks.components.containers.Category;
|
|||
import com.google.mediapipe.tasks.components.containers.Detection;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
|
||||
import java.io.InputStream;
|
||||
|
@ -45,10 +46,11 @@ import org.junit.runners.Suite.SuiteClasses;
|
|||
public class ObjectDetectorTest {
|
||||
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
||||
private static final String CAT_AND_DOG_ROTATED_IMAGE = "cats_and_dogs_rotated.jpg";
|
||||
private static final int IMAGE_WIDTH = 1200;
|
||||
private static final int IMAGE_HEIGHT = 600;
|
||||
private static final float CAT_SCORE = 0.69f;
|
||||
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
|
||||
private static final RectF CAT_BOUNDING_BOX = new RectF(611, 164, 986, 596);
|
||||
// TODO: Figure out why android_x86 and android_arm tests have slightly different
|
||||
// scores (0.6875 vs 0.69921875).
|
||||
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
|
||||
|
@ -67,7 +69,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -104,7 +106,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
// The score threshold should block all other other objects, except cat.
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -175,7 +177,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -228,6 +230,46 @@ public class ObjectDetectorTest {
|
|||
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_succeedsWithRotation() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.setMaxResults(1)
|
||||
.setCategoryAllowlist(Arrays.asList("cat"))
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
ObjectDetectionResult results =
|
||||
objectDetector.detect(
|
||||
getImageFromAsset(CAT_AND_DOG_ROTATED_IMAGE), imageProcessingOptions);
|
||||
|
||||
assertContainsOnlyCat(results, new RectF(22.0f, 611.0f, 452.0f, 890.0f), 0.7109375f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_failsWithRegionOfInterest() throws Exception {
|
||||
ObjectDetectorOptions options =
|
||||
ObjectDetectorOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||
.build();
|
||||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
objectDetector.detect(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), imageProcessingOptions));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("ObjectDetector doesn't support region-of-interest");
|
||||
}
|
||||
|
||||
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
|
||||
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
|
||||
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
|
||||
|
@ -282,12 +324,16 @@ public class ObjectDetectorTest {
|
|||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
() ->
|
||||
objectDetector.detectForVideo(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
() ->
|
||||
objectDetector.detectAsync(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -309,7 +355,9 @@ public class ObjectDetectorTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
() ->
|
||||
objectDetector.detectAsync(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
|
@ -333,7 +381,9 @@ public class ObjectDetectorTest {
|
|||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||
() ->
|
||||
objectDetector.detectForVideo(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
|
@ -348,7 +398,7 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -363,8 +413,9 @@ public class ObjectDetectorTest {
|
|||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ObjectDetectionResult results =
|
||||
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
|
||||
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||
objectDetector.detectForVideo(
|
||||
getImageFromAsset(CAT_AND_DOG_IMAGE), /*timestampsMs=*/ i);
|
||||
assertContainsOnlyCat(results, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -377,16 +428,18 @@ public class ObjectDetectorTest {
|
|||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(objectDetectionResult, inputImage) -> {
|
||||
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
||||
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.setMaxResults(1)
|
||||
.build();
|
||||
try (ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
objectDetector.detectAsync(image, 1);
|
||||
objectDetector.detectAsync(image, /*timestampsMs=*/ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> objectDetector.detectAsync(image, /*timestampsMs=*/ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
|
@ -402,7 +455,7 @@ public class ObjectDetectorTest {
|
|||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(objectDetectionResult, inputImage) -> {
|
||||
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
||||
assertContainsOnlyCat(objectDetectionResult, CAT_BOUNDING_BOX, CAT_SCORE);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.setMaxResults(1)
|
||||
|
@ -410,7 +463,7 @@ public class ObjectDetectorTest {
|
|||
try (ObjectDetector objectDetector =
|
||||
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
objectDetector.detectAsync(image, i);
|
||||
objectDetector.detectAsync(image, /*timestampsMs=*/ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -86,3 +86,13 @@ py_library(
|
|||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classifications",
|
||||
srcs = ["classifications.py"],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
168
mediapipe/tasks/python/components/containers/classifications.py
Normal file
168
mediapipe/tasks/python/components/containers/classifications.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Classifications data class."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_ClassificationEntryProto = classifications_pb2.ClassificationEntry
|
||||
_ClassificationsProto = classifications_pb2.Classifications
|
||||
_ClassificationResultProto = classifications_pb2.ClassificationResult
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassificationEntry:
|
||||
"""List of predicted classes (aka labels) for a given classifier head.
|
||||
|
||||
Attributes:
|
||||
categories: The array of predicted categories, usually sorted by descending
|
||||
scores (e.g. from high to low probability).
|
||||
timestamp_ms: The optional timestamp (in milliseconds) associated to the
|
||||
classification entry. This is useful for time series use cases, e.g.,
|
||||
audio classification.
|
||||
"""
|
||||
|
||||
categories: List[category_module.Category]
|
||||
timestamp_ms: Optional[int] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationEntryProto:
|
||||
"""Generates a ClassificationEntry protobuf object."""
|
||||
return _ClassificationEntryProto(
|
||||
categories=[category.to_pb2() for category in self.categories],
|
||||
timestamp_ms=self.timestamp_ms)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _ClassificationEntryProto) -> 'ClassificationEntry':
|
||||
"""Creates a `ClassificationEntry` object from the given protobuf object."""
|
||||
return ClassificationEntry(
|
||||
categories=[
|
||||
category_module.Category.create_from_pb2(category)
|
||||
for category in pb2_obj.categories
|
||||
],
|
||||
timestamp_ms=pb2_obj.timestamp_ms)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, ClassificationEntry):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Classifications:
|
||||
"""Represents the classifications for a given classifier head.
|
||||
|
||||
Attributes:
|
||||
entries: A list of `ClassificationEntry` objects.
|
||||
head_index: The index of the classifier head these categories refer to. This
|
||||
is useful for multi-head models.
|
||||
head_name: The name of the classifier head, which is the corresponding
|
||||
tensor metadata name.
|
||||
"""
|
||||
|
||||
entries: List[ClassificationEntry]
|
||||
head_index: int
|
||||
head_name: str
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationsProto:
|
||||
"""Generates a Classifications protobuf object."""
|
||||
return _ClassificationsProto(
|
||||
entries=[entry.to_pb2() for entry in self.entries],
|
||||
head_index=self.head_index,
|
||||
head_name=self.head_name)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||
"""Creates a `Classifications` object from the given protobuf object."""
|
||||
return Classifications(
|
||||
entries=[
|
||||
ClassificationEntry.create_from_pb2(entry)
|
||||
for entry in pb2_obj.entries
|
||||
],
|
||||
head_index=pb2_obj.head_index,
|
||||
head_name=pb2_obj.head_name)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, Classifications):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassificationResult:
|
||||
"""Contains one set of results per classifier head.
|
||||
|
||||
Attributes:
|
||||
classifications: A list of `Classifications` objects.
|
||||
"""
|
||||
|
||||
classifications: List[Classifications]
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ClassificationResultProto:
|
||||
"""Generates a ClassificationResult protobuf object."""
|
||||
return _ClassificationResultProto(classifications=[
|
||||
classification.to_pb2() for classification in self.classifications
|
||||
])
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls, pb2_obj: _ClassificationResultProto) -> 'ClassificationResult':
|
||||
"""Creates a `ClassificationResult` object from the given protobuf object.
|
||||
"""
|
||||
return ClassificationResult(classifications=[
|
||||
Classifications.create_from_pb2(classification)
|
||||
for classification in pb2_obj.classifications
|
||||
])
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
||||
Args:
|
||||
other: The object to be compared with.
|
||||
|
||||
Returns:
|
||||
True if the objects are equal.
|
||||
"""
|
||||
if not isinstance(other, ClassificationResult):
|
||||
return False
|
||||
|
||||
return self.to_pb2().__eq__(other.to_pb2())
|
|
@ -26,15 +26,13 @@ _NormalizedRectProto = rect_pb2.NormalizedRect
|
|||
@dataclasses.dataclass
|
||||
class Rect:
|
||||
"""A rectangle with rotation in image coordinates.
|
||||
|
||||
Attributes:
|
||||
x_center : The X coordinate of the top-left corner, in pixels.
|
||||
y_center : The Y coordinate of the top-left corner, in pixels.
|
||||
Attributes: x_center : The X coordinate of the top-left corner, in pixels.
|
||||
y_center : The Y coordinate of the top-left corner, in pixels.
|
||||
width: The width of the rectangle, in pixels.
|
||||
height: The height of the rectangle, in pixels.
|
||||
rotation: Rotation angle is clockwise in radians.
|
||||
rect_id: Optional unique id to help associate different rectangles to each
|
||||
other.
|
||||
other.
|
||||
"""
|
||||
|
||||
x_center: int
|
||||
|
@ -81,17 +79,16 @@ class Rect:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class NormalizedRect:
|
||||
"""A rectangle with rotation in normalized coordinates. The values of box
|
||||
"""A rectangle with rotation in normalized coordinates.
|
||||
The values of box
|
||||
center location and size are within [0, 1].
|
||||
|
||||
Attributes:
|
||||
x_center : The X normalized coordinate of the top-left corner.
|
||||
y_center : The Y normalized coordinate of the top-left corner.
|
||||
Attributes: x_center : The X normalized coordinate of the top-left corner.
|
||||
y_center : The Y normalized coordinate of the top-left corner.
|
||||
width: The width of the rectangle.
|
||||
height: The height of the rectangle.
|
||||
rotation: Rotation angle is clockwise in radians.
|
||||
rect_id: Optional unique id to help associate different rectangles to each
|
||||
other.
|
||||
other.
|
||||
"""
|
||||
|
||||
x_center: float
|
||||
|
@ -110,8 +107,7 @@ class NormalizedRect:
|
|||
width=self.width,
|
||||
height=self.height,
|
||||
rotation=self.rotation,
|
||||
rect_id=self.rect_id
|
||||
)
|
||||
rect_id=self.rect_id)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
|
@ -123,8 +119,7 @@ class NormalizedRect:
|
|||
width=pb2_obj.width,
|
||||
height=pb2_obj.height,
|
||||
rotation=pb2_obj.rotation,
|
||||
rect_id=pb2_obj.rect_id
|
||||
)
|
||||
rect_id=pb2_obj.rect_id)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
|
|
@ -61,19 +61,13 @@ class ClassifierOptions:
|
|||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(
|
||||
cls,
|
||||
pb2_obj: _ClassifierOptionsProto
|
||||
) -> 'ClassifierOptions':
|
||||
def create_from_pb2(cls,
|
||||
pb2_obj: _ClassifierOptionsProto) -> 'ClassifierOptions':
|
||||
"""Creates a `ClassifierOptions` object from the given protobuf object."""
|
||||
return ClassifierOptions(
|
||||
score_threshold=pb2_obj.score_threshold,
|
||||
category_allowlist=[
|
||||
str(name) for name in pb2_obj.class_name_allowlist
|
||||
],
|
||||
category_denylist=[
|
||||
str(name) for name in pb2_obj.class_name_denylist
|
||||
],
|
||||
category_allowlist=[str(name) for name in pb2_obj.category_allowlist],
|
||||
category_denylist=[str(name) for name in pb2_obj.category_denylist],
|
||||
display_names_locale=pb2_obj.display_names_locale,
|
||||
max_results=pb2_obj.max_results)
|
||||
|
||||
|
|
|
@ -37,6 +37,26 @@ py_test(
|
|||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_classifier_test",
|
||||
srcs = ["image_classifier_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:classifications",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:image_classifier",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gesture_recognizer_test",
|
||||
srcs = ["gesture_recognizer_test.py"],
|
||||
|
|
515
mediapipe/tasks/python/test/vision/image_classifier_test.py
Normal file
515
mediapipe/tasks/python/test/vision/image_classifier_test.py
Normal file
|
@ -0,0 +1,515 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for image classifier."""
|
||||
|
||||
import enum
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from mediapipe.python._framework_bindings import image
|
||||
from mediapipe.tasks.python.components.containers import category
|
||||
from mediapipe.tasks.python.components.containers import classifications as classifications_module
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import image_classifier
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_Category = category.Category
|
||||
_ClassificationEntry = classifications_module.ClassificationEntry
|
||||
_Classifications = classifications_module.Classifications
|
||||
_ClassificationResult = classifications_module.ClassificationResult
|
||||
_Image = image.Image
|
||||
_ImageClassifier = image_classifier.ImageClassifier
|
||||
_ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||
|
||||
_MODEL_FILE = 'mobilenet_v2_1.0_224.tflite'
|
||||
_IMAGE_FILE = 'burger.jpg'
|
||||
_ALLOW_LIST = ['cheeseburger', 'guacamole']
|
||||
_DENY_LIST = ['cheeseburger']
|
||||
_SCORE_THRESHOLD = 0.5
|
||||
_MAX_RESULTS = 3
|
||||
|
||||
|
||||
# TODO: Port assertProtoEquals
|
||||
def _assert_proto_equals(expected, actual): # pylint: disable=unused-argument
|
||||
pass
|
||||
|
||||
|
||||
def _generate_empty_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(categories=[], timestamp_ms=timestamp_ms)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
|
||||
|
||||
def _generate_burger_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=934,
|
||||
score=0.7939587831497192,
|
||||
display_name='',
|
||||
category_name='cheeseburger'),
|
||||
_Category(
|
||||
index=932,
|
||||
score=0.02739289402961731,
|
||||
display_name='',
|
||||
category_name='bagel'),
|
||||
_Category(
|
||||
index=925,
|
||||
score=0.01934075355529785,
|
||||
display_name='',
|
||||
category_name='guacamole'),
|
||||
_Category(
|
||||
index=963,
|
||||
score=0.006327860057353973,
|
||||
display_name='',
|
||||
category_name='meat loaf')
|
||||
],
|
||||
timestamp_ms=timestamp_ms)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
|
||||
|
||||
def _generate_soccer_ball_results(timestamp_ms: int) -> _ClassificationResult:
|
||||
return _ClassificationResult(classifications=[
|
||||
_Classifications(
|
||||
entries=[
|
||||
_ClassificationEntry(
|
||||
categories=[
|
||||
_Category(
|
||||
index=806,
|
||||
score=0.9965274930000305,
|
||||
display_name='',
|
||||
category_name='soccer ball')
|
||||
],
|
||||
timestamp_ms=timestamp_ms)
|
||||
],
|
||||
head_index=0,
|
||||
head_name='probability')
|
||||
])
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class ImageClassifierTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_IMAGE_FILE))
|
||||
self.model_path = test_utils.get_test_data_path(_MODEL_FILE)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _ImageClassifier.create_from_model_path(self.model_path) as classifier:
|
||||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
_ImageClassifier.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _ImageClassifierOptions(base_options=base_options)
|
||||
classifier = _ImageClassifier.create_from_options(options)
|
||||
self.assertIsInstance(classifier, _ImageClassifier)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
||||
def test_classify(self, model_file_type, max_results,
|
||||
expected_classification_result):
|
||||
# Creates classifier.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(max_results=max_results)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, classifier_options=custom_classifier_options)
|
||||
classifier = _ImageClassifier.create_from_options(options)
|
||||
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
# Closes the classifier explicitly when the classifier is not used in
|
||||
# a context.
|
||||
classifier.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME, 4, _generate_burger_results(0)),
|
||||
(ModelFileType.FILE_CONTENT, 4, _generate_burger_results(0)))
|
||||
def test_classify_in_context(self, model_file_type, max_results,
|
||||
expected_classification_result):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(max_results=max_results)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
expected_classification_result.to_pb2())
|
||||
|
||||
def test_classify_succeeds_with_region_of_interest(self):
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=base_options, classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(test_image, roi)
|
||||
# Comparing results.
|
||||
_assert_proto_equals(image_result.to_pb2(),
|
||||
_generate_soccer_ball_results(0).to_pb2())
|
||||
|
||||
def test_score_threshold_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for entry in classification.entries:
|
||||
score = entry.categories[0].score
|
||||
self.assertGreaterEqual(
|
||||
score, _SCORE_THRESHOLD,
|
||||
f'Classification with score lower than threshold found. '
|
||||
f'{classification}')
|
||||
|
||||
def test_max_results_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
score_threshold=_SCORE_THRESHOLD)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
categories = image_result.classifications[0].entries[0].categories
|
||||
|
||||
self.assertLessEqual(
|
||||
len(categories), _MAX_RESULTS, 'Too many results returned.')
|
||||
|
||||
def test_allow_list_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
category_allowlist=_ALLOW_LIST)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for entry in classification.entries:
|
||||
label = entry.categories[0].category_name
|
||||
self.assertIn(label, _ALLOW_LIST,
|
||||
f'Label {label} found but not in label allow list')
|
||||
|
||||
def test_deny_list_option(self):
|
||||
custom_classifier_options = _ClassifierOptions(category_denylist=_DENY_LIST)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
classifications = image_result.classifications
|
||||
|
||||
for classification in classifications:
|
||||
for entry in classification.entries:
|
||||
label = entry.categories[0].category_name
|
||||
self.assertNotIn(label, _DENY_LIST,
|
||||
f'Label {label} found but in deny list.')
|
||||
|
||||
def test_combined_allowlist_and_denylist(self):
|
||||
# Fails with combined allowlist and denylist
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'`category_allowlist` and `category_denylist` are mutually '
|
||||
r'exclusive options.'):
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
category_allowlist=['foo'], category_denylist=['bar'])
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_empty_classification_outputs(self):
|
||||
custom_classifier_options = _ClassifierOptions(score_threshold=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Performs image classification on the input.
|
||||
image_result = classifier.classify(self.test_image)
|
||||
self.assertEmpty(image_result.classifications[0].entries[0].categories)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
with _ImageClassifier.create_from_options(options) as unused_classifier:
|
||||
pass
|
||||
|
||||
def test_calling_classify_for_video_in_image_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_async_in_image_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_calling_classify_in_video_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
classifier.classify(self.test_image)
|
||||
|
||||
def test_calling_classify_async_in_video_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
unused_result = classifier.classify_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_for_video(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=4)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
self.test_image, timestamp)
|
||||
_assert_proto_equals(classification_result.to_pb2(),
|
||||
_generate_burger_results(timestamp).to_pb2())
|
||||
|
||||
def test_classify_for_video_succeeds_with_region_of_interest(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
classifier_options=custom_classifier_options)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
for timestamp in range(0, 300, 30):
|
||||
classification_result = classifier.classify_for_video(
|
||||
test_image, timestamp, roi)
|
||||
self.assertEqual(classification_result,
|
||||
_generate_soccer_ball_results(timestamp))
|
||||
|
||||
def test_calling_classify_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
classifier.classify(self.test_image)
|
||||
|
||||
def test_calling_classify_for_video_in_live_stream_mode(self):
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
classifier.classify_for_video(self.test_image, 0)
|
||||
|
||||
def test_classify_async_calls_with_illegal_timestamp(self):
|
||||
custom_classifier_options = _ClassifierOptions(max_results=4)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
classifier_options=custom_classifier_options,
|
||||
result_callback=mock.MagicMock())
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
classifier.classify_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
classifier.classify_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters((0, _generate_burger_results),
|
||||
(1, _generate_empty_results))
|
||||
def test_classify_async_calls(self, threshold, expected_result_fn):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
_assert_proto_equals(result.to_pb2(),
|
||||
expected_result_fn(timestamp_ms).to_pb2())
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(),
|
||||
self.test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(
|
||||
max_results=4, score_threshold=threshold)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
classifier_options=custom_classifier_options,
|
||||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classifier.classify_async(self.test_image, timestamp)
|
||||
|
||||
def test_classify_async_succeeds_with_region_of_interest(self):
|
||||
# Load the test image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path('multi_objects.jpg'))
|
||||
# NormalizedRect around the soccer ball.
|
||||
roi = _NormalizedRect(
|
||||
x_center=0.532, y_center=0.521, width=0.164, height=0.427)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _ClassificationResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
_assert_proto_equals(result.to_pb2(),
|
||||
_generate_soccer_ball_results(timestamp_ms).to_pb2())
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
custom_classifier_options = _ClassifierOptions(max_results=1)
|
||||
options = _ImageClassifierOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
classifier_options=custom_classifier_options,
|
||||
result_callback=check_result)
|
||||
with _ImageClassifier.create_from_options(options) as classifier:
|
||||
for timestamp in range(0, 300, 30):
|
||||
classifier.classify_async(test_image, timestamp, roi)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -37,6 +37,28 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_classifier",
|
||||
srcs = [
|
||||
"image_classifier.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_py_pb2",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:classifications",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = [
|
||||
|
|
294
mediapipe/tasks/python/vision/image_classifier.py
Normal file
294
mediapipe/tasks/python/vision/image_classifier.py
Normal file
|
@ -0,0 +1,294 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe image classifier task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Mapping, Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
# TODO: Import MPImage directly one we have an alias
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet
|
||||
from mediapipe.python._framework_bindings import task_runner
|
||||
from mediapipe.tasks.cc.components.containers.proto import classifications_pb2
|
||||
from mediapipe.tasks.cc.vision.image_classifier.proto import image_classifier_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import classifications
|
||||
from mediapipe.tasks.python.components.containers import rect
|
||||
from mediapipe.tasks.python.components.processors import classifier_options
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_ImageClassifierGraphOptionsProto = image_classifier_graph_options_pb2.ImageClassifierGraphOptions
|
||||
_ClassifierOptions = classifier_options.ClassifierOptions
|
||||
_RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
_TaskRunner = task_runner.TaskRunner
|
||||
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME = 'classification_result_out'
|
||||
_CLASSIFICATION_RESULT_TAG = 'CLASSIFICATION_RESULT'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
def _build_full_image_norm_rect() -> _NormalizedRect:
|
||||
# Builds a NormalizedRect covering the entire image.
|
||||
return _NormalizedRect(x_center=0.5, y_center=0.5, width=1, height=1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageClassifierOptions:
|
||||
"""Options for the image classifier task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the image classifier task.
|
||||
running_mode: The running mode of the task. Default to the image mode. Image
|
||||
classifier task has three running modes: 1) The image mode for classifying
|
||||
objects on single image inputs. 2) The video mode for classifying objects
|
||||
on the decoded frames of a video. 3) The live stream mode for classifying
|
||||
objects on a live stream of input data, such as from camera.
|
||||
classifier_options: Options for the image classification task.
|
||||
result_callback: The user-defined result callback for processing live stream
|
||||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
classifier_options: _ClassifierOptions = _ClassifierOptions()
|
||||
result_callback: Optional[
|
||||
Callable[[classifications.ClassificationResult, image_module.Image, int],
|
||||
None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageClassifierGraphOptionsProto:
|
||||
"""Generates an ImageClassifierOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
classifier_options_proto = self.classifier_options.to_pb2()
|
||||
|
||||
return _ImageClassifierGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
classifier_options=classifier_options_proto)
|
||||
|
||||
|
||||
class ImageClassifier(base_vision_task_api.BaseVisionTaskApi):
|
||||
"""Class that performs image classification on images."""
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'ImageClassifier':
|
||||
"""Creates an `ImageClassifier` object from a TensorFlow Lite model and the default `ImageClassifierOptions`.
|
||||
|
||||
Note that the created `ImageClassifier` instance is in image mode, for
|
||||
classifying objects on single image inputs.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
|
||||
Returns:
|
||||
`ImageClassifier` object that's created from the model file and the
|
||||
default `ImageClassifierOptions`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `ImageClassifier` object from the provided
|
||||
file such as invalid file path.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = ImageClassifierOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: ImageClassifierOptions) -> 'ImageClassifier':
|
||||
"""Creates the `ImageClassifier` object from image classifier options.
|
||||
|
||||
Args:
|
||||
options: Options for the image classifier task.
|
||||
|
||||
Returns:
|
||||
`ImageClassifier` object that's created from `options`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `ImageClassifier` object from
|
||||
`ImageClassifierOptions` such as missing the model.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
|
||||
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
|
||||
classification_result = classifications.ClassificationResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(classification_result, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([
|
||||
_CLASSIFICATION_RESULT_TAG,
|
||||
_CLASSIFICATION_RESULT_OUT_STREAM_NAME
|
||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
],
|
||||
task_options=options)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
|
||||
# TODO: Replace _NormalizedRect with ImageProcessingOption
|
||||
def classify(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
) -> classifications.ClassificationResult:
|
||||
"""Performs image classification on the provided MediaPipe Image.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
roi: The region of interest.
|
||||
|
||||
Returns:
|
||||
A classification result object that contains a list of classifications.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME: packet_creator.create_image(image),
|
||||
_NORM_RECT_NAME: packet_creator.create_proto(norm_rect.to_pb2())
|
||||
})
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
|
||||
return classifications.ClassificationResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
|
||||
def classify_for_video(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
roi: Optional[_NormalizedRect] = None
|
||||
) -> classifications.ClassificationResult:
|
||||
"""Performs image classification on the provided video frames.
|
||||
|
||||
Only use this method when the ImageClassifier is created with the video
|
||||
running mode. It's required to provide the video frame's timestamp (in
|
||||
milliseconds) along with the video frame. The input timestamps should be
|
||||
monotonically increasing for adjacent calls of this method.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
roi: The region of interest.
|
||||
|
||||
Returns:
|
||||
A classification result object that contains a list of classifications.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If image classification failed to run.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_NAME:
|
||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
|
||||
classification_result_proto = classifications_pb2.ClassificationResult()
|
||||
classification_result_proto.CopyFrom(
|
||||
packet_getter.get_proto(
|
||||
output_packets[_CLASSIFICATION_RESULT_OUT_STREAM_NAME]))
|
||||
|
||||
return classifications.ClassificationResult([
|
||||
classifications.Classifications.create_from_pb2(classification)
|
||||
for classification in classification_result_proto.classifications
|
||||
])
|
||||
|
||||
def classify_async(self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
roi: Optional[_NormalizedRect] = None) -> None:
|
||||
"""Sends live image data (an Image with a unique timestamp) to perform image classification.
|
||||
|
||||
Only use this method when the ImageClassifier is created with the live
|
||||
stream running mode. The input timestamps should be monotonically increasing
|
||||
for adjacent calls of this method. This method will return immediately after
|
||||
the input image is accepted. The results will be available via the
|
||||
`result_callback` provided in the `ImageClassifierOptions`. The
|
||||
`classify_async` method is designed to process live stream data such as
|
||||
camera input. To lower the overall latency, image classifier may drop the
|
||||
input images if needed. In other words, it's not guaranteed to have output
|
||||
per input image.
|
||||
|
||||
The `result_callback` provides:
|
||||
- A classification result object that contains a list of classifications.
|
||||
- The input image that the image classifier runs on.
|
||||
- The input timestamp in milliseconds.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
roi: The region of interest.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the image
|
||||
classifier has already processed.
|
||||
"""
|
||||
norm_rect = roi if roi is not None else _build_full_image_norm_rect()
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_NAME:
|
||||
packet_creator.create_proto(norm_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
4
mediapipe/tasks/testdata/vision/BUILD
vendored
4
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -28,6 +28,8 @@ mediapipe_files(srcs = [
|
|||
"burger_rotated.jpg",
|
||||
"cat.jpg",
|
||||
"cat_mask.jpg",
|
||||
"cat_rotated.jpg",
|
||||
"cat_rotated_mask.jpg",
|
||||
"cats_and_dogs.jpg",
|
||||
"cats_and_dogs_no_resizing.jpg",
|
||||
"cats_and_dogs_rotated.jpg",
|
||||
|
@ -84,6 +86,8 @@ filegroup(
|
|||
"burger_rotated.jpg",
|
||||
"cat.jpg",
|
||||
"cat_mask.jpg",
|
||||
"cat_rotated.jpg",
|
||||
"cat_rotated_mask.jpg",
|
||||
"cats_and_dogs.jpg",
|
||||
"cats_and_dogs_no_resizing.jpg",
|
||||
"cats_and_dogs_rotated.jpg",
|
||||
|
|
|
@ -82,7 +82,8 @@ absl::StatusOr<std::string> PathToResourceAsFile(const std::string& path) {
|
|||
// If that fails, assume it was a relative path, and try just the base name.
|
||||
{
|
||||
const size_t last_slash_idx = path.find_last_of("\\/");
|
||||
CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path.
|
||||
RET_CHECK(last_slash_idx != std::string::npos)
|
||||
<< path << " doesn't have a slash in it"; // Make sure it's a path.
|
||||
auto base_name = path.substr(last_slash_idx + 1);
|
||||
auto status_or_path = PathToResourceAsFileInternal(base_name);
|
||||
if (status_or_path.ok()) {
|
||||
|
|
|
@ -71,7 +71,8 @@ absl::StatusOr<std::string> PathToResourceAsFile(const std::string& path) {
|
|||
// If that fails, assume it was a relative path, and try just the base name.
|
||||
{
|
||||
const size_t last_slash_idx = path.find_last_of("\\/");
|
||||
CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path.
|
||||
RET_CHECK(last_slash_idx != std::string::npos)
|
||||
<< path << " doesn't have a slash in it"; // Make sure it's a path.
|
||||
auto base_name = path.substr(last_slash_idx + 1);
|
||||
auto status_or_path = PathToResourceAsFileInternal(base_name);
|
||||
if (status_or_path.ok()) {
|
||||
|
|
40
third_party/external_files.bzl
vendored
40
third_party/external_files.bzl
vendored
|
@ -76,6 +76,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_mask.jpg?generation=1661875677203533"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cat_rotated_jpg",
|
||||
sha256 = "b78cee5ad14c9f36b1c25d103db371d81ca74d99030063c46a38e80bb8f38649",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated.jpg?generation=1666304165042123"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cat_rotated_mask_jpg",
|
||||
sha256 = "f336973e7621d602f2ebc9a6ab1c62d8502272d391713f369d3b99541afda861",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/cat_rotated_mask.jpg?generation=1666304167148173"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_cats_and_dogs_jpg",
|
||||
sha256 = "a2eaa7ad3a1aae4e623dd362a5f737e8a88d122597ecd1a02b3e1444db56df9c",
|
||||
|
@ -162,8 +174,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
|
||||
sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"],
|
||||
sha256 = "c4dfdcc2e4cd366eb5f8ad227be94049eb593e3a528564611094687912463687",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666629474155924"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -174,8 +186,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
|
||||
sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"],
|
||||
sha256 = "7fb2d33cf69d2da50952a45bad0c0618f30859e608958fee95948a6e0de63ccb",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666629476401757"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -258,8 +270,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
|
||||
sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"],
|
||||
sha256 = "555079c274ea91699757a0b9888c9993a8ab450069103b1bcd4ebb805a8e023c",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666629478777955"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -456,14 +468,14 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_json",
|
||||
sha256 = "0eb285a857b4bb1815736d0902ace0af45ea62e90c1dac98844b9ca797cd0d7b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1665988398778178"],
|
||||
sha256 = "94613ea9539a20a3352604004be6d4d64d4d76250bc9042fcd8685c9a8498517",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224.json?generation=1666633416316646"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilenet_v2_1_0_224_quant_json",
|
||||
sha256 = "932f345ebe3d98daf0dc4c88b0f9e694e450390fb394fc217e851338dfec43e6",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1665988401522527"],
|
||||
sha256 = "3703eadcf838b65bbc2b2aa11dbb1f1bc654c7a09a7aba5ca75a26096484a8ac",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/mobilenet_v2_1.0_224_quant.json?generation=1666633418665507"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -606,8 +618,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
|
||||
sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"],
|
||||
sha256 = "5ec37218d8b613436f5c10121dc689bf9ee69af0656a6ccf8c2e3e8b652e2ad6",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666629486774022"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -798,8 +810,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
|
||||
sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"],
|
||||
sha256 = "6645bbd98ea7f90b3e1ba297e16ea5280847fc5bf5400726d98c282f6c597257",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666629489421733"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user