Add tests for C API containers

PiperOrigin-RevId: 569526282
This commit is contained in:
Sebastian Schmidt 2023-09-29 10:24:35 -07:00 committed by Copybara-Service
parent d4561fb5c2
commit 6915a79e28
6 changed files with 200 additions and 8 deletions

View File

@ -31,6 +31,18 @@ cc_library(
],
)
cc_test(
name = "category_converter_test",
srcs = ["category_converter_test.cc"],
deps = [
":category",
":category_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:category",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "classification_result",
hdrs = ["classification_result.h"],
@ -47,3 +59,15 @@ cc_library(
"//mediapipe/tasks/cc/components/containers:classification_result",
],
)
cc_test(
name = "classification_result_converter_test",
srcs = ["classification_result_converter_test.cc"],
deps = [
":classification_result",
":classification_result_converter",
"//mediapipe/framework/port:gtest",
"//mediapipe/tasks/cc/components/containers:classification_result",
"@com_google_googletest//:gtest_main",
],
)

View File

@ -36,11 +36,11 @@ struct Category {
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
const char* category_name;
char* category_name;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
const char* display_name;
char* display_name;
};
#ifdef __cplusplus

View File

@ -0,0 +1,63 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/category_converter.h"
#include <cstdlib>
#include <optional>
#include <string>
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
namespace mediapipe::tasks::c::components::containers {
TEST(CategoryConverterTest, ConvertsCategoryCustomValues) {
mediapipe::tasks::components::containers::Category cpp_category = {
/* index= */ 1,
/* score= */ 0.1,
/* category_name= */ "category_name",
/* display_name= */ "display_name",
};
Category c_category;
CppConvertToCategory(cpp_category, &c_category);
EXPECT_EQ(c_category.index, 1);
EXPECT_FLOAT_EQ(c_category.score, 0.1);
EXPECT_EQ(std::string{c_category.category_name}, "category_name");
EXPECT_EQ(std::string{c_category.display_name}, "display_name");
free(c_category.category_name);
free(c_category.display_name);
}
TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
mediapipe::tasks::components::containers::Category cpp_category = {
/* index= */ 1,
/* score= */ 0.1,
/* category_name= */ std::nullopt,
/* display_name= */ std::nullopt,
};
Category c_category;
CppConvertToCategory(cpp_category, &c_category);
EXPECT_EQ(c_category.index, 1);
EXPECT_FLOAT_EQ(c_category.score, 0.1);
EXPECT_EQ(c_category.category_name, nullptr);
EXPECT_EQ(c_category.display_name, nullptr);
}
} // namespace mediapipe::tasks::c::components::containers

View File

@ -39,7 +39,7 @@ struct Classifications {
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
const char* head_name;
char* head_name;
};
// Defines classification results of a model.

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mediapipe/tasks/c/components/containers/category.h"
#include "mediapipe/tasks/c/components/containers/category_converter.h"
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
namespace mediapipe::tasks::c::components::containers {
@ -27,12 +28,12 @@ void CppConvertToClassificationResult(
const mediapipe::tasks::components::containers::ClassificationResult& in,
ClassificationResult* out) {
out->has_timestamp_ms = in.timestamp_ms.has_value();
if (out->has_timestamp_ms) {
out->timestamp_ms = in.timestamp_ms.value();
}
out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;
out->classifications_count = in.classifications.size();
out->classifications = new Classifications[out->classifications_count];
out->classifications = out->classifications_count
? new Classifications[out->classifications_count]
: nullptr;
for (uint32_t i = 0; i < out->classifications_count; ++i) {
auto classification_in = in.classifications[i];
@ -40,7 +41,9 @@ void CppConvertToClassificationResult(
classification_out.categories_count = classification_in.categories.size();
classification_out.categories =
new Category[classification_out.categories_count];
classification_out.categories_count
? new Category[classification_out.categories_count]
: nullptr;
for (uint32_t j = 0; j < classification_out.categories_count; ++j) {
CppConvertToCategory(classification_in.categories[j],
&(classification_out.categories[j]));

View File

@ -0,0 +1,102 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
#include <cstdlib>
#include <optional>
#include <string>
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
namespace mediapipe::tasks::c::components::containers {
TEST(ClassificationResultConverterTest,
ConvertsClassificationResulCustomCategory) {
mediapipe::tasks::components::containers::ClassificationResult
cpp_classification_result = {
/* classifications= */ {{/* categories= */ {{
/* index= */ 1,
/* score= */ 0.1,
/* category_name= */ std::nullopt,
/* display_name= */ std::nullopt,
}},
/* head_index= */ 0,
/* head_name= */ "foo"}},
/* timestamp_ms= */ 42,
};
ClassificationResult c_classification_result;
CppConvertToClassificationResult(cpp_classification_result,
&c_classification_result);
EXPECT_NE(c_classification_result.classifications, nullptr);
EXPECT_EQ(c_classification_result.classifications_count, 1);
EXPECT_NE(c_classification_result.classifications[0].categories, nullptr);
EXPECT_EQ(c_classification_result.classifications[0].categories_count, 1);
EXPECT_EQ(c_classification_result.classifications[0].head_index, 0);
EXPECT_EQ(std::string(c_classification_result.classifications[0].head_name),
"foo");
EXPECT_EQ(c_classification_result.timestamp_ms, 42);
EXPECT_EQ(c_classification_result.has_timestamp_ms, true);
free(c_classification_result.classifications[0].categories);
free(c_classification_result.classifications[0].head_name);
free(c_classification_result.classifications);
}
TEST(ClassificationResultConverterTest,
ConvertsClassificationResulEmptyCategory) {
mediapipe::tasks::components::containers::ClassificationResult
cpp_classification_result = {
/* classifications= */ {{/* categories= */ {}, /* head_index= */ 0,
/* head_name= */ std::nullopt}},
/* timestamp_ms= */ std::nullopt,
};
ClassificationResult c_classification_result;
CppConvertToClassificationResult(cpp_classification_result,
&c_classification_result);
EXPECT_NE(c_classification_result.classifications, nullptr);
EXPECT_EQ(c_classification_result.classifications_count, 1);
EXPECT_EQ(c_classification_result.classifications[0].categories, nullptr);
EXPECT_EQ(c_classification_result.classifications[0].categories_count, 0);
EXPECT_EQ(c_classification_result.classifications[0].head_index, 0);
EXPECT_EQ(c_classification_result.classifications[0].head_name, nullptr);
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
free(c_classification_result.classifications);
}
TEST(ClassificationResultConverterTest,
ConvertsClassificationResultNoCategory) {
mediapipe::tasks::components::containers::ClassificationResult
cpp_classification_result = {
/* classifications= */ {},
/* timestamp_ms= */ std::nullopt,
};
ClassificationResult c_classification_result;
CppConvertToClassificationResult(cpp_classification_result,
&c_classification_result);
EXPECT_EQ(c_classification_result.classifications, nullptr);
EXPECT_EQ(c_classification_result.classifications_count, 0);
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
}
} // namespace mediapipe::tasks::c::components::containers