Add implementation and tests for Image Classifier C API

PiperOrigin-RevId: 574679661
This commit is contained in:
MediaPipe Team 2023-10-18 18:54:41 -07:00 committed by Copybara-Service
parent 364048daca
commit 259fa86c62
8 changed files with 602 additions and 0 deletions

View File

@ -0,0 +1,70 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "image_classifier_lib",
srcs = ["image_classifier.cc"],
hdrs = ["image_classifier.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/tasks/c/components/containers:classification_result",
"//mediapipe/tasks/c/components/containers:classification_result_converter",
"//mediapipe/tasks/c/components/processors:classifier_options",
"//mediapipe/tasks/c/components/processors:classifier_options_converter",
"//mediapipe/tasks/c/core:base_options",
"//mediapipe/tasks/c/core:base_options_converter",
"//mediapipe/tasks/cc/vision/image_classifier",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
],
alwayslink = 1,
)
cc_test(
name = "image_classifier_test",
srcs = ["image_classifier_test.cc"],
data = [
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
],
linkstatic = 1,
deps = [
":image_classifier_lib",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -0,0 +1,147 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h"
#include <memory>
#include <utility>
#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
#include "mediapipe/tasks/c/components/processors/classifier_options_converter.h"
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace mediapipe::tasks::c::vision::image_classifier {
namespace {
using ::mediapipe::tasks::c::components::containers::
CppCloseClassificationResult;
using ::mediapipe::tasks::c::components::containers::
CppConvertToClassificationResult;
using ::mediapipe::tasks::c::components::processors::
CppConvertToClassifierOptions;
using ::mediapipe::tasks::c::core::CppConvertToBaseOptions;
using ::mediapipe::tasks::vision::CreateImageFromBuffer;
using ::mediapipe::tasks::vision::image_classifier::ImageClassifier;
int CppProcessError(absl::Status status, char** error_msg) {
if (error_msg) {
*error_msg = strdup(status.ToString().c_str());
}
return status.raw_code();
}
} // namespace
ImageClassifier* CppImageClassifierCreate(const ImageClassifierOptions& options,
char** error_msg) {
auto cpp_options = std::make_unique<
::mediapipe::tasks::vision::image_classifier::ImageClassifierOptions>();
CppConvertToBaseOptions(options.base_options, &cpp_options->base_options);
CppConvertToClassifierOptions(options.classifier_options,
&cpp_options->classifier_options);
auto classifier = ImageClassifier::Create(std::move(cpp_options));
if (!classifier.ok()) {
ABSL_LOG(ERROR) << "Failed to create ImageClassifier: "
<< classifier.status();
CppProcessError(classifier.status(), error_msg);
return nullptr;
}
return classifier->release();
}
int CppImageClassifierClassify(void* classifier, const MpImage* image,
ImageClassifierResult* result,
char** error_msg) {
if (image->type == MpImage::GPU_BUFFER) {
absl::Status status =
absl::InvalidArgumentError("gpu buffer not supported yet");
ABSL_LOG(ERROR) << "Classification failed: " << status.message();
return CppProcessError(status, error_msg);
}
const auto img = CreateImageFromBuffer(
static_cast<ImageFormat::Format>(image->image_frame.format),
image->image_frame.image_buffer, image->image_frame.width,
image->image_frame.height);
if (!img.ok()) {
ABSL_LOG(ERROR) << "Failed to create Image: " << img.status();
return CppProcessError(img.status(), error_msg);
}
auto cpp_classifier = static_cast<ImageClassifier*>(classifier);
auto cpp_result = cpp_classifier->Classify(*img);
if (!cpp_result.ok()) {
ABSL_LOG(ERROR) << "Classification failed: " << cpp_result.status();
return CppProcessError(cpp_result.status(), error_msg);
}
CppConvertToClassificationResult(*cpp_result, result);
return 0;
}
void CppImageClassifierCloseResult(ImageClassifierResult* result) {
CppCloseClassificationResult(result);
}
int CppImageClassifierClose(void* classifier, char** error_msg) {
auto cpp_classifier = static_cast<ImageClassifier*>(classifier);
auto result = cpp_classifier->Close();
if (!result.ok()) {
ABSL_LOG(ERROR) << "Failed to close ImageClassifier: " << result;
return CppProcessError(result, error_msg);
}
delete cpp_classifier;
return 0;
}
} // namespace mediapipe::tasks::c::vision::image_classifier
extern "C" {
void* image_classifier_create(struct ImageClassifierOptions* options,
char** error_msg) {
return mediapipe::tasks::c::vision::image_classifier::
CppImageClassifierCreate(*options, error_msg);
}
int image_classifier_classify_image(void* classifier, const MpImage* image,
ImageClassifierResult* result,
char** error_msg) {
return mediapipe::tasks::c::vision::image_classifier::
CppImageClassifierClassify(classifier, image, result, error_msg);
}
void image_classifier_close_result(ImageClassifierResult* result) {
mediapipe::tasks::c::vision::image_classifier::CppImageClassifierCloseResult(
result);
}
int image_classifier_close(void* classifier, char** error_ms) {
return mediapipe::tasks::c::vision::image_classifier::CppImageClassifierClose(
classifier, error_ms);
}
} // extern "C"

View File

@ -0,0 +1,132 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
#define MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_
#include <cstdint>
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
#ifndef MP_EXPORT
#define MP_EXPORT __attribute__((visibility("default")))
#endif // MP_EXPORT
#ifdef __cplusplus
extern "C" {
#endif
typedef ClassificationResult ImageClassifierResult;
// Supported image formats.
enum ImageFormat {
UNKNOWN = 0,
SRGB = 1,
SRGBA = 2,
GRAY8 = 3,
SBGRA = 11 // compatible with Flutter `bgra8888` format.
};
// Supported processing modes.
enum RunningMode {
IMAGE = 1,
VIDEO = 2,
LIVE_STREAM = 3,
};
// Structure to hold image frame.
struct ImageFrame {
enum ImageFormat format;
const uint8_t* image_buffer;
int width;
int height;
};
// TODO: Add GPU buffer declaration and proccessing logic for it.
struct GpuBuffer {};
// The object to contain an image, realizes `OneOf` concept.
struct MpImage {
enum { IMAGE_FRAME, GPU_BUFFER } type;
union {
ImageFrame image_frame;
GpuBuffer gpu_buffer;
};
};
// The options for configuring a Mediapipe image classifier task.
struct ImageClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options;
// The running mode of the task. Default to the image mode.
// Image classifier has three running modes:
// 1) The image mode for classifying image on single image inputs.
// 2) The video mode for classifying image on the decoded frames of a video.
// 3) The live stream mode for classifying image on the live stream of input
// data, such as from camera. In this mode, the "result_callback" below must
// be specified to receive the segmentation results asynchronously.
RunningMode running_mode;
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
struct ClassifierOptions classifier_options;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
typedef void (*result_callback_fn)(ImageClassifierResult*, const MpImage*,
int64_t);
result_callback_fn result_callback;
};
// Creates an ImageClassifier from provided `options`.
// Returns a pointer to the image classifier on success.
// If an error occurs, returns `nullptr` and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory
// allocated for the error message.
MP_EXPORT void* image_classifier_create(struct ImageClassifierOptions* options,
char** error_msg = nullptr);
// Performs image classification on the input `image`. Returns `0` on success.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory
// allocated for the error message.
//
// TODO: Add API for video and live stream processing.
MP_EXPORT int image_classifier_classify_image(void* classifier,
const MpImage* image,
ImageClassifierResult* result,
char** error_msg = nullptr);
// Frees the memory allocated inside a ImageClassifierResult result.
// Does not free the result pointer itself.
MP_EXPORT void image_classifier_close_result(ImageClassifierResult* result);
// Frees image classifier.
// If an error occurs, returns an error code and sets the error parameter to an
// an error message (if `error_msg` is not nullptr). You must free the memory
// allocated for the error message.
MP_EXPORT int image_classifier_close(void* classifier,
char** error_msg = nullptr);
#ifdef __cplusplus
} // extern C
#endif
#endif // MEDIAPIPE_TASKS_C_VISION_IMAGE_CLASSIFIER_IMAGE_CLASSIFIER_H_

View File

@ -0,0 +1,132 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/vision/image_classifier/image_classifier.h"
#include <cstdlib>
#include <string>
#include "absl/flags/flag.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::vision::DecodeImageFromFile;
using testing::HasSubstr;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kModelName[] = "mobilenet_v2_1.0_224.tflite";
constexpr float kPrecision = 1e-4;
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
TEST(ImageClassifierTest, SmokeTest) {
const auto image = DecodeImageFromFile(GetFullPath("burger.jpg"));
ASSERT_TRUE(image.ok());
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* classifier_options= */
{/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0},
};
void* classifier = image_classifier_create(&options);
EXPECT_NE(classifier, nullptr);
const MpImage mp_image = {
.type = MpImage::IMAGE_FRAME,
.image_frame = {
.format = static_cast<ImageFormat>(
image->GetImageFrameSharedPtr()->Format()),
.image_buffer = image->GetImageFrameSharedPtr()->PixelData(),
.width = image->GetImageFrameSharedPtr()->Width(),
.height = image->GetImageFrameSharedPtr()->Height()}};
ImageClassifierResult result;
image_classifier_classify_image(classifier, &mp_image, &result);
EXPECT_EQ(result.classifications_count, 1);
EXPECT_EQ(result.classifications[0].categories_count, 1001);
EXPECT_EQ(std::string{result.classifications[0].categories[0].category_name},
"cheeseburger");
EXPECT_NEAR(result.classifications[0].categories[0].score, 0.7939f,
kPrecision);
image_classifier_close_result(&result);
image_classifier_close(classifier);
}
TEST(ImageClassifierTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path.
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ nullptr},
/* classifier_options= */ {},
};
char* error_msg;
void* classifier = image_classifier_create(&options, &error_msg);
EXPECT_EQ(classifier, nullptr);
EXPECT_THAT(error_msg, HasSubstr("ExternalFile must specify"));
free(error_msg);
}
TEST(ImageClassifierTest, FailedClassificationHandling) {
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* classifier_options= */
{/* display_names_locale= */ nullptr,
/* max_results= */ -1,
/* score_threshold= */ 0.0,
/* category_allowlist= */ nullptr,
/* category_allowlist_count= */ 0,
/* category_denylist= */ nullptr,
/* category_denylist_count= */ 0},
};
void* classifier = image_classifier_create(&options);
EXPECT_NE(classifier, nullptr);
const MpImage mp_image = {.type = MpImage::GPU_BUFFER, .gpu_buffer = {}};
ImageClassifierResult result;
char* error_msg;
image_classifier_classify_image(classifier, &mp_image, &result, &error_msg);
EXPECT_THAT(error_msg, HasSubstr("gpu buffer not supported yet"));
free(error_msg);
image_classifier_close(classifier);
}
} // namespace

View File

@ -86,6 +86,20 @@ cc_library(
],
)
cc_test(
name = "image_utils_test",
srcs = ["image_utils_test.cc"],
deps = [
":image_utils",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "landmarks_duplicates_finder",
hdrs = ["landmarks_duplicates_finder.h"],

View File

@ -14,14 +14,17 @@ limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/tensor.h"
#include "stb_image.h"
namespace mediapipe {
@ -64,6 +67,34 @@ absl::StatusOr<Image> DecodeImageFromFile(const std::string& path) {
return Image(std::move(image_frame));
}
absl::StatusOr<Image> CreateImageFromBuffer(ImageFormat::Format format,
const uint8_t* pixel_data,
int width, int height) {
int width_step = 0;
switch (format) {
case ImageFormat::GRAY8:
width_step = width;
break;
case ImageFormat::SRGB:
width_step = 3 * width;
break;
case ImageFormat::SRGBA:
width_step = 4 * width;
break;
case ImageFormat::SBGRA:
width_step = 4 * width;
break;
default:
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image of SRGB, SRGBA or SBGRA format, but found %d.",
format));
}
ImageFrameSharedPtr image_frame = std::make_shared<ImageFrame>(
format, width, height, width_step, const_cast<uint8_t*>(pixel_data),
ImageFrame::PixelDataDeleter::kNone);
return Image(std::move(image_frame));
}
absl::StatusOr<Shape> GetImageLikeTensorShape(const mediapipe::Tensor& tensor) {
int width = 0;
int height = 0;

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_IMAGE_UTILS_H_
#define MEDIAPIPE_TASKS_CC_VISION_UTILS_IMAGE_UTILS_H_
#include <cstdint>
#include <memory>
#include <string>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
@ -43,6 +45,14 @@ struct Shape {
// outside of tests or simple CLI demo tools.
absl::StatusOr<mediapipe::Image> DecodeImageFromFile(const std::string& path);
// Creates an image and returns it as a mediapipe::Image object.
//
// Support a wide range of image formats, namely grayscale (1 channel), RGB (3
// channels) or RGBA (4 channels) and BGRA (4 channels).
absl::StatusOr<Image> CreateImageFromBuffer(ImageFormat::Format format,
const uint8_t* pixel_data,
int width, int height);
// Get the shape of a image-like tensor.
//
// The tensor should have dimension 2, 3 or 4, representing `[height x width]`,

View File

@ -0,0 +1,66 @@
/* Copyright 2022 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include <cstdint>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe::tasks::vision::utils {
namespace {
class ImageUtilsTest : public ::testing::TestWithParam<ImageFormat::Format> {};
TEST_F(ImageUtilsTest, FailedImageFromBuffer) {
constexpr int width = 1;
constexpr int height = 1;
constexpr int max_channels = 1;
const std::vector<uint8_t> buffer(width * height * max_channels, 0);
const ImageFormat::Format format = ImageFormat::UNKNOWN;
const absl::StatusOr<Image> image =
CreateImageFromBuffer(format, &buffer[0], width, height);
EXPECT_EQ(image.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_EQ(image.status().message(),
"Expected image of SRGB, SRGBA or SBGRA format, but found 0.");
}
TEST_P(ImageUtilsTest, SuccessfulImageFromBuffer) {
constexpr int width = 4;
constexpr int height = 4;
constexpr int max_channels = 4;
const std::vector<uint8_t> buffer(width * height * max_channels, 0);
const ImageFormat::Format format = GetParam();
const absl::StatusOr<Image> image =
CreateImageFromBuffer(format, &buffer[0], width, height);
EXPECT_TRUE(image.status().ok());
EXPECT_EQ(image->GetImageFrameSharedPtr()->Format(), format);
EXPECT_EQ(image->GetImageFrameSharedPtr()->Width(), width);
EXPECT_EQ(image->GetImageFrameSharedPtr()->Height(), height);
}
INSTANTIATE_TEST_SUITE_P(ImageUtilsTests, ImageUtilsTest,
testing::Values(ImageFormat::SRGB, ImageFormat::SRGBA,
ImageFormat::SBGRA));
} // namespace
} // namespace mediapipe::tasks::vision::utils