Add classification result structs and use then in ImageClassifier & TextClassifier.
PiperOrigin-RevId: 485567133
This commit is contained in:
parent
aab5f84aae
commit
d61ab92b90
|
@ -29,3 +29,23 @@ cc_library(
|
|||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "category",
|
||||
srcs = ["category.cc"],
|
||||
hdrs = ["category.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_result",
|
||||
srcs = ["classification_result.cc"],
|
||||
hdrs = ["classification_result.h"],
|
||||
deps = [
|
||||
":category",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
],
|
||||
)
|
||||
|
|
38
mediapipe/tasks/cc/components/containers/category.cc
Normal file
38
mediapipe/tasks/cc/components/containers/category.cc
Normal file
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
Category ConvertToCategory(const mediapipe::Classification& proto) {
|
||||
Category category;
|
||||
category.index = proto.index();
|
||||
category.score = proto.score();
|
||||
if (proto.has_label()) {
|
||||
category.category_name = proto.label();
|
||||
}
|
||||
if (proto.has_display_name()) {
|
||||
category.display_name = proto.display_name();
|
||||
}
|
||||
return category;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
52
mediapipe/tasks/cc/components/containers/category.h
Normal file
52
mediapipe/tasks/cc/components/containers/category.h
Normal file
|
@ -0,0 +1,52 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Defines a single classification result.
|
||||
//
|
||||
// The label maps packed into the TFLite Model Metadata [1] are used to populate
|
||||
// the 'category_name' and 'display_name' fields.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
struct Category {
|
||||
// The index of the category in the classification model output.
|
||||
int index;
|
||||
// The score for this category, e.g. (but not necessarily) a probability in
|
||||
// [0,1].
|
||||
float score;
|
||||
// The optional ID for the category, read from the label map packed in the
|
||||
// TFLite Model Metadata if present. Not necessarily human-readable.
|
||||
std::optional<std::string> category_name = std::nullopt;
|
||||
// The optional human-readable name for the category, read from the label map
|
||||
// packed in the TFLite Model Metadata if present.
|
||||
std::optional<std::string> display_name = std::nullopt;
|
||||
};
|
||||
|
||||
// Utility function to convert from mediapipe::Classification proto to Category
|
||||
// struct.
|
||||
Category ConvertToCategory(const mediapipe::Classification& proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
|
@ -0,0 +1,57 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
Classifications ConvertToClassifications(const proto::Classifications& proto) {
|
||||
Classifications classifications;
|
||||
classifications.categories.reserve(
|
||||
proto.classification_list().classification_size());
|
||||
for (const auto& classification :
|
||||
proto.classification_list().classification()) {
|
||||
classifications.categories.push_back(ConvertToCategory(classification));
|
||||
}
|
||||
classifications.head_index = proto.head_index();
|
||||
if (proto.has_head_name()) {
|
||||
classifications.head_name = proto.head_name();
|
||||
}
|
||||
return classifications;
|
||||
}
|
||||
|
||||
ClassificationResult ConvertToClassificationResult(
|
||||
const proto::ClassificationResult& proto) {
|
||||
ClassificationResult classification_result;
|
||||
classification_result.classifications.reserve(proto.classifications_size());
|
||||
for (const auto& classifications : proto.classifications()) {
|
||||
classification_result.classifications.push_back(
|
||||
ConvertToClassifications(classifications));
|
||||
}
|
||||
if (proto.has_timestamp_ms()) {
|
||||
classification_result.timestamp_ms = proto.timestamp_ms();
|
||||
}
|
||||
return classification_result;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
|
@ -0,0 +1,68 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// Defines classification results for a given classifier head.
|
||||
struct Classifications {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
// e.g. from high to low probability.
|
||||
std::vector<Category> categories;
|
||||
// The index of the classifier head (i.e. output tensor) these categories
|
||||
// refer to. This is useful for multi-head models.
|
||||
int head_index;
|
||||
// The optional name of the classifier head, as provided in the TFLite Model
|
||||
// Metadata [1] if present. This is useful for multi-head models.
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
std::optional<std::string> head_name = std::nullopt;
|
||||
};
|
||||
|
||||
// Defines classification results of a model.
|
||||
struct ClassificationResult {
|
||||
// The classification results for each head of the model.
|
||||
std::vector<Classifications> classifications;
|
||||
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
// corresponding to these results.
|
||||
//
|
||||
// This is only used for classification on time series (e.g. audio
|
||||
// classification). In these use cases, the amount of data to process might
|
||||
// exceed the maximum size that the model can process: to solve this, the
|
||||
// input data is split into multiple chunks starting at different timestamps.
|
||||
std::optional<int64_t> timestamp_ms = std::nullopt;
|
||||
};
|
||||
|
||||
// Utility function to convert from Classifications proto to
|
||||
// Classifications struct.
|
||||
Classifications ConvertToClassifications(const proto::Classifications& proto);
|
||||
|
||||
// Utility function to convert from ClassificationResult proto to
|
||||
// ClassificationResult struct.
|
||||
ClassificationResult ConvertToClassificationResult(
|
||||
const proto::ClassificationResult& proto);
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
|
@ -49,6 +49,8 @@ cc_library(
|
|||
":text_classifier_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers:category",
|
||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
|
@ -76,7 +78,8 @@ cc_test(
|
|||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:category",
|
||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
|
@ -37,12 +38,13 @@ namespace text_classifier {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kTextStreamName[] = "text_in";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
|
@ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
|
||||
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||
subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||
graph.Out(kClassificationsTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -88,14 +89,14 @@ absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
|
|||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
|
||||
absl::StatusOr<TextClassifierResult> TextClassifier::Classify(
|
||||
absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
return output_packets[kClassificationResultStreamName]
|
||||
.Get<ClassificationResult>();
|
||||
return ConvertToClassificationResult(
|
||||
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
|
||||
}
|
||||
|
||||
} // namespace text_classifier
|
||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
@ -31,6 +31,10 @@ namespace tasks {
|
|||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
// Alias the shared ClassificationResult struct as result type.
|
||||
using TextClassifierResult =
|
||||
::mediapipe::tasks::components::containers::ClassificationResult;
|
||||
|
||||
// The options for configuring a MediaPipe text classifier task.
|
||||
struct TextClassifierOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
|
@ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi {
|
|||
std::unique_ptr<TextClassifierOptions> options);
|
||||
|
||||
// Performs classification on the input `text`.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
absl::string_view text);
|
||||
absl::StatusOr<TextClassifierResult> Classify(absl::string_view text);
|
||||
|
||||
// Shuts down the TextClassifier when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
|||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
// TODO: remove once Java API migration is over.
|
||||
// Struct holding the different output streams produced by the text classifier.
|
||||
struct TextClassifierOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||
|
@ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
// Input text to perform classification on.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by classifier head.
|
||||
// TODO: remove once Java API migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result object that has 3 dimensions:
|
||||
// (classification head, classification timestamp, classification category).
|
||||
//
|
||||
|
@ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
// node {
|
||||
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
|
||||
// input_stream: "TEXT:text_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// output_stream: "CLASSIFICATIONS:classifications_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
|
||||
// {
|
||||
|
@ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
|||
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
Source<ClassificationResult> classification_result_out,
|
||||
auto output_streams,
|
||||
BuildTextClassifierTask(
|
||||
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<std::string>(kTextTag)], graph));
|
||||
classification_result_out >>
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
|||
// TextClassifier model file with model metadata.
|
||||
// text_in: (std::string) stream to run text classification on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
||||
absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
|
||||
const proto::TextClassifierGraphOptions& task_options,
|
||||
const ModelResources& model_resources, Source<std::string> text_in,
|
||||
Graph& graph) {
|
||||
|
@ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
|||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
// stream.
|
||||
return postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)];
|
||||
return TextClassifierOutputStreams{
|
||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)],
|
||||
/*classifications=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationsTag)]};
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
|
@ -43,14 +44,13 @@ namespace text {
|
|||
namespace text_classifier {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::EqualsProto;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::Category;
|
||||
using ::mediapipe::tasks::components::containers::Classifications;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
constexpr float kEpsilon = 0.001;
|
||||
constexpr int kMaxSeqLen = 128;
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||
|
@ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) {
|
|||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
// Checks that the two provided `TextClassifierResult` are equal, with a
|
||||
// tolerancy on floating-point score to account for numerical instabilities.
|
||||
// TODO: create shared matcher for ClassificationResult.
|
||||
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
||||
const TextClassifierResult& expected) {
|
||||
const float kPrecision = 1e-6;
|
||||
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
|
||||
for (int i = 0; i < actual.classifications.size(); ++i) {
|
||||
const Classifications& a = actual.classifications[i];
|
||||
const Classifications& b = expected.classifications[i];
|
||||
EXPECT_EQ(a.head_index, b.head_index);
|
||||
EXPECT_EQ(a.head_name, b.head_name);
|
||||
EXPECT_EQ(a.categories.size(), b.categories.size());
|
||||
for (int j = 0; j < a.categories.size(); ++j) {
|
||||
const Category& x = a.categories[j];
|
||||
const Category& y = b.categories[j];
|
||||
EXPECT_EQ(x.index, y.index);
|
||||
EXPECT_NEAR(x.score, y.score, kPrecision);
|
||||
EXPECT_EQ(x.category_name, y.category_name);
|
||||
EXPECT_EQ(x.display_name, y.display_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||
|
@ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
|||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextClassifierResult negative_result,
|
||||
classifier->Classify("unflinchingly bleak and desperate"));
|
||||
TextClassifierResult negative_expected;
|
||||
negative_expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
|
||||
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(negative_result, negative_expected);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextClassifierResult positive_result,
|
||||
classifier->Classify("it's a charming and often affecting journey"));
|
||||
TextClassifierResult positive_expected;
|
||||
positive_expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
|
||||
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(positive_result, positive_expected);
|
||||
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result,
|
||||
classifier->Classify("What a waste of my time."));
|
||||
TextClassifierResult negative_expected;
|
||||
negative_expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"},
|
||||
{/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(negative_result, negative_expected);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextClassifierResult positive_result,
|
||||
classifier->Classify("This is the best movie I’ve seen in recent years."
|
||||
"Strongly recommend it!"));
|
||||
TextClassifierResult positive_expected;
|
||||
positive_expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"},
|
||||
{/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(positive_result, positive_expected);
|
||||
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||
options->base_options.op_resolver = CreateCustomResolver();
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
|
||||
classifier->Classify("hello"));
|
||||
|
||||
// Binary outputs causes flaky ordering, so we compare manually.
|
||||
ASSERT_EQ(result.classifications.size(), 1);
|
||||
ASSERT_EQ(result.classifications[0].head_index, 0);
|
||||
ASSERT_EQ(result.classifications[0].categories.size(), 3);
|
||||
ASSERT_EQ(result.classifications[0].categories[0].score, 1);
|
||||
ASSERT_LT(result.classifications[0].categories[0].index, 2); // i.e O or 1.
|
||||
ASSERT_EQ(result.classifications[0].categories[1].score, 1);
|
||||
ASSERT_LT(result.classifications[0].categories[1].index, 2); // i.e 0 or 1.
|
||||
ASSERT_EQ(result.classifications[0].categories[2].score, 0);
|
||||
ASSERT_EQ(result.classifications[0].categories[2].index, 2);
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
|
|
|
@ -50,6 +50,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
@ -46,8 +47,8 @@ namespace image_classifier {
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
|
@ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] =
|
|||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
|
||||
|
@ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
task_subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||
graph.Out(kClassificationsTag);
|
||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||
graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
|
||||
{kImageTag, kNormRectTag},
|
||||
kClassificationResultTag);
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag);
|
||||
}
|
||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||
|
@ -125,13 +125,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
|||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||
return;
|
||||
}
|
||||
Packet classification_result_packet =
|
||||
status_or_packets.value()[kClassificationResultStreamName];
|
||||
Packet classifications_packet =
|
||||
status_or_packets.value()[kClassificationsStreamName];
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
result_callback(
|
||||
classification_result_packet.Get<ClassificationResult>(),
|
||||
ConvertToClassificationResult(
|
||||
classifications_packet.Get<ClassificationResult>()),
|
||||
image_packet.Get<Image>(),
|
||||
classification_result_packet.Timestamp().Value() /
|
||||
classifications_packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
|
@ -144,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
|||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
||||
absl::StatusOr<ImageClassifierResult> ImageClassifier::Classify(
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -160,11 +161,11 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
|||
ProcessImageData(
|
||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
return output_packets[kClassificationResultStreamName]
|
||||
.Get<ClassificationResult>();
|
||||
return ConvertToClassificationResult(
|
||||
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
||||
absl::StatusOr<ImageClassifierResult> ImageClassifier::ClassifyForVideo(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -184,8 +185,8 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
|||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
return output_packets[kClassificationResultStreamName]
|
||||
.Get<ClassificationResult>();
|
||||
return ConvertToClassificationResult(
|
||||
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
|
||||
}
|
||||
|
||||
absl::Status ImageClassifier::ClassifyAsync(
|
||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
|
@ -34,6 +34,10 @@ namespace tasks {
|
|||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
// Alias the shared ClassificationResult struct as result type.
|
||||
using ImageClassifierResult =
|
||||
::mediapipe::tasks::components::containers::ClassificationResult;
|
||||
|
||||
// The options for configuring a Mediapipe image classifier task.
|
||||
struct ImageClassifierOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
|
@ -56,9 +60,8 @@ struct ImageClassifierOptions {
|
|||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>,
|
||||
const Image&, int64)>
|
||||
std::function<void(absl::StatusOr<ImageClassifierResult>, const Image&,
|
||||
int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: describe exact preprocessing steps once
|
||||
// YUVToImageCalculator is integrated.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
absl::StatusOr<ImageClassifierResult> Classify(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
@ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>
|
||||
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
absl::StatusOr<ImageClassifierResult> ClassifyForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to image classification, and the results will be
|
||||
// available via the "result_callback" provided in the ImageClassifierOptions.
|
||||
|
@ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// increasing.
|
||||
//
|
||||
// The "result_callback" provides:
|
||||
// - The classification results as a ClassificationResult object.
|
||||
// - The classification results as an ImageClassifierResult object.
|
||||
// - The const reference to the corresponding input image that the image
|
||||
// classifier runs on. Note that the const reference to the image will no
|
||||
// longer be valid when the callback returns. To access the image data
|
||||
|
|
|
@ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
|||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
@ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
// subgraph.
|
||||
struct ImageClassifierOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
|
@ -71,17 +73,19 @@ struct ImageClassifierOutputStreams {
|
|||
// Describes region of image to perform classification on.
|
||||
// @Optional: rect covering the whole image is used if not specified.
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The aggregated classification result object has two dimensions:
|
||||
// (classification head, classification category)
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by classifier head.
|
||||
// IMAGE - Image
|
||||
// The image that object detection runs on.
|
||||
// TODO: remove this output once Java API migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// output_stream: "CLASSIFICATIONS:classifications_out"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
|
||||
|
@ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
return ImageClassifierOutputStreams{
|
||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)],
|
||||
/*classifications=*/
|
||||
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
|
||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -32,8 +32,8 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
@ -50,10 +50,9 @@ namespace image_classifier {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::Category;
|
||||
using ::mediapipe::tasks::components::containers::Classifications;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
@ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
|
|||
constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] =
|
||||
"mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
|
||||
|
||||
// Checks that the two provided `ClassificationResult` are equal, with a
|
||||
// Checks that the two provided `ImageClassifierResult` are equal, with a
|
||||
// tolerancy on floating-point score to account for numerical instabilities.
|
||||
void ExpectApproximatelyEqual(const ClassificationResult& actual,
|
||||
const ClassificationResult& expected) {
|
||||
void ExpectApproximatelyEqual(const ImageClassifierResult& actual,
|
||||
const ImageClassifierResult& expected) {
|
||||
const float kPrecision = 1e-6;
|
||||
ASSERT_EQ(actual.classifications_size(), expected.classifications_size());
|
||||
for (int i = 0; i < actual.classifications_size(); ++i) {
|
||||
const Classifications& a = actual.classifications(i);
|
||||
const Classifications& b = expected.classifications(i);
|
||||
EXPECT_EQ(a.head_index(), b.head_index());
|
||||
EXPECT_EQ(a.head_name(), b.head_name());
|
||||
EXPECT_EQ(a.entries_size(), b.entries_size());
|
||||
for (int j = 0; j < a.entries_size(); ++j) {
|
||||
const ClassificationEntry& x = a.entries(j);
|
||||
const ClassificationEntry& y = b.entries(j);
|
||||
EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms());
|
||||
EXPECT_EQ(x.categories_size(), y.categories_size());
|
||||
for (int k = 0; k < x.categories_size(); ++k) {
|
||||
EXPECT_EQ(x.categories(k).index(), y.categories(k).index());
|
||||
EXPECT_EQ(x.categories(k).category_name(),
|
||||
y.categories(k).category_name());
|
||||
EXPECT_EQ(x.categories(k).display_name(),
|
||||
y.categories(k).display_name());
|
||||
EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(),
|
||||
kPrecision);
|
||||
}
|
||||
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
|
||||
for (int i = 0; i < actual.classifications.size(); ++i) {
|
||||
const Classifications& a = actual.classifications[i];
|
||||
const Classifications& b = expected.classifications[i];
|
||||
EXPECT_EQ(a.head_index, b.head_index);
|
||||
EXPECT_EQ(a.head_name, b.head_name);
|
||||
EXPECT_EQ(a.categories.size(), b.categories.size());
|
||||
for (int j = 0; j < a.categories.size(); ++j) {
|
||||
const Category& x = a.categories[j];
|
||||
const Category& y = b.categories[j];
|
||||
EXPECT_EQ(x.index, y.index);
|
||||
EXPECT_NEAR(x.score, y.score, kPrecision);
|
||||
EXPECT_EQ(x.category_name, y.category_name);
|
||||
EXPECT_EQ(x.display_name, y.display_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
|
||||
// with max_results set to 3.
|
||||
ClassificationResult GenerateBurgerResults(int64 timestamp) {
|
||||
return ParseTextProtoOrDie<ClassificationResult>(
|
||||
absl::StrFormat(R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.7939592
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
categories {
|
||||
index: 932
|
||||
score: 0.027392805
|
||||
category_name: "bagel"
|
||||
}
|
||||
categories {
|
||||
index: 925
|
||||
score: 0.019340655
|
||||
category_name: "guacamole"
|
||||
}
|
||||
timestamp_ms: %d
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb",
|
||||
timestamp));
|
||||
ImageClassifierResult GenerateBurgerResults() {
|
||||
ImageClassifierResult result;
|
||||
result.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/934, /*score=*/0.793959200,
|
||||
/*category_name=*/"cheeseburger"},
|
||||
{/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"},
|
||||
{/*index=*/925, /*score=*/0.019340655,
|
||||
/*category_name=*/"guacamole"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
return result;
|
||||
}
|
||||
|
||||
// Generates expected results for "multi_objects.jpg" using
|
||||
// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
|
||||
// box set around the soccer ball.
|
||||
ClassificationResult GenerateSoccerBallResults(int64 timestamp) {
|
||||
return ParseTextProtoOrDie<ClassificationResult>(
|
||||
absl::StrFormat(R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 806
|
||||
score: 0.996527493
|
||||
category_name: "soccer ball"
|
||||
}
|
||||
timestamp_ms: %d
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb",
|
||||
timestamp));
|
||||
ImageClassifierResult GenerateSoccerBallResults() {
|
||||
ImageClassifierResult result;
|
||||
result.classifications.emplace_back(
|
||||
Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493,
|
||||
/*category_name=*/"soccer ball"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
return result;
|
||||
}
|
||||
|
||||
// A custom OpResolver only containing the Ops required by the test model.
|
||||
|
@ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata);
|
||||
options->running_mode = running_mode;
|
||||
options->result_callback = [](absl::StatusOr<ClassificationResult>,
|
||||
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
|
||||
auto image_classifier = ImageClassifier::Create(std::move(options));
|
||||
|
@ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, GenerateBurgerResults(0));
|
||||
ExpectApproximatelyEqual(results, GenerateBurgerResults());
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
|
||||
|
@ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.97265625
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(
|
||||
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625,
|
||||
/*category_name=*/"cheeseburger"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||
|
@ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.7939592
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(
|
||||
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592,
|
||||
/*category_name=*/"cheeseburger"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
||||
|
@ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.7939592
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
categories {
|
||||
index: 932
|
||||
score: 0.027392805
|
||||
category_name: "bagel"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/934, /*score=*/0.7939592,
|
||||
/*category_name=*/"cheeseburger"},
|
||||
{/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
||||
|
@ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.7939592
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
categories {
|
||||
index: 925
|
||||
score: 0.019340655
|
||||
category_name: "guacamole"
|
||||
}
|
||||
categories {
|
||||
index: 963
|
||||
score: 0.0063278517
|
||||
category_name: "meat loaf"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/934, /*score=*/0.7939592,
|
||||
/*category_name=*/"cheeseburger"},
|
||||
{/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
|
||||
{/*index=*/963, /*score=*/0.0063278517,
|
||||
/*category_name=*/"meat loaf"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||
|
@ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.7939592
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
categories {
|
||||
index: 925
|
||||
score: 0.019340655
|
||||
category_name: "guacamole"
|
||||
}
|
||||
categories {
|
||||
index: 963
|
||||
score: 0.0063278517
|
||||
category_name: "meat loaf"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/934, /*score=*/0.7939592,
|
||||
/*category_name=*/"cheeseburger"},
|
||||
{/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
|
||||
{/*index=*/963, /*score=*/0.0063278517,
|
||||
/*category_name=*/"meat loaf"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
||||
|
@ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.725648628
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(
|
||||
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628,
|
||||
/*category_name=*/"cheeseburger"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||
|
@ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
|
||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
|
@ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
// Results differ slightly from the non-rotated image, but that's expected
|
||||
// as models are very sensitive to the slightest numerical differences
|
||||
// introduced by the rotation and JPG encoding.
|
||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 934
|
||||
score: 0.6371766
|
||||
category_name: "cheeseburger"
|
||||
}
|
||||
categories {
|
||||
index: 963
|
||||
score: 0.049443405
|
||||
category_name: "meat loaf"
|
||||
}
|
||||
categories {
|
||||
index: 925
|
||||
score: 0.047918003
|
||||
category_name: "guacamole"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(Classifications{
|
||||
/*categories=*/{
|
||||
{/*index=*/934, /*score=*/0.6371766,
|
||||
/*category_name=*/"cheeseburger"},
|
||||
{/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"},
|
||||
{/*index=*/925, /*score=*/0.047918003,
|
||||
/*category_name=*/"guacamole"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||
|
@ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||
image, image_processing_options));
|
||||
|
||||
ExpectApproximatelyEqual(results,
|
||||
ParseTextProtoOrDie<ClassificationResult>(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 560
|
||||
score: 0.6522213
|
||||
category_name: "folding chair"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
ImageClassifierResult expected;
|
||||
expected.classifications.emplace_back(
|
||||
Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213,
|
||||
/*category_name=*/"folding chair"}},
|
||||
/*head_index=*/0,
|
||||
/*head_name=*/"probability"});
|
||||
ExpectApproximatelyEqual(results, expected);
|
||||
}
|
||||
|
||||
// Testing all these once with ImageClassifier.
|
||||
|
@ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
|||
for (int i = 0; i < iterations; ++i) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
image_classifier->ClassifyForVideo(image, i));
|
||||
ExpectApproximatelyEqual(results, GenerateBurgerResults(i));
|
||||
ExpectApproximatelyEqual(results, GenerateBurgerResults());
|
||||
}
|
||||
MP_ASSERT_OK(image_classifier->Close());
|
||||
}
|
||||
|
@ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto results,
|
||||
image_classifier->ClassifyForVideo(image, i, image_processing_options));
|
||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
|
||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
|
||||
}
|
||||
MP_ASSERT_OK(image_classifier->Close());
|
||||
}
|
||||
|
@ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback = [](absl::StatusOr<ClassificationResult>,
|
||||
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
|
@ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback = [](absl::StatusOr<ClassificationResult>,
|
||||
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||
ImageClassifier::Create(std::move(options)));
|
||||
|
@ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
|||
}
|
||||
|
||||
struct LiveStreamModeResults {
|
||||
ClassificationResult classification_result;
|
||||
ImageClassifierResult classification_result;
|
||||
std::pair<int, int> image_size;
|
||||
int64 timestamp_ms;
|
||||
};
|
||||
|
@ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->classifier_options.max_results = 3;
|
||||
options->result_callback =
|
||||
[&results](absl::StatusOr<ClassificationResult> classification_result,
|
||||
[&results](absl::StatusOr<ImageClassifierResult> classification_result,
|
||||
const Image& image, int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(classification_result.status());
|
||||
results.push_back(
|
||||
|
@ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
|||
EXPECT_EQ(result.image_size.first, image.width());
|
||||
EXPECT_EQ(result.image_size.second, image.height());
|
||||
ExpectApproximatelyEqual(result.classification_result,
|
||||
GenerateBurgerResults(timestamp_ms));
|
||||
GenerateBurgerResults());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->classifier_options.max_results = 1;
|
||||
options->result_callback =
|
||||
[&results](absl::StatusOr<ClassificationResult> classification_result,
|
||||
[&results](absl::StatusOr<ImageClassifierResult> classification_result,
|
||||
const Image& image, int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(classification_result.status());
|
||||
results.push_back(
|
||||
|
@ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
EXPECT_EQ(result.image_size.first, image.width());
|
||||
EXPECT_EQ(result.image_size.second, image.height());
|
||||
ExpectApproximatelyEqual(result.classification_result,
|
||||
GenerateSoccerBallResults(timestamp_ms));
|
||||
GenerateSoccerBallResults());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user