Add classification result structs and use then in ImageClassifier & TextClassifier.

PiperOrigin-RevId: 485567133
This commit is contained in:
MediaPipe Team 2022-11-02 05:11:16 -07:00 committed by Copybara-Service
parent aab5f84aae
commit d61ab92b90
15 changed files with 559 additions and 276 deletions

View File

@ -29,3 +29,23 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
], ],
) )
cc_library(
name = "category",
srcs = ["category.cc"],
hdrs = ["category.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
],
)
cc_library(
name = "classification_result",
srcs = ["classification_result.cc"],
hdrs = ["classification_result.h"],
deps = [
":category",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/category.h"
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
Category ConvertToCategory(const mediapipe::Classification& proto) {
Category category;
category.index = proto.index();
category.score = proto.score();
if (proto.has_label()) {
category.category_name = proto.label();
}
if (proto.has_display_name()) {
category.display_name = proto.display_name();
}
return category;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines a single classification result.
//
// The label maps packed into the TFLite Model Metadata [1] are used to populate
// the 'category_name' and 'display_name' fields.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
struct Category {
// The index of the category in the classification model output.
int index;
// The score for this category, e.g. (but not necessarily) a probability in
// [0,1].
float score;
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
std::optional<std::string> category_name = std::nullopt;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
std::optional<std::string> display_name = std::nullopt;
};
// Utility function to convert from mediapipe::Classification proto to Category
// struct.
Category ConvertToCategory(const mediapipe::Classification& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
Classifications ConvertToClassifications(const proto::Classifications& proto) {
Classifications classifications;
classifications.categories.reserve(
proto.classification_list().classification_size());
for (const auto& classification :
proto.classification_list().classification()) {
classifications.categories.push_back(ConvertToCategory(classification));
}
classifications.head_index = proto.head_index();
if (proto.has_head_name()) {
classifications.head_name = proto.head_name();
}
return classifications;
}
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto) {
ClassificationResult classification_result;
classification_result.classifications.reserve(proto.classifications_size());
for (const auto& classifications : proto.classifications()) {
classification_result.classifications.push_back(
ConvertToClassifications(classifications));
}
if (proto.has_timestamp_ms()) {
classification_result.timestamp_ms = proto.timestamp_ms();
}
return classification_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,68 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines classification results for a given classifier head.
struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
// e.g. from high to low probability.
std::vector<Category> categories;
// The index of the classifier head (i.e. output tensor) these categories
// refer to. This is useful for multi-head models.
int head_index;
// The optional name of the classifier head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines classification results of a model.
struct ClassificationResult {
// The classification results for each head of the model.
std::vector<Classifications> classifications;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Classifications proto to
// Classifications struct.
Classifications ConvertToClassifications(const proto::Classifications& proto);
// Utility function to convert from ClassificationResult proto to
// ClassificationResult struct.
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

View File

@ -49,6 +49,8 @@ cc_library(
":text_classifier_graph", ":text_classifier_graph",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/components/containers:category",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
@ -76,7 +78,8 @@ cc_test(
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers:category",
"//mediapipe/tasks/cc/components/containers:classification_result",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h" #include "mediapipe/tasks/cc/core/task_api_factory.h"
@ -37,12 +38,13 @@ namespace text_classifier {
namespace { namespace {
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kTextStreamName[] = "text_in"; constexpr char kTextStreamName[] = "text_in";
constexpr char kTextTag[] = "TEXT"; constexpr char kTextTag[] = "TEXT";
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationsStreamName[] = "classifications_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph"; "mediapipe.tasks.text.text_classifier.TextClassifierGraph";
@ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig(
auto& subgraph = graph.AddNode(kSubgraphTypeName); auto& subgraph = graph.AddNode(kSubgraphTypeName);
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get()); subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
subgraph.Out(kClassificationResultTag) subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
.SetName(kClassificationResultStreamName) >> graph.Out(kClassificationsTag);
graph.Out(kClassificationResultTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -88,14 +89,14 @@ absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
std::move(options->base_options.op_resolver)); std::move(options->base_options.op_resolver));
} }
absl::StatusOr<ClassificationResult> TextClassifier::Classify( absl::StatusOr<TextClassifierResult> TextClassifier::Classify(
absl::string_view text) { absl::string_view text) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
runner_->Process( runner_->Process(
{{kTextStreamName, MakePacket<std::string>(std::string(text))}})); {{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
return output_packets[kClassificationResultStreamName] return ConvertToClassificationResult(
.Get<ClassificationResult>(); output_packets[kClassificationsStreamName].Get<ClassificationResult>());
} }
} // namespace text_classifier } // namespace text_classifier

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/base_task_api.h"
@ -31,6 +31,10 @@ namespace tasks {
namespace text { namespace text {
namespace text_classifier { namespace text_classifier {
// Alias the shared ClassificationResult struct as result type.
using TextClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a MediaPipe text classifier task. // The options for configuring a MediaPipe text classifier task.
struct TextClassifierOptions { struct TextClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model // Base options for configuring MediaPipe Tasks, such as specifying the model
@ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi {
std::unique_ptr<TextClassifierOptions> options); std::unique_ptr<TextClassifierOptions> options);
// Performs classification on the input `text`. // Performs classification on the input `text`.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<TextClassifierResult> Classify(absl::string_view text);
absl::string_view text);
// Shuts down the TextClassifier when all the work is done. // Shuts down the TextClassifier when all the work is done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }

View File

@ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTextTag[] = "TEXT"; constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
// TODO: remove once Java API migration is over.
// Struct holding the different output streams produced by the text classifier.
struct TextClassifierOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
};
} // namespace } // namespace
// A "TextClassifierGraph" performs Natural Language classification (including // A "TextClassifierGraph" performs Natural Language classification (including
@ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS";
// Input text to perform classification on. // Input text to perform classification on.
// //
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by classifier head.
// TODO: remove once Java API migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result object that has 3 dimensions: // The aggregated classification result object that has 3 dimensions:
// (classification head, classification timestamp, classification category). // (classification head, classification timestamp, classification category).
// //
@ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS";
// node { // node {
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph" // calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
// input_stream: "TEXT:text_in" // input_stream: "TEXT:text_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATIONS:classifications_out"
// options { // options {
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext] // [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
// { // {
@ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph {
CreateModelResources<proto::TextClassifierGraphOptions>(sc)); CreateModelResources<proto::TextClassifierGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
Source<ClassificationResult> classification_result_out, auto output_streams,
BuildTextClassifierTask( BuildTextClassifierTask(
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources, sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
graph[Input<std::string>(kTextTag)], graph)); graph[Input<std::string>(kTextTag)], graph));
classification_result_out >> output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)]; graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
// TextClassifier model file with model metadata. // TextClassifier model file with model metadata.
// text_in: (std::string) stream to run text classification on. // text_in: (std::string) stream to run text classification on.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask( absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
const proto::TextClassifierGraphOptions& task_options, const proto::TextClassifierGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in, const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) { Graph& graph) {
@ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
// Outputs the aggregated classification result as the subgraph output // Outputs the aggregated classification result as the subgraph output
// stream. // stream.
return postprocessing[Output<ClassificationResult>( return TextClassifierOutputStreams{
kClassificationResultTag)]; /*classification_result=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)],
/*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationsTag)]};
} }
}; };

View File

@ -33,7 +33,8 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -43,14 +44,13 @@ namespace text {
namespace text_classifier { namespace text_classifier {
namespace { namespace {
using ::mediapipe::EqualsProto;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::kMediaPipeTasksPayload; using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
constexpr float kEpsilon = 0.001;
constexpr int kMaxSeqLen = 128; constexpr int kMaxSeqLen = 128;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
@ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name); return JoinPath("./", kTestDataDirectory, file_name);
} }
// Checks that the two provided `TextClassifierResult` are equal, with a
// tolerancy on floating-point score to account for numerical instabilities.
// TODO: create shared matcher for ClassificationResult.
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
const TextClassifierResult& expected) {
const float kPrecision = 1e-6;
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
for (int i = 0; i < actual.classifications.size(); ++i) {
const Classifications& a = actual.classifications[i];
const Classifications& b = expected.classifications[i];
EXPECT_EQ(a.head_index, b.head_index);
EXPECT_EQ(a.head_name, b.head_name);
EXPECT_EQ(a.categories.size(), b.categories.size());
for (int j = 0; j < a.categories.size(); ++j) {
const Category& x = a.categories[j];
const Category& y = b.categories[j];
EXPECT_EQ(x.index, y.index);
EXPECT_NEAR(x.score, y.score, kPrecision);
EXPECT_EQ(x.category_name, y.category_name);
EXPECT_EQ(x.display_name, y.display_name);
}
}
}
class TextClassifierTest : public tflite_shims::testing::Test {}; class TextClassifierTest : public tflite_shims::testing::Test {};
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
@ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
MP_ASSERT_OK(TextClassifier::Create(std::move(options))); MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
} }
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
TextClassifierResult negative_expected;
negative_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(negative_result, negative_expected);
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
TextClassifierResult positive_expected;
positive_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(positive_result, positive_expected);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result,
classifier->Classify("What a waste of my time."));
TextClassifierResult negative_expected;
negative_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"},
{/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(negative_result, negative_expected);
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years."
"Strongly recommend it!"));
TextClassifierResult positive_expected;
positive_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"},
{/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(positive_result, positive_expected);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
classifier->Classify("hello"));
// Binary outputs causes flaky ordering, so we compare manually.
ASSERT_EQ(result.classifications.size(), 1);
ASSERT_EQ(result.classifications[0].head_index, 0);
ASSERT_EQ(result.classifications[0].categories.size(), 3);
ASSERT_EQ(result.classifications[0].categories[0].score, 1);
ASSERT_LT(result.classifications[0].categories[0].index, 2); // i.e O or 1.
ASSERT_EQ(result.classifications[0].categories[1].score, 1);
ASSERT_LT(result.classifications[0].categories[1].index, 2); // i.e 0 or 1.
ASSERT_EQ(result.classifications[0].categories[2].score, 0);
ASSERT_EQ(result.classifications[0].categories[2].index, 2);
MP_ASSERT_OK(classifier->Close());
}
} // namespace } // namespace
} // namespace text_classifier } // namespace text_classifier
} // namespace text } // namespace text

View File

@ -50,6 +50,7 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -46,8 +47,8 @@ namespace image_classifier {
namespace { namespace {
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationsStreamName[] = "classifications_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
@ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
@ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig(
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap( task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
options_proto.get()); options_proto.get());
task_subgraph.Out(kClassificationResultTag) task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
.SetName(kClassificationResultStreamName) >> graph.Out(kClassificationsTag);
graph.Out(kClassificationResultTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag); graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, return tasks::core::AddFlowLimiterCalculator(
{kImageTag, kNormRectTag}, graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag);
kClassificationResultTag);
} }
graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
@ -125,13 +125,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return; return;
} }
Packet classification_result_packet = Packet classifications_packet =
status_or_packets.value()[kClassificationResultStreamName]; status_or_packets.value()[kClassificationsStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback( result_callback(
classification_result_packet.Get<ClassificationResult>(), ConvertToClassificationResult(
classifications_packet.Get<ClassificationResult>()),
image_packet.Get<Image>(), image_packet.Get<Image>(),
classification_result_packet.Timestamp().Value() / classifications_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond); kMicroSecondsPerMilliSecond);
}; };
} }
@ -144,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<ClassificationResult> ImageClassifier::Classify( absl::StatusOr<ImageClassifierResult> ImageClassifier::Classify(
Image image, Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -160,11 +161,11 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
ProcessImageData( ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))}, {{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}})); {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kClassificationResultStreamName] return ConvertToClassificationResult(
.Get<ClassificationResult>(); output_packets[kClassificationsStreamName].Get<ClassificationResult>());
} }
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo( absl::StatusOr<ImageClassifierResult> ImageClassifier::ClassifyForVideo(
Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -184,8 +185,8 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
{kNormRectName, {kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kClassificationResultStreamName] return ConvertToClassificationResult(
.Get<ClassificationResult>(); output_packets[kClassificationsStreamName].Get<ClassificationResult>());
} }
absl::Status ImageClassifier::ClassifyAsync( absl::Status ImageClassifier::ClassifyAsync(

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
@ -34,6 +34,10 @@ namespace tasks {
namespace vision { namespace vision {
namespace image_classifier { namespace image_classifier {
// Alias the shared ClassificationResult struct as result type.
using ImageClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a Mediapipe image classifier task. // The options for configuring a Mediapipe image classifier task.
struct ImageClassifierOptions { struct ImageClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model // Base options for configuring MediaPipe Tasks, such as specifying the model
@ -56,9 +60,8 @@ struct ImageClassifierOptions {
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void( std::function<void(absl::StatusOr<ImageClassifierResult>, const Image&,
absl::StatusOr<components::containers::proto::ClassificationResult>, int64)>
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: describe exact preprocessing steps once // TODO: describe exact preprocessing steps once
// YUVToImageCalculator is integrated. // YUVToImageCalculator is integrated.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify( absl::StatusOr<ImageClassifierResult> Classify(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::proto::ClassificationResult> absl::StatusOr<ImageClassifierResult> ClassifyForVideo(
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> std::optional<core::ImageProcessingOptions> image_processing_options =
image_processing_options = std::nullopt); std::nullopt);
// Sends live image data to image classification, and the results will be // Sends live image data to image classification, and the results will be
// available via the "result_callback" provided in the ImageClassifierOptions. // available via the "result_callback" provided in the ImageClassifierOptions.
@ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// increasing. // increasing.
// //
// The "result_callback" provides: // The "result_callback" provides:
// - The classification results as a ClassificationResult object. // - The classification results as an ImageClassifierResult object.
// - The const reference to the corresponding input image that the image // - The const reference to the corresponding input image that the image
// classifier runs on. Note that the const reference to the image will no // classifier runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data // longer be valid when the callback returns. To access the image data

View File

@ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
@ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS";
// subgraph. // subgraph.
struct ImageClassifierOutputStreams { struct ImageClassifierOutputStreams {
Source<ClassificationResult> classification_result; Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<Image> image; Source<Image> image;
}; };
@ -71,17 +73,19 @@ struct ImageClassifierOutputStreams {
// Describes region of image to perform classification on. // Describes region of image to perform classification on.
// @Optional: rect covering the whole image is used if not specified. // @Optional: rect covering the whole image is used if not specified.
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATIONS - ClassificationResult @Optional
// The aggregated classification result object has two dimensions: // The classification results aggregated by classifier head.
// (classification head, classification category)
// IMAGE - Image // IMAGE - Image
// The image that object detection runs on. // The image that object detection runs on.
// TODO: remove this output once Java API migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph" // calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATIONS:classifications_out"
// output_stream: "IMAGE:image_out" // output_stream: "IMAGE:image_out"
// options { // options {
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext] // [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
@ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.classification_result >> output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)]; graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)]; output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
return ImageClassifierOutputStreams{ return ImageClassifierOutputStreams{
/*classification_result=*/postprocessing[Output<ClassificationResult>( /*classification_result=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)], kClassificationResultTag)],
/*classifications=*/
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]}; /*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -50,10 +50,9 @@ namespace image_classifier {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] = constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] =
"mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
// Checks that the two provided `ClassificationResult` are equal, with a // Checks that the two provided `ImageClassifierResult` are equal, with a
// tolerancy on floating-point score to account for numerical instabilities. // tolerancy on floating-point score to account for numerical instabilities.
void ExpectApproximatelyEqual(const ClassificationResult& actual, void ExpectApproximatelyEqual(const ImageClassifierResult& actual,
const ClassificationResult& expected) { const ImageClassifierResult& expected) {
const float kPrecision = 1e-6; const float kPrecision = 1e-6;
ASSERT_EQ(actual.classifications_size(), expected.classifications_size()); ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
for (int i = 0; i < actual.classifications_size(); ++i) { for (int i = 0; i < actual.classifications.size(); ++i) {
const Classifications& a = actual.classifications(i); const Classifications& a = actual.classifications[i];
const Classifications& b = expected.classifications(i); const Classifications& b = expected.classifications[i];
EXPECT_EQ(a.head_index(), b.head_index()); EXPECT_EQ(a.head_index, b.head_index);
EXPECT_EQ(a.head_name(), b.head_name()); EXPECT_EQ(a.head_name, b.head_name);
EXPECT_EQ(a.entries_size(), b.entries_size()); EXPECT_EQ(a.categories.size(), b.categories.size());
for (int j = 0; j < a.entries_size(); ++j) { for (int j = 0; j < a.categories.size(); ++j) {
const ClassificationEntry& x = a.entries(j); const Category& x = a.categories[j];
const ClassificationEntry& y = b.entries(j); const Category& y = b.categories[j];
EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms()); EXPECT_EQ(x.index, y.index);
EXPECT_EQ(x.categories_size(), y.categories_size()); EXPECT_NEAR(x.score, y.score, kPrecision);
for (int k = 0; k < x.categories_size(); ++k) { EXPECT_EQ(x.category_name, y.category_name);
EXPECT_EQ(x.categories(k).index(), y.categories(k).index()); EXPECT_EQ(x.display_name, y.display_name);
EXPECT_EQ(x.categories(k).category_name(),
y.categories(k).category_name());
EXPECT_EQ(x.categories(k).display_name(),
y.categories(k).display_name());
EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(),
kPrecision);
}
} }
} }
} }
// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata // Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
// with max_results set to 3. // with max_results set to 3.
ClassificationResult GenerateBurgerResults(int64 timestamp) { ImageClassifierResult GenerateBurgerResults() {
return ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult result;
absl::StrFormat(R"pb(classifications { result.classifications.emplace_back(Classifications{
entries { /*categories=*/{
categories { {/*index=*/934, /*score=*/0.793959200,
index: 934 /*category_name=*/"cheeseburger"},
score: 0.7939592 {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"},
category_name: "cheeseburger" {/*index=*/925, /*score=*/0.019340655,
} /*category_name=*/"guacamole"}},
categories { /*head_index=*/0,
index: 932 /*head_name=*/"probability"});
score: 0.027392805 return result;
category_name: "bagel"
}
categories {
index: 925
score: 0.019340655
category_name: "guacamole"
}
timestamp_ms: %d
}
head_index: 0
head_name: "probability"
})pb",
timestamp));
} }
// Generates expected results for "multi_objects.jpg" using // Generates expected results for "multi_objects.jpg" using
// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding // kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
// box set around the soccer ball. // box set around the soccer ball.
ClassificationResult GenerateSoccerBallResults(int64 timestamp) { ImageClassifierResult GenerateSoccerBallResults() {
return ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult result;
absl::StrFormat(R"pb(classifications { result.classifications.emplace_back(
entries { Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493,
categories { /*category_name=*/"soccer ball"}},
index: 806 /*head_index=*/0,
score: 0.996527493 /*head_name=*/"probability"});
category_name: "soccer ball" return result;
}
timestamp_ms: %d
}
head_index: 0
head_name: "probability"
})pb",
timestamp));
} }
// A custom OpResolver only containing the Ops required by the test model. // A custom OpResolver only containing the Ops required by the test model.
@ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata);
options->running_mode = running_mode; options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<ClassificationResult>, options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
auto image_classifier = ImageClassifier::Create(std::move(options)); auto image_classifier = ImageClassifier::Create(std::move(options));
@ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, GenerateBurgerResults(0)); ExpectApproximatelyEqual(results, GenerateBurgerResults());
} }
TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
@ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(
entries { Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625,
categories { /*category_name=*/"cheeseburger"}},
index: 934 /*head_index=*/0,
score: 0.97265625 /*head_name=*/"probability"});
category_name: "cheeseburger" ExpectApproximatelyEqual(results, expected);
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
@ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(
entries { Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592,
categories { /*category_name=*/"cheeseburger"}},
index: 934 /*head_index=*/0,
score: 0.7939592 /*head_name=*/"probability"});
category_name: "cheeseburger" ExpectApproximatelyEqual(results, expected);
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
@ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(Classifications{
entries { /*categories=*/{
categories { {/*index=*/934, /*score=*/0.7939592,
index: 934 /*category_name=*/"cheeseburger"},
score: 0.7939592 {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}},
category_name: "cheeseburger" /*head_index=*/0,
} /*head_name=*/"probability"});
categories { ExpectApproximatelyEqual(results, expected);
index: 932
score: 0.027392805
category_name: "bagel"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
@ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(Classifications{
entries { /*categories=*/{
categories { {/*index=*/934, /*score=*/0.7939592,
index: 934 /*category_name=*/"cheeseburger"},
score: 0.7939592 {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
category_name: "cheeseburger" {/*index=*/963, /*score=*/0.0063278517,
} /*category_name=*/"meat loaf"}},
categories { /*head_index=*/0,
index: 925 /*head_name=*/"probability"});
score: 0.019340655 ExpectApproximatelyEqual(results, expected);
category_name: "guacamole"
}
categories {
index: 963
score: 0.0063278517
category_name: "meat loaf"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithDenylistOption) { TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
@ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(Classifications{
entries { /*categories=*/{
categories { {/*index=*/934, /*score=*/0.7939592,
index: 934 /*category_name=*/"cheeseburger"},
score: 0.7939592 {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
category_name: "cheeseburger" {/*index=*/963, /*score=*/0.0063278517,
} /*category_name=*/"meat loaf"}},
categories { /*head_index=*/0,
index: 925 /*head_name=*/"probability"});
score: 0.019340655 ExpectApproximatelyEqual(results, expected);
category_name: "guacamole"
}
categories {
index: 963
score: 0.0063278517
category_name: "meat loaf"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
@ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(
entries { Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628,
categories { /*category_name=*/"cheeseburger"}},
index: 934 /*head_index=*/0,
score: 0.725648628 /*head_name=*/"probability"});
category_name: "cheeseburger" ExpectApproximatelyEqual(results, expected);
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
@ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options)); image, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
} }
TEST_F(ImageModeTest, SucceedsWithRotation) { TEST_F(ImageModeTest, SucceedsWithRotation) {
@ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
// Results differ slightly from the non-rotated image, but that's expected // Results differ slightly from the non-rotated image, but that's expected
// as models are very sensitive to the slightest numerical differences // as models are very sensitive to the slightest numerical differences
// introduced by the rotation and JPG encoding. // introduced by the rotation and JPG encoding.
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( ImageClassifierResult expected;
R"pb(classifications { expected.classifications.emplace_back(Classifications{
entries { /*categories=*/{
categories { {/*index=*/934, /*score=*/0.6371766,
index: 934 /*category_name=*/"cheeseburger"},
score: 0.6371766 {/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"},
category_name: "cheeseburger" {/*index=*/925, /*score=*/0.047918003,
} /*category_name=*/"guacamole"}},
categories { /*head_index=*/0,
index: 963 /*head_name=*/"probability"});
score: 0.049443405 ExpectApproximatelyEqual(results, expected);
category_name: "meat loaf"
}
categories {
index: 925
score: 0.047918003
category_name: "guacamole"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
@ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options)); image, image_processing_options));
ExpectApproximatelyEqual(results, ImageClassifierResult expected;
ParseTextProtoOrDie<ClassificationResult>( expected.classifications.emplace_back(
R"pb(classifications { Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213,
entries { /*category_name=*/"folding chair"}},
categories { /*head_index=*/0,
index: 560 /*head_name=*/"probability"});
score: 0.6522213 ExpectApproximatelyEqual(results, expected);
category_name: "folding chair"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
} }
// Testing all these once with ImageClassifier. // Testing all these once with ImageClassifier.
@ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) {
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results, MP_ASSERT_OK_AND_ASSIGN(auto results,
image_classifier->ClassifyForVideo(image, i)); image_classifier->ClassifyForVideo(image, i));
ExpectApproximatelyEqual(results, GenerateBurgerResults(i)); ExpectApproximatelyEqual(results, GenerateBurgerResults());
} }
MP_ASSERT_OK(image_classifier->Close()); MP_ASSERT_OK(image_classifier->Close());
} }
@ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto results, auto results,
image_classifier->ClassifyForVideo(image, i, image_processing_options)); image_classifier->ClassifyForVideo(image, i, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
} }
MP_ASSERT_OK(image_classifier->Close()); MP_ASSERT_OK(image_classifier->Close());
} }
@ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<ClassificationResult>, options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
@ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<ClassificationResult>, options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
const Image& image, int64 timestamp_ms) {}; const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options))); ImageClassifier::Create(std::move(options)));
@ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
} }
struct LiveStreamModeResults { struct LiveStreamModeResults {
ClassificationResult classification_result; ImageClassifierResult classification_result;
std::pair<int, int> image_size; std::pair<int, int> image_size;
int64 timestamp_ms; int64 timestamp_ms;
}; };
@ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->classifier_options.max_results = 3; options->classifier_options.max_results = 3;
options->result_callback = options->result_callback =
[&results](absl::StatusOr<ClassificationResult> classification_result, [&results](absl::StatusOr<ImageClassifierResult> classification_result,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(classification_result.status()); MP_ASSERT_OK(classification_result.status());
results.push_back( results.push_back(
@ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
EXPECT_EQ(result.image_size.first, image.width()); EXPECT_EQ(result.image_size.first, image.width());
EXPECT_EQ(result.image_size.second, image.height()); EXPECT_EQ(result.image_size.second, image.height());
ExpectApproximatelyEqual(result.classification_result, ExpectApproximatelyEqual(result.classification_result,
GenerateBurgerResults(timestamp_ms)); GenerateBurgerResults());
} }
} }
@ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->classifier_options.max_results = 1; options->classifier_options.max_results = 1;
options->result_callback = options->result_callback =
[&results](absl::StatusOr<ClassificationResult> classification_result, [&results](absl::StatusOr<ImageClassifierResult> classification_result,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(classification_result.status()); MP_ASSERT_OK(classification_result.status());
results.push_back( results.push_back(
@ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
EXPECT_EQ(result.image_size.first, image.width()); EXPECT_EQ(result.image_size.first, image.width());
EXPECT_EQ(result.image_size.second, image.height()); EXPECT_EQ(result.image_size.second, image.height());
ExpectApproximatelyEqual(result.classification_result, ExpectApproximatelyEqual(result.classification_result,
GenerateSoccerBallResults(timestamp_ms)); GenerateSoccerBallResults());
} }
} }