Add classification result structs and use then in ImageClassifier & TextClassifier.

PiperOrigin-RevId: 485567133
This commit is contained in:
MediaPipe Team 2022-11-02 05:11:16 -07:00 committed by Copybara-Service
parent aab5f84aae
commit d61ab92b90
15 changed files with 559 additions and 276 deletions

View File

@ -29,3 +29,23 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library(
name = "category",
srcs = ["category.cc"],
hdrs = ["category.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
],
)
cc_library(
name = "classification_result",
srcs = ["classification_result.cc"],
hdrs = ["classification_result.h"],
deps = [
":category",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/category.h"
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
Category ConvertToCategory(const mediapipe::Classification& proto) {
Category category;
category.index = proto.index();
category.score = proto.score();
if (proto.has_label()) {
category.category_name = proto.label();
}
if (proto.has_display_name()) {
category.display_name = proto.display_name();
}
return category;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,52 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
#include <optional>
#include <string>
#include "mediapipe/framework/formats/classification.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines a single classification result.
//
// The label maps packed into the TFLite Model Metadata [1] are used to populate
// the 'category_name' and 'display_name' fields.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
struct Category {
// The index of the category in the classification model output.
int index;
// The score for this category, e.g. (but not necessarily) a probability in
// [0,1].
float score;
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
std::optional<std::string> category_name = std::nullopt;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
std::optional<std::string> display_name = std::nullopt;
};
// Utility function to convert from mediapipe::Classification proto to Category
// struct.
Category ConvertToCategory(const mediapipe::Classification& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
Classifications ConvertToClassifications(const proto::Classifications& proto) {
Classifications classifications;
classifications.categories.reserve(
proto.classification_list().classification_size());
for (const auto& classification :
proto.classification_list().classification()) {
classifications.categories.push_back(ConvertToCategory(classification));
}
classifications.head_index = proto.head_index();
if (proto.has_head_name()) {
classifications.head_name = proto.head_name();
}
return classifications;
}
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto) {
ClassificationResult classification_result;
classification_result.classifications.reserve(proto.classifications_size());
for (const auto& classifications : proto.classifications()) {
classification_result.classifications.push_back(
ConvertToClassifications(classifications));
}
if (proto.has_timestamp_ms()) {
classification_result.timestamp_ms = proto.timestamp_ms();
}
return classification_result;
}
} // namespace mediapipe::tasks::components::containers

View File

@ -0,0 +1,68 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#include <optional>
#include <string>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe::tasks::components::containers {
// Defines classification results for a given classifier head.
struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
// e.g. from high to low probability.
std::vector<Category> categories;
// The index of the classifier head (i.e. output tensor) these categories
// refer to. This is useful for multi-head models.
int head_index;
// The optional name of the classifier head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
std::optional<std::string> head_name = std::nullopt;
};
// Defines classification results of a model.
struct ClassificationResult {
// The classification results for each head of the model.
std::vector<Classifications> classifications;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
std::optional<int64_t> timestamp_ms = std::nullopt;
};
// Utility function to convert from Classifications proto to
// Classifications struct.
Classifications ConvertToClassifications(const proto::Classifications& proto);
// Utility function to convert from ClassificationResult proto to
// ClassificationResult struct.
ClassificationResult ConvertToClassificationResult(
const proto::ClassificationResult& proto);
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

View File

@ -49,6 +49,8 @@ cc_library(
":text_classifier_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/components/containers:category",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
@ -76,7 +78,8 @@ cc_test(
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers:category",
"//mediapipe/tasks/cc/components/containers:classification_result",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
@ -37,12 +38,13 @@ namespace text_classifier {
namespace {
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kTextStreamName[] = "text_in";
constexpr char kTextTag[] = "TEXT";
constexpr char kClassificationResultStreamName[] = "classification_result_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsStreamName[] = "classifications_out";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
@ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig(
auto& subgraph = graph.AddNode(kSubgraphTypeName);
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag);
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
graph.Out(kClassificationsTag);
return graph.GetConfig();
}
@ -88,14 +89,14 @@ absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
std::move(options->base_options.op_resolver));
}
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
absl::StatusOr<TextClassifierResult> TextClassifier::Classify(
absl::string_view text) {
ASSIGN_OR_RETURN(
auto output_packets,
runner_->Process(
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
return output_packets[kClassificationResultStreamName]
.Get<ClassificationResult>();
return ConvertToClassificationResult(
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
}
} // namespace text_classifier

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
@ -31,6 +31,10 @@ namespace tasks {
namespace text {
namespace text_classifier {
// Alias the shared ClassificationResult struct as result type.
using TextClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a MediaPipe text classifier task.
struct TextClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
@ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi {
std::unique_ptr<TextClassifierOptions> options);
// Performs classification on the input `text`.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
absl::string_view text);
absl::StatusOr<TextClassifierResult> Classify(absl::string_view text);
// Shuts down the TextClassifier when all the work is done.
absl::Status Close() { return runner_->Close(); }

View File

@ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
// TODO: remove once Java API migration is over.
// Struct holding the different output streams produced by the text classifier.
struct TextClassifierOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
};
} // namespace
// A "TextClassifierGraph" performs Natural Language classification (including
@ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS";
// Input text to perform classification on.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by classifier head.
// TODO: remove once Java API migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result object that has 3 dimensions:
// (classification head, classification timestamp, classification category).
//
@ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS";
// node {
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
// input_stream: "TEXT:text_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// output_stream: "CLASSIFICATIONS:classifications_out"
// options {
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
// {
@ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph {
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
Source<ClassificationResult> classification_result_out,
auto output_streams,
BuildTextClassifierTask(
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
graph[Input<std::string>(kTextTag)], graph));
classification_result_out >>
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
return graph.GetConfig();
}
@ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
// TextClassifier model file with model metadata.
// text_in: (std::string) stream to run text classification on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
const proto::TextClassifierGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) {
@ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
// Outputs the aggregated classification result as the subgraph output
// stream.
return postprocessing[Output<ClassificationResult>(
kClassificationResultTag)];
return TextClassifierOutputStreams{
/*classification_result=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)],
/*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationsTag)]};
}
};

View File

@ -33,7 +33,8 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -43,14 +44,13 @@ namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::EqualsProto;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications;
using ::testing::HasSubstr;
using ::testing::Optional;
constexpr float kEpsilon = 0.001;
constexpr int kMaxSeqLen = 128;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
@ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
// Checks that the two provided `TextClassifierResult` are equal, with a
// tolerancy on floating-point score to account for numerical instabilities.
// TODO: create shared matcher for ClassificationResult.
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
const TextClassifierResult& expected) {
const float kPrecision = 1e-6;
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
for (int i = 0; i < actual.classifications.size(); ++i) {
const Classifications& a = actual.classifications[i];
const Classifications& b = expected.classifications[i];
EXPECT_EQ(a.head_index, b.head_index);
EXPECT_EQ(a.head_name, b.head_name);
EXPECT_EQ(a.categories.size(), b.categories.size());
for (int j = 0; j < a.categories.size(); ++j) {
const Category& x = a.categories[j];
const Category& y = b.categories[j];
EXPECT_EQ(x.index, y.index);
EXPECT_NEAR(x.score, y.score, kPrecision);
EXPECT_EQ(x.category_name, y.category_name);
EXPECT_EQ(x.display_name, y.display_name);
}
}
}
class TextClassifierTest : public tflite_shims::testing::Test {};
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
@ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
TextClassifierResult negative_expected;
negative_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(negative_result, negative_expected);
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
TextClassifierResult positive_expected;
positive_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(positive_result, positive_expected);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result,
classifier->Classify("What a waste of my time."));
TextClassifierResult negative_expected;
negative_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"},
{/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(negative_result, negative_expected);
MP_ASSERT_OK_AND_ASSIGN(
TextClassifierResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years."
"Strongly recommend it!"));
TextClassifierResult positive_expected;
positive_expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"},
{/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(positive_result, positive_expected);
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
classifier->Classify("hello"));
// Binary outputs causes flaky ordering, so we compare manually.
ASSERT_EQ(result.classifications.size(), 1);
ASSERT_EQ(result.classifications[0].head_index, 0);
ASSERT_EQ(result.classifications[0].categories.size(), 3);
ASSERT_EQ(result.classifications[0].categories[0].score, 1);
ASSERT_LT(result.classifications[0].categories[0].index, 2); // i.e O or 1.
ASSERT_EQ(result.classifications[0].categories[1].score, 1);
ASSERT_LT(result.classifications[0].categories[1].index, 2); // i.e 0 or 1.
ASSERT_EQ(result.classifications[0].categories[2].score, 0);
ASSERT_EQ(result.classifications[0].categories[2].index, 2);
MP_ASSERT_OK(classifier->Close());
}
} // namespace
} // namespace text_classifier
} // namespace text

View File

@ -50,6 +50,7 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:classification_result",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
@ -46,8 +47,8 @@ namespace image_classifier {
namespace {
constexpr char kClassificationResultStreamName[] = "classification_result_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsStreamName[] = "classifications_out";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kImageTag[] = "IMAGE";
@ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap;
@ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig(
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
options_proto.get());
task_subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag);
task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
graph.Out(kClassificationsTag);
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
{kImageTag, kNormRectTag},
kClassificationResultTag);
return tasks::core::AddFlowLimiterCalculator(
graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag);
}
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
@ -125,13 +125,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet classification_result_packet =
status_or_packets.value()[kClassificationResultStreamName];
Packet classifications_packet =
status_or_packets.value()[kClassificationsStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(
classification_result_packet.Get<ClassificationResult>(),
ConvertToClassificationResult(
classifications_packet.Get<ClassificationResult>()),
image_packet.Get<Image>(),
classification_result_packet.Timestamp().Value() /
classifications_packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond);
};
}
@ -144,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
std::move(packets_callback));
}
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
absl::StatusOr<ImageClassifierResult> ImageClassifier::Classify(
Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -160,11 +161,11 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
return output_packets[kClassificationResultStreamName]
.Get<ClassificationResult>();
return ConvertToClassificationResult(
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
}
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
absl::StatusOr<ImageClassifierResult> ImageClassifier::ClassifyForVideo(
Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) {
@ -184,8 +185,8 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
{kNormRectName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return output_packets[kClassificationResultStreamName]
.Get<ClassificationResult>();
return ConvertToClassificationResult(
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
}
absl::Status ImageClassifier::ClassifyAsync(

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
@ -34,6 +34,10 @@ namespace tasks {
namespace vision {
namespace image_classifier {
// Alias the shared ClassificationResult struct as result type.
using ImageClassifierResult =
::mediapipe::tasks::components::containers::ClassificationResult;
// The options for configuring a Mediapipe image classifier task.
struct ImageClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
@ -56,9 +60,8 @@ struct ImageClassifierOptions {
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(
absl::StatusOr<components::containers::proto::ClassificationResult>,
const Image&, int64)>
std::function<void(absl::StatusOr<ImageClassifierResult>, const Image&,
int64)>
result_callback = nullptr;
};
@ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA.
// TODO: describe exact preprocessing steps once
// YUVToImageCalculator is integrated.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
absl::StatusOr<ImageClassifierResult> Classify(
mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
@ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<components::containers::proto::ClassificationResult>
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions>
image_processing_options = std::nullopt);
absl::StatusOr<ImageClassifierResult> ClassifyForVideo(
mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt);
// Sends live image data to image classification, and the results will be
// available via the "result_callback" provided in the ImageClassifierOptions.
@ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// increasing.
//
// The "result_callback" provides:
// - The classification results as a ClassificationResult object.
// - The classification results as an ImageClassifierResult object.
// - The const reference to the corresponding input image that the image
// classifier runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data

View File

@ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kTensorsTag[] = "TENSORS";
@ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS";
// subgraph.
struct ImageClassifierOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<Image> image;
};
@ -71,17 +73,19 @@ struct ImageClassifierOutputStreams {
// Describes region of image to perform classification on.
// @Optional: rect covering the whole image is used if not specified.
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The aggregated classification result object has two dimensions:
// (classification head, classification category)
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by classifier head.
// IMAGE - Image
// The image that object detection runs on.
// TODO: remove this output once Java API migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
// input_stream: "IMAGE:image_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// output_stream: "CLASSIFICATIONS:classifications_out"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
@ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
@ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
return ImageClassifierOutputStreams{
/*classification_result=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)],
/*classifications=*/
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]};
}
};

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.h"
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
#include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -50,10 +50,9 @@ namespace image_classifier {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::Category;
using ::mediapipe::tasks::components::containers::Classifications;
using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
using ::testing::HasSubstr;
using ::testing::Optional;
@ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] =
"mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
// Checks that the two provided `ClassificationResult` are equal, with a
// Checks that the two provided `ImageClassifierResult` are equal, with a
// tolerancy on floating-point score to account for numerical instabilities.
void ExpectApproximatelyEqual(const ClassificationResult& actual,
const ClassificationResult& expected) {
void ExpectApproximatelyEqual(const ImageClassifierResult& actual,
const ImageClassifierResult& expected) {
const float kPrecision = 1e-6;
ASSERT_EQ(actual.classifications_size(), expected.classifications_size());
for (int i = 0; i < actual.classifications_size(); ++i) {
const Classifications& a = actual.classifications(i);
const Classifications& b = expected.classifications(i);
EXPECT_EQ(a.head_index(), b.head_index());
EXPECT_EQ(a.head_name(), b.head_name());
EXPECT_EQ(a.entries_size(), b.entries_size());
for (int j = 0; j < a.entries_size(); ++j) {
const ClassificationEntry& x = a.entries(j);
const ClassificationEntry& y = b.entries(j);
EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms());
EXPECT_EQ(x.categories_size(), y.categories_size());
for (int k = 0; k < x.categories_size(); ++k) {
EXPECT_EQ(x.categories(k).index(), y.categories(k).index());
EXPECT_EQ(x.categories(k).category_name(),
y.categories(k).category_name());
EXPECT_EQ(x.categories(k).display_name(),
y.categories(k).display_name());
EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(),
kPrecision);
}
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
for (int i = 0; i < actual.classifications.size(); ++i) {
const Classifications& a = actual.classifications[i];
const Classifications& b = expected.classifications[i];
EXPECT_EQ(a.head_index, b.head_index);
EXPECT_EQ(a.head_name, b.head_name);
EXPECT_EQ(a.categories.size(), b.categories.size());
for (int j = 0; j < a.categories.size(); ++j) {
const Category& x = a.categories[j];
const Category& y = b.categories[j];
EXPECT_EQ(x.index, y.index);
EXPECT_NEAR(x.score, y.score, kPrecision);
EXPECT_EQ(x.category_name, y.category_name);
EXPECT_EQ(x.display_name, y.display_name);
}
}
}
// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
// with max_results set to 3.
ClassificationResult GenerateBurgerResults(int64 timestamp) {
return ParseTextProtoOrDie<ClassificationResult>(
absl::StrFormat(R"pb(classifications {
entries {
categories {
index: 934
score: 0.7939592
category_name: "cheeseburger"
}
categories {
index: 932
score: 0.027392805
category_name: "bagel"
}
categories {
index: 925
score: 0.019340655
category_name: "guacamole"
}
timestamp_ms: %d
}
head_index: 0
head_name: "probability"
})pb",
timestamp));
ImageClassifierResult GenerateBurgerResults() {
ImageClassifierResult result;
result.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/934, /*score=*/0.793959200,
/*category_name=*/"cheeseburger"},
{/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"},
{/*index=*/925, /*score=*/0.019340655,
/*category_name=*/"guacamole"}},
/*head_index=*/0,
/*head_name=*/"probability"});
return result;
}
// Generates expected results for "multi_objects.jpg" using
// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
// box set around the soccer ball.
ClassificationResult GenerateSoccerBallResults(int64 timestamp) {
return ParseTextProtoOrDie<ClassificationResult>(
absl::StrFormat(R"pb(classifications {
entries {
categories {
index: 806
score: 0.996527493
category_name: "soccer ball"
}
timestamp_ms: %d
}
head_index: 0
head_name: "probability"
})pb",
timestamp));
ImageClassifierResult GenerateSoccerBallResults() {
ImageClassifierResult result;
result.classifications.emplace_back(
Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493,
/*category_name=*/"soccer ball"}},
/*head_index=*/0,
/*head_name=*/"probability"});
return result;
}
// A custom OpResolver only containing the Ops required by the test model.
@ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata);
options->running_mode = running_mode;
options->result_callback = [](absl::StatusOr<ClassificationResult>,
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
const Image& image, int64 timestamp_ms) {};
auto image_classifier = ImageClassifier::Create(std::move(options));
@ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, GenerateBurgerResults(0));
ExpectApproximatelyEqual(results, GenerateBurgerResults());
}
TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
@ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.97265625
category_name: "cheeseburger"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625,
/*category_name=*/"cheeseburger"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
@ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.7939592
category_name: "cheeseburger"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592,
/*category_name=*/"cheeseburger"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
@ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.7939592
category_name: "cheeseburger"
}
categories {
index: 932
score: 0.027392805
category_name: "bagel"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/934, /*score=*/0.7939592,
/*category_name=*/"cheeseburger"},
{/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
@ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.7939592
category_name: "cheeseburger"
}
categories {
index: 925
score: 0.019340655
category_name: "guacamole"
}
categories {
index: 963
score: 0.0063278517
category_name: "meat loaf"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/934, /*score=*/0.7939592,
/*category_name=*/"cheeseburger"},
{/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
{/*index=*/963, /*score=*/0.0063278517,
/*category_name=*/"meat loaf"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
@ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.7939592
category_name: "cheeseburger"
}
categories {
index: 925
score: 0.019340655
category_name: "guacamole"
}
categories {
index: 963
score: 0.0063278517
category_name: "meat loaf"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/934, /*score=*/0.7939592,
/*category_name=*/"cheeseburger"},
{/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
{/*index=*/963, /*score=*/0.0063278517,
/*category_name=*/"meat loaf"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
@ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.725648628
category_name: "cheeseburger"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628,
/*category_name=*/"cheeseburger"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
@ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
@ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
// Results differ slightly from the non-rotated image, but that's expected
// as models are very sensitive to the slightest numerical differences
// introduced by the rotation and JPG encoding.
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 934
score: 0.6371766
category_name: "cheeseburger"
}
categories {
index: 963
score: 0.049443405
category_name: "meat loaf"
}
categories {
index: 925
score: 0.047918003
category_name: "guacamole"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(Classifications{
/*categories=*/{
{/*index=*/934, /*score=*/0.6371766,
/*category_name=*/"cheeseburger"},
{/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"},
{/*index=*/925, /*score=*/0.047918003,
/*category_name=*/"guacamole"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
@ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
image, image_processing_options));
ExpectApproximatelyEqual(results,
ParseTextProtoOrDie<ClassificationResult>(
R"pb(classifications {
entries {
categories {
index: 560
score: 0.6522213
category_name: "folding chair"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
ImageClassifierResult expected;
expected.classifications.emplace_back(
Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213,
/*category_name=*/"folding chair"}},
/*head_index=*/0,
/*head_name=*/"probability"});
ExpectApproximatelyEqual(results, expected);
}
// Testing all these once with ImageClassifier.
@ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) {
for (int i = 0; i < iterations; ++i) {
MP_ASSERT_OK_AND_ASSIGN(auto results,
image_classifier->ClassifyForVideo(image, i));
ExpectApproximatelyEqual(results, GenerateBurgerResults(i));
ExpectApproximatelyEqual(results, GenerateBurgerResults());
}
MP_ASSERT_OK(image_classifier->Close());
}
@ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
MP_ASSERT_OK_AND_ASSIGN(
auto results,
image_classifier->ClassifyForVideo(image, i, image_processing_options));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
}
MP_ASSERT_OK(image_classifier->Close());
}
@ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<ClassificationResult>,
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
@ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = [](absl::StatusOr<ClassificationResult>,
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
ImageClassifier::Create(std::move(options)));
@ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
}
struct LiveStreamModeResults {
ClassificationResult classification_result;
ImageClassifierResult classification_result;
std::pair<int, int> image_size;
int64 timestamp_ms;
};
@ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
options->running_mode = core::RunningMode::LIVE_STREAM;
options->classifier_options.max_results = 3;
options->result_callback =
[&results](absl::StatusOr<ClassificationResult> classification_result,
[&results](absl::StatusOr<ImageClassifierResult> classification_result,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(classification_result.status());
results.push_back(
@ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
EXPECT_EQ(result.image_size.first, image.width());
EXPECT_EQ(result.image_size.second, image.height());
ExpectApproximatelyEqual(result.classification_result,
GenerateBurgerResults(timestamp_ms));
GenerateBurgerResults());
}
}
@ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
options->running_mode = core::RunningMode::LIVE_STREAM;
options->classifier_options.max_results = 1;
options->result_callback =
[&results](absl::StatusOr<ClassificationResult> classification_result,
[&results](absl::StatusOr<ImageClassifierResult> classification_result,
const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(classification_result.status());
results.push_back(
@ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
EXPECT_EQ(result.image_size.first, image.width());
EXPECT_EQ(result.image_size.second, image.height());
ExpectApproximatelyEqual(result.classification_result,
GenerateSoccerBallResults(timestamp_ms));
GenerateSoccerBallResults());
}
}