Add tests for C API containers
PiperOrigin-RevId: 569526282
This commit is contained in:
parent
d4561fb5c2
commit
6915a79e28
|
@ -31,6 +31,18 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "category_converter_test",
|
||||||
|
srcs = ["category_converter_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
":category_converter",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:category",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "classification_result",
|
name = "classification_result",
|
||||||
hdrs = ["classification_result.h"],
|
hdrs = ["classification_result.h"],
|
||||||
|
@ -47,3 +59,15 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "classification_result_converter_test",
|
||||||
|
srcs = ["classification_result_converter_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":classification_result",
|
||||||
|
":classification_result_converter",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -36,11 +36,11 @@ struct Category {
|
||||||
|
|
||||||
// The optional ID for the category, read from the label map packed in the
|
// The optional ID for the category, read from the label map packed in the
|
||||||
// TFLite Model Metadata if present. Not necessarily human-readable.
|
// TFLite Model Metadata if present. Not necessarily human-readable.
|
||||||
const char* category_name;
|
char* category_name;
|
||||||
|
|
||||||
// The optional human-readable name for the category, read from the label map
|
// The optional human-readable name for the category, read from the label map
|
||||||
// packed in the TFLite Model Metadata if present.
|
// packed in the TFLite Model Metadata if present.
|
||||||
const char* display_name;
|
char* display_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::containers {
|
||||||
|
|
||||||
|
TEST(CategoryConverterTest, ConvertsCategoryCustomValues) {
|
||||||
|
mediapipe::tasks::components::containers::Category cpp_category = {
|
||||||
|
/* index= */ 1,
|
||||||
|
/* score= */ 0.1,
|
||||||
|
/* category_name= */ "category_name",
|
||||||
|
/* display_name= */ "display_name",
|
||||||
|
};
|
||||||
|
|
||||||
|
Category c_category;
|
||||||
|
CppConvertToCategory(cpp_category, &c_category);
|
||||||
|
EXPECT_EQ(c_category.index, 1);
|
||||||
|
EXPECT_FLOAT_EQ(c_category.score, 0.1);
|
||||||
|
EXPECT_EQ(std::string{c_category.category_name}, "category_name");
|
||||||
|
EXPECT_EQ(std::string{c_category.display_name}, "display_name");
|
||||||
|
|
||||||
|
free(c_category.category_name);
|
||||||
|
free(c_category.display_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CategoryConverterTest, ConvertsCategoryDefaultValues) {
|
||||||
|
mediapipe::tasks::components::containers::Category cpp_category = {
|
||||||
|
/* index= */ 1,
|
||||||
|
/* score= */ 0.1,
|
||||||
|
/* category_name= */ std::nullopt,
|
||||||
|
/* display_name= */ std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
Category c_category;
|
||||||
|
CppConvertToCategory(cpp_category, &c_category);
|
||||||
|
EXPECT_EQ(c_category.index, 1);
|
||||||
|
EXPECT_FLOAT_EQ(c_category.score, 0.1);
|
||||||
|
EXPECT_EQ(c_category.category_name, nullptr);
|
||||||
|
EXPECT_EQ(c_category.display_name, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::containers
|
|
@ -39,7 +39,7 @@ struct Classifications {
|
||||||
// Metadata [1] if present. This is useful for multi-head models.
|
// Metadata [1] if present. This is useful for multi-head models.
|
||||||
//
|
//
|
||||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
const char* head_name;
|
char* head_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Defines classification results of a model.
|
// Defines classification results of a model.
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/tasks/c/components/containers/category.h"
|
#include "mediapipe/tasks/c/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
#include "mediapipe/tasks/c/components/containers/category_converter.h"
|
||||||
|
#include "mediapipe/tasks/c/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::c::components::containers {
|
namespace mediapipe::tasks::c::components::containers {
|
||||||
|
@ -27,12 +28,12 @@ void CppConvertToClassificationResult(
|
||||||
const mediapipe::tasks::components::containers::ClassificationResult& in,
|
const mediapipe::tasks::components::containers::ClassificationResult& in,
|
||||||
ClassificationResult* out) {
|
ClassificationResult* out) {
|
||||||
out->has_timestamp_ms = in.timestamp_ms.has_value();
|
out->has_timestamp_ms = in.timestamp_ms.has_value();
|
||||||
if (out->has_timestamp_ms) {
|
out->timestamp_ms = out->has_timestamp_ms ? in.timestamp_ms.value() : 0;
|
||||||
out->timestamp_ms = in.timestamp_ms.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
out->classifications_count = in.classifications.size();
|
out->classifications_count = in.classifications.size();
|
||||||
out->classifications = new Classifications[out->classifications_count];
|
out->classifications = out->classifications_count
|
||||||
|
? new Classifications[out->classifications_count]
|
||||||
|
: nullptr;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < out->classifications_count; ++i) {
|
for (uint32_t i = 0; i < out->classifications_count; ++i) {
|
||||||
auto classification_in = in.classifications[i];
|
auto classification_in = in.classifications[i];
|
||||||
|
@ -40,7 +41,9 @@ void CppConvertToClassificationResult(
|
||||||
|
|
||||||
classification_out.categories_count = classification_in.categories.size();
|
classification_out.categories_count = classification_in.categories.size();
|
||||||
classification_out.categories =
|
classification_out.categories =
|
||||||
new Category[classification_out.categories_count];
|
classification_out.categories_count
|
||||||
|
? new Category[classification_out.categories_count]
|
||||||
|
: nullptr;
|
||||||
for (uint32_t j = 0; j < classification_out.categories_count; ++j) {
|
for (uint32_t j = 0; j < classification_out.categories_count; ++j) {
|
||||||
CppConvertToCategory(classification_in.categories[j],
|
CppConvertToCategory(classification_in.categories[j],
|
||||||
&(classification_out.categories[j]));
|
&(classification_out.categories[j]));
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/c/components/containers/classification_result_converter.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/c/components/containers/classification_result.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::c::components::containers {
|
||||||
|
|
||||||
|
TEST(ClassificationResultConverterTest,
|
||||||
|
ConvertsClassificationResulCustomCategory) {
|
||||||
|
mediapipe::tasks::components::containers::ClassificationResult
|
||||||
|
cpp_classification_result = {
|
||||||
|
/* classifications= */ {{/* categories= */ {{
|
||||||
|
/* index= */ 1,
|
||||||
|
/* score= */ 0.1,
|
||||||
|
/* category_name= */ std::nullopt,
|
||||||
|
/* display_name= */ std::nullopt,
|
||||||
|
}},
|
||||||
|
/* head_index= */ 0,
|
||||||
|
/* head_name= */ "foo"}},
|
||||||
|
/* timestamp_ms= */ 42,
|
||||||
|
};
|
||||||
|
|
||||||
|
ClassificationResult c_classification_result;
|
||||||
|
CppConvertToClassificationResult(cpp_classification_result,
|
||||||
|
&c_classification_result);
|
||||||
|
EXPECT_NE(c_classification_result.classifications, nullptr);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications_count, 1);
|
||||||
|
EXPECT_NE(c_classification_result.classifications[0].categories, nullptr);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications[0].categories_count, 1);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications[0].head_index, 0);
|
||||||
|
EXPECT_EQ(std::string(c_classification_result.classifications[0].head_name),
|
||||||
|
"foo");
|
||||||
|
EXPECT_EQ(c_classification_result.timestamp_ms, 42);
|
||||||
|
EXPECT_EQ(c_classification_result.has_timestamp_ms, true);
|
||||||
|
|
||||||
|
free(c_classification_result.classifications[0].categories);
|
||||||
|
free(c_classification_result.classifications[0].head_name);
|
||||||
|
free(c_classification_result.classifications);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ClassificationResultConverterTest,
|
||||||
|
ConvertsClassificationResulEmptyCategory) {
|
||||||
|
mediapipe::tasks::components::containers::ClassificationResult
|
||||||
|
cpp_classification_result = {
|
||||||
|
/* classifications= */ {{/* categories= */ {}, /* head_index= */ 0,
|
||||||
|
/* head_name= */ std::nullopt}},
|
||||||
|
/* timestamp_ms= */ std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
ClassificationResult c_classification_result;
|
||||||
|
CppConvertToClassificationResult(cpp_classification_result,
|
||||||
|
&c_classification_result);
|
||||||
|
EXPECT_NE(c_classification_result.classifications, nullptr);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications_count, 1);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications[0].categories, nullptr);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications[0].categories_count, 0);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications[0].head_index, 0);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications[0].head_name, nullptr);
|
||||||
|
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
||||||
|
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
||||||
|
|
||||||
|
free(c_classification_result.classifications);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ClassificationResultConverterTest,
|
||||||
|
ConvertsClassificationResultNoCategory) {
|
||||||
|
mediapipe::tasks::components::containers::ClassificationResult
|
||||||
|
cpp_classification_result = {
|
||||||
|
/* classifications= */ {},
|
||||||
|
/* timestamp_ms= */ std::nullopt,
|
||||||
|
};
|
||||||
|
|
||||||
|
ClassificationResult c_classification_result;
|
||||||
|
CppConvertToClassificationResult(cpp_classification_result,
|
||||||
|
&c_classification_result);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications, nullptr);
|
||||||
|
EXPECT_EQ(c_classification_result.classifications_count, 0);
|
||||||
|
EXPECT_EQ(c_classification_result.timestamp_ms, 0);
|
||||||
|
EXPECT_EQ(c_classification_result.has_timestamp_ms, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::c::components::containers
|
Loading…
Reference in New Issue
Block a user