Pass Model Asset Buffer as byte array + length

PiperOrigin-RevId: 579944283
This commit is contained in:
Sebastian Schmidt 2023-11-06 13:39:48 -08:00 committed by Copybara-Service
parent 5f0d24d741
commit 077b52250d
8 changed files with 22 additions and 6 deletions

View File

@ -22,9 +22,12 @@ extern "C" {
// Base options for MediaPipe C Tasks.
struct BaseOptions {
// The model asset file contents as a string.
// The model asset file contents as bytes.
const char* model_asset_buffer;
// The size of the model assets buffer (or `0` if not set).
const unsigned int model_asset_buffer_count;
// The path to the model asset to open and mmap in memory.
const char* model_asset_path;
};

View File

@ -27,7 +27,9 @@ void CppConvertToBaseOptions(const BaseOptions& in,
mediapipe::tasks::core::BaseOptions* out) {
out->model_asset_buffer =
in.model_asset_buffer
? std::make_unique<std::string>(in.model_asset_buffer)
? std::make_unique<std::string>(
in.model_asset_buffer,
in.model_asset_buffer + in.model_asset_buffer_count)
: nullptr;
out->model_asset_path =
in.model_asset_path ? std::string(in.model_asset_path) : "";

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "mediapipe/tasks/c/core/base_options_converter.h"
#include <cstring>
#include <string>
#include "mediapipe/framework/port/gtest.h"
@ -28,6 +29,8 @@ constexpr char kModelAssetPath[] = "abc.tflite";
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
BaseOptions c_base_options = {/* model_asset_buffer= */ kAssetBuffer,
/* model_asset_buffer_count= */
static_cast<unsigned int>(strlen(kAssetBuffer)),
/* model_asset_path= */ nullptr};
mediapipe::tasks::core::BaseOptions cpp_base_options = {};
@ -39,6 +42,7 @@ TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetBuffer) {
TEST(BaseOptionsConverterTest, ConvertsBaseOptionsAssetPath) {
BaseOptions c_base_options = {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ kModelAssetPath};
mediapipe::tasks::core::BaseOptions cpp_base_options = {};

View File

@ -44,6 +44,7 @@ TEST(LanguageDetectorTest, SmokeTest) {
std::string model_path = GetFullPath(kTestLanguageDetectorModelPath);
LanguageDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* classifier_options= */
{/* display_names_locale= */ nullptr,
@ -71,6 +72,7 @@ TEST(LanguageDetectorTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path.
LanguageDetectorOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
/* classifier_options= */ {},
};

View File

@ -43,6 +43,7 @@ TEST(TextClassifierTest, SmokeTest) {
std::string model_path = GetFullPath(kTestBertModelPath);
TextClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* classifier_options= */
{/* display_names_locale= */ nullptr,
@ -74,6 +75,7 @@ TEST(TextClassifierTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path.
TextClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
/* classifier_options= */ {},
};

View File

@ -42,6 +42,7 @@ TEST(TextEmbedderTest, SmokeTest) {
std::string model_path = GetFullPath(kTestBertModelPath);
TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* embedder_options= */
{/* l2_normalize= */ false, /* quantize= */ true},
@ -63,6 +64,7 @@ TEST(TextEmbedderTest, ErrorHandling) {
// It is an error to set neither the asset buffer nor the path.
TextEmbedderOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
/* embedder_options= */ {},
};

View File

@ -55,11 +55,7 @@ cc_test(
":image_classifier_lib",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/tasks/c/components/containers:category",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/flags:flag",

View File

@ -50,6 +50,7 @@ TEST(ImageClassifierTest, ImageModeTest) {
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* classifier_options= */
@ -92,6 +93,7 @@ TEST(ImageClassifierTest, VideoModeTest) {
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::VIDEO,
/* classifier_options= */
@ -164,6 +166,7 @@ TEST(ImageClassifierTest, LiveStreamModeTest) {
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::LIVE_STREAM,
/* classifier_options= */
@ -203,6 +206,7 @@ TEST(ImageClassifierTest, InvalidArgumentHandling) {
// It is an error to set neither the asset buffer nor the path.
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ nullptr},
/* classifier_options= */ {},
};
@ -220,6 +224,7 @@ TEST(ImageClassifierTest, FailedClassificationHandling) {
const std::string model_path = GetFullPath(kModelName);
ImageClassifierOptions options = {
/* base_options= */ {/* model_asset_buffer= */ nullptr,
/* model_asset_buffer_count= */ 0,
/* model_asset_path= */ model_path.c_str()},
/* running_mode= */ RunningMode::IMAGE,
/* classifier_options= */