Add classification result structs and use then in ImageClassifier & TextClassifier.
PiperOrigin-RevId: 485567133
This commit is contained in:
parent
aab5f84aae
commit
d61ab92b90
|
@ -29,3 +29,23 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "category",
|
||||||
|
srcs = ["category.cc"],
|
||||||
|
hdrs = ["category.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "classification_result",
|
||||||
|
srcs = ["classification_result.cc"],
|
||||||
|
hdrs = ["classification_result.h"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
38
mediapipe/tasks/cc/components/containers/category.cc
Normal file
38
mediapipe/tasks/cc/components/containers/category.cc
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
Category ConvertToCategory(const mediapipe::Classification& proto) {
|
||||||
|
Category category;
|
||||||
|
category.index = proto.index();
|
||||||
|
category.score = proto.score();
|
||||||
|
if (proto.has_label()) {
|
||||||
|
category.category_name = proto.label();
|
||||||
|
}
|
||||||
|
if (proto.has_display_name()) {
|
||||||
|
category.display_name = proto.display_name();
|
||||||
|
}
|
||||||
|
return category;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
52
mediapipe/tasks/cc/components/containers/category.h
Normal file
52
mediapipe/tasks/cc/components/containers/category.h
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
// Defines a single classification result.
|
||||||
|
//
|
||||||
|
// The label maps packed into the TFLite Model Metadata [1] are used to populate
|
||||||
|
// the 'category_name' and 'display_name' fields.
|
||||||
|
//
|
||||||
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
|
struct Category {
|
||||||
|
// The index of the category in the classification model output.
|
||||||
|
int index;
|
||||||
|
// The score for this category, e.g. (but not necessarily) a probability in
|
||||||
|
// [0,1].
|
||||||
|
float score;
|
||||||
|
// The optional ID for the category, read from the label map packed in the
|
||||||
|
// TFLite Model Metadata if present. Not necessarily human-readable.
|
||||||
|
std::optional<std::string> category_name = std::nullopt;
|
||||||
|
// The optional human-readable name for the category, read from the label map
|
||||||
|
// packed in the TFLite Model Metadata if present.
|
||||||
|
std::optional<std::string> display_name = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility function to convert from mediapipe::Classification proto to Category
|
||||||
|
// struct.
|
||||||
|
Category ConvertToCategory(const mediapipe::Classification& proto);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
|
|
@ -0,0 +1,57 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
Classifications ConvertToClassifications(const proto::Classifications& proto) {
|
||||||
|
Classifications classifications;
|
||||||
|
classifications.categories.reserve(
|
||||||
|
proto.classification_list().classification_size());
|
||||||
|
for (const auto& classification :
|
||||||
|
proto.classification_list().classification()) {
|
||||||
|
classifications.categories.push_back(ConvertToCategory(classification));
|
||||||
|
}
|
||||||
|
classifications.head_index = proto.head_index();
|
||||||
|
if (proto.has_head_name()) {
|
||||||
|
classifications.head_name = proto.head_name();
|
||||||
|
}
|
||||||
|
return classifications;
|
||||||
|
}
|
||||||
|
|
||||||
|
ClassificationResult ConvertToClassificationResult(
|
||||||
|
const proto::ClassificationResult& proto) {
|
||||||
|
ClassificationResult classification_result;
|
||||||
|
classification_result.classifications.reserve(proto.classifications_size());
|
||||||
|
for (const auto& classifications : proto.classifications()) {
|
||||||
|
classification_result.classifications.push_back(
|
||||||
|
ConvertToClassifications(classifications));
|
||||||
|
}
|
||||||
|
if (proto.has_timestamp_ms()) {
|
||||||
|
classification_result.timestamp_ms = proto.timestamp_ms();
|
||||||
|
}
|
||||||
|
return classification_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
|
@ -0,0 +1,68 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::components::containers {
|
||||||
|
|
||||||
|
// Defines classification results for a given classifier head.
|
||||||
|
struct Classifications {
|
||||||
|
// The array of predicted categories, usually sorted by descending scores,
|
||||||
|
// e.g. from high to low probability.
|
||||||
|
std::vector<Category> categories;
|
||||||
|
// The index of the classifier head (i.e. output tensor) these categories
|
||||||
|
// refer to. This is useful for multi-head models.
|
||||||
|
int head_index;
|
||||||
|
// The optional name of the classifier head, as provided in the TFLite Model
|
||||||
|
// Metadata [1] if present. This is useful for multi-head models.
|
||||||
|
//
|
||||||
|
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||||
|
std::optional<std::string> head_name = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Defines classification results of a model.
|
||||||
|
struct ClassificationResult {
|
||||||
|
// The classification results for each head of the model.
|
||||||
|
std::vector<Classifications> classifications;
|
||||||
|
// The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
// corresponding to these results.
|
||||||
|
//
|
||||||
|
// This is only used for classification on time series (e.g. audio
|
||||||
|
// classification). In these use cases, the amount of data to process might
|
||||||
|
// exceed the maximum size that the model can process: to solve this, the
|
||||||
|
// input data is split into multiple chunks starting at different timestamps.
|
||||||
|
std::optional<int64_t> timestamp_ms = std::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Utility function to convert from Classifications proto to
|
||||||
|
// Classifications struct.
|
||||||
|
Classifications ConvertToClassifications(const proto::Classifications& proto);
|
||||||
|
|
||||||
|
// Utility function to convert from ClassificationResult proto to
|
||||||
|
// ClassificationResult struct.
|
||||||
|
ClassificationResult ConvertToClassificationResult(
|
||||||
|
const proto::ClassificationResult& proto);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::components::containers
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
|
|
@ -49,6 +49,8 @@ cc_library(
|
||||||
":text_classifier_graph",
|
":text_classifier_graph",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
|
@ -76,7 +78,8 @@ cc_test(
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers:category",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
"@com_google_absl//absl/flags:flag",
|
"@com_google_absl//absl/flags:flag",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||||
|
@ -37,12 +38,13 @@ namespace text_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr char kTextStreamName[] = "text_in";
|
constexpr char kTextStreamName[] = "text_in";
|
||||||
constexpr char kTextTag[] = "TEXT";
|
constexpr char kTextTag[] = "TEXT";
|
||||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kSubgraphTypeName[] =
|
constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||||
|
|
||||||
|
@ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
|
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
|
||||||
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||||
subgraph.Out(kClassificationResultTag)
|
subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||||
.SetName(kClassificationResultStreamName) >>
|
graph.Out(kClassificationsTag);
|
||||||
graph.Out(kClassificationResultTag);
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,14 +89,14 @@ absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
|
||||||
std::move(options->base_options.op_resolver));
|
std::move(options->base_options.op_resolver));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
|
absl::StatusOr<TextClassifierResult> TextClassifier::Classify(
|
||||||
absl::string_view text) {
|
absl::string_view text) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_packets,
|
auto output_packets,
|
||||||
runner_->Process(
|
runner_->Process(
|
||||||
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||||
return output_packets[kClassificationResultStreamName]
|
return ConvertToClassificationResult(
|
||||||
.Get<ClassificationResult>();
|
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace text_classifier
|
} // namespace text_classifier
|
||||||
|
|
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
|
@ -31,6 +31,10 @@ namespace tasks {
|
||||||
namespace text {
|
namespace text {
|
||||||
namespace text_classifier {
|
namespace text_classifier {
|
||||||
|
|
||||||
|
// Alias the shared ClassificationResult struct as result type.
|
||||||
|
using TextClassifierResult =
|
||||||
|
::mediapipe::tasks::components::containers::ClassificationResult;
|
||||||
|
|
||||||
// The options for configuring a MediaPipe text classifier task.
|
// The options for configuring a MediaPipe text classifier task.
|
||||||
struct TextClassifierOptions {
|
struct TextClassifierOptions {
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
@ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi {
|
||||||
std::unique_ptr<TextClassifierOptions> options);
|
std::unique_ptr<TextClassifierOptions> options);
|
||||||
|
|
||||||
// Performs classification on the input `text`.
|
// Performs classification on the input `text`.
|
||||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
absl::StatusOr<TextClassifierResult> Classify(absl::string_view text);
|
||||||
absl::string_view text);
|
|
||||||
|
|
||||||
// Shuts down the TextClassifier when all the work is done.
|
// Shuts down the TextClassifier when all the work is done.
|
||||||
absl::Status Close() { return runner_->Close(); }
|
absl::Status Close() { return runner_->Close(); }
|
||||||
|
|
|
@ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kTextTag[] = "TEXT";
|
constexpr char kTextTag[] = "TEXT";
|
||||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
|
// TODO: remove once Java API migration is over.
|
||||||
|
// Struct holding the different output streams produced by the text classifier.
|
||||||
|
struct TextClassifierOutputStreams {
|
||||||
|
Source<ClassificationResult> classification_result;
|
||||||
|
Source<ClassificationResult> classifications;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A "TextClassifierGraph" performs Natural Language classification (including
|
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||||
|
@ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS";
|
||||||
// Input text to perform classification on.
|
// Input text to perform classification on.
|
||||||
//
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
|
// The classification results aggregated by classifier head.
|
||||||
|
// TODO: remove once Java API migration is over.
|
||||||
|
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||||
// The aggregated classification result object that has 3 dimensions:
|
// The aggregated classification result object that has 3 dimensions:
|
||||||
// (classification head, classification timestamp, classification category).
|
// (classification head, classification timestamp, classification category).
|
||||||
//
|
//
|
||||||
|
@ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS";
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
|
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
|
||||||
// input_stream: "TEXT:text_in"
|
// input_stream: "TEXT:text_in"
|
||||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
// output_stream: "CLASSIFICATIONS:classifications_out"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
|
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
|
||||||
// {
|
// {
|
||||||
|
@ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
Source<ClassificationResult> classification_result_out,
|
auto output_streams,
|
||||||
BuildTextClassifierTask(
|
BuildTextClassifierTask(
|
||||||
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<std::string>(kTextTag)], graph));
|
graph[Input<std::string>(kTextTag)], graph));
|
||||||
classification_result_out >>
|
output_streams.classification_result >>
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||||
|
output_streams.classifications >>
|
||||||
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
// TextClassifier model file with model metadata.
|
// TextClassifier model file with model metadata.
|
||||||
// text_in: (std::string) stream to run text classification on.
|
// text_in: (std::string) stream to run text classification on.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
|
||||||
const proto::TextClassifierGraphOptions& task_options,
|
const proto::TextClassifierGraphOptions& task_options,
|
||||||
const ModelResources& model_resources, Source<std::string> text_in,
|
const ModelResources& model_resources, Source<std::string> text_in,
|
||||||
Graph& graph) {
|
Graph& graph) {
|
||||||
|
@ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
// Outputs the aggregated classification result as the subgraph output
|
// Outputs the aggregated classification result as the subgraph output
|
||||||
// stream.
|
// stream.
|
||||||
return postprocessing[Output<ClassificationResult>(
|
return TextClassifierOutputStreams{
|
||||||
kClassificationResultTag)];
|
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
||||||
|
kClassificationResultTag)],
|
||||||
|
/*classifications=*/postprocessing[Output<ClassificationResult>(
|
||||||
|
kClassificationsTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,8 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
@ -43,14 +44,13 @@ namespace text {
|
||||||
namespace text_classifier {
|
namespace text_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::EqualsProto;
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::Category;
|
||||||
|
using ::mediapipe::tasks::components::containers::Classifications;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
constexpr float kEpsilon = 0.001;
|
|
||||||
constexpr int kMaxSeqLen = 128;
|
constexpr int kMaxSeqLen = 128;
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||||
|
@ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) {
|
||||||
return JoinPath("./", kTestDataDirectory, file_name);
|
return JoinPath("./", kTestDataDirectory, file_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks that the two provided `TextClassifierResult` are equal, with a
|
||||||
|
// tolerancy on floating-point score to account for numerical instabilities.
|
||||||
|
// TODO: create shared matcher for ClassificationResult.
|
||||||
|
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
|
||||||
|
const TextClassifierResult& expected) {
|
||||||
|
const float kPrecision = 1e-6;
|
||||||
|
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
|
||||||
|
for (int i = 0; i < actual.classifications.size(); ++i) {
|
||||||
|
const Classifications& a = actual.classifications[i];
|
||||||
|
const Classifications& b = expected.classifications[i];
|
||||||
|
EXPECT_EQ(a.head_index, b.head_index);
|
||||||
|
EXPECT_EQ(a.head_name, b.head_name);
|
||||||
|
EXPECT_EQ(a.categories.size(), b.categories.size());
|
||||||
|
for (int j = 0; j < a.categories.size(); ++j) {
|
||||||
|
const Category& x = a.categories[j];
|
||||||
|
const Category& y = b.categories[j];
|
||||||
|
EXPECT_EQ(x.index, y.index);
|
||||||
|
EXPECT_NEAR(x.score, y.score, kPrecision);
|
||||||
|
EXPECT_EQ(x.category_name, y.category_name);
|
||||||
|
EXPECT_EQ(x.display_name, y.display_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||||
|
|
||||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||||
|
@ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
||||||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextClassifierResult negative_result,
|
||||||
|
classifier->Classify("unflinchingly bleak and desperate"));
|
||||||
|
TextClassifierResult negative_expected;
|
||||||
|
negative_expected.classifications.emplace_back(Classifications{
|
||||||
|
/*categories=*/{
|
||||||
|
{/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
|
||||||
|
{/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
ExpectApproximatelyEqual(negative_result, negative_expected);
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextClassifierResult positive_result,
|
||||||
|
classifier->Classify("it's a charming and often affecting journey"));
|
||||||
|
TextClassifierResult positive_expected;
|
||||||
|
positive_expected.classifications.emplace_back(Classifications{
|
||||||
|
/*categories=*/{
|
||||||
|
{/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
|
||||||
|
{/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
ExpectApproximatelyEqual(positive_result, positive_expected);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result,
|
||||||
|
classifier->Classify("What a waste of my time."));
|
||||||
|
TextClassifierResult negative_expected;
|
||||||
|
negative_expected.classifications.emplace_back(Classifications{
|
||||||
|
/*categories=*/{
|
||||||
|
{/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"},
|
||||||
|
{/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}},
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
ExpectApproximatelyEqual(negative_result, negative_expected);
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextClassifierResult positive_result,
|
||||||
|
classifier->Classify("This is the best movie I’ve seen in recent years."
|
||||||
|
"Strongly recommend it!"));
|
||||||
|
TextClassifierResult positive_expected;
|
||||||
|
positive_expected.classifications.emplace_back(Classifications{
|
||||||
|
/*categories=*/{
|
||||||
|
{/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"},
|
||||||
|
{/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}},
|
||||||
|
/*head_index=*/0,
|
||||||
|
/*head_name=*/"probability"});
|
||||||
|
ExpectApproximatelyEqual(positive_result, positive_expected);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||||
|
auto options = std::make_unique<TextClassifierOptions>();
|
||||||
|
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||||
|
options->base_options.op_resolver = CreateCustomResolver();
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||||
|
TextClassifier::Create(std::move(options)));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
|
||||||
|
classifier->Classify("hello"));
|
||||||
|
|
||||||
|
// Binary outputs causes flaky ordering, so we compare manually.
|
||||||
|
ASSERT_EQ(result.classifications.size(), 1);
|
||||||
|
ASSERT_EQ(result.classifications[0].head_index, 0);
|
||||||
|
ASSERT_EQ(result.classifications[0].categories.size(), 3);
|
||||||
|
ASSERT_EQ(result.classifications[0].categories[0].score, 1);
|
||||||
|
ASSERT_LT(result.classifications[0].categories[0].index, 2); // i.e O or 1.
|
||||||
|
ASSERT_EQ(result.classifications[0].categories[1].score, 1);
|
||||||
|
ASSERT_LT(result.classifications[0].categories[1].index, 2); // i.e 0 or 1.
|
||||||
|
ASSERT_EQ(result.classifications[0].categories[2].score, 0);
|
||||||
|
ASSERT_EQ(result.classifications[0].categories[2].index, 2);
|
||||||
|
MP_ASSERT_OK(classifier->Close());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace text_classifier
|
} // namespace text_classifier
|
||||||
} // namespace text
|
} // namespace text
|
||||||
|
|
|
@ -50,6 +50,7 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
|
@ -46,8 +47,8 @@ namespace image_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
constexpr char kClassificationsStreamName[] = "classifications_out";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kImageInStreamName[] = "image_in";
|
constexpr char kImageInStreamName[] = "image_in";
|
||||||
constexpr char kImageOutStreamName[] = "image_out";
|
constexpr char kImageOutStreamName[] = "image_out";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
@ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
|
|
||||||
|
@ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig(
|
||||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||||
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
|
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
|
||||||
options_proto.get());
|
options_proto.get());
|
||||||
task_subgraph.Out(kClassificationResultTag)
|
task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
|
||||||
.SetName(kClassificationResultStreamName) >>
|
graph.Out(kClassificationsTag);
|
||||||
graph.Out(kClassificationResultTag);
|
|
||||||
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
|
||||||
graph.Out(kImageTag);
|
graph.Out(kImageTag);
|
||||||
if (enable_flow_limiting) {
|
if (enable_flow_limiting) {
|
||||||
return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
|
return tasks::core::AddFlowLimiterCalculator(
|
||||||
{kImageTag, kNormRectTag},
|
graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag);
|
||||||
kClassificationResultTag);
|
|
||||||
}
|
}
|
||||||
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
graph.In(kImageTag) >> task_subgraph.In(kImageTag);
|
||||||
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
|
||||||
|
@ -125,13 +125,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
||||||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Packet classification_result_packet =
|
Packet classifications_packet =
|
||||||
status_or_packets.value()[kClassificationResultStreamName];
|
status_or_packets.value()[kClassificationsStreamName];
|
||||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||||
result_callback(
|
result_callback(
|
||||||
classification_result_packet.Get<ClassificationResult>(),
|
ConvertToClassificationResult(
|
||||||
|
classifications_packet.Get<ClassificationResult>()),
|
||||||
image_packet.Get<Image>(),
|
image_packet.Get<Image>(),
|
||||||
classification_result_packet.Timestamp().Value() /
|
classifications_packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond);
|
kMicroSecondsPerMilliSecond);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -144,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
absl::StatusOr<ImageClassifierResult> ImageClassifier::Classify(
|
||||||
Image image,
|
Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -160,11 +161,11 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
|
||||||
ProcessImageData(
|
ProcessImageData(
|
||||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||||
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
{kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
return output_packets[kClassificationResultStreamName]
|
return ConvertToClassificationResult(
|
||||||
.Get<ClassificationResult>();
|
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
absl::StatusOr<ImageClassifierResult> ImageClassifier::ClassifyForVideo(
|
||||||
Image image, int64 timestamp_ms,
|
Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -184,8 +185,8 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
|
||||||
{kNormRectName,
|
{kNormRectName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
return output_packets[kClassificationResultStreamName]
|
return ConvertToClassificationResult(
|
||||||
.Get<ClassificationResult>();
|
output_packets[kClassificationsStreamName].Get<ClassificationResult>());
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageClassifier::ClassifyAsync(
|
absl::Status ImageClassifier::ClassifyAsync(
|
||||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
|
@ -34,6 +34,10 @@ namespace tasks {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace image_classifier {
|
namespace image_classifier {
|
||||||
|
|
||||||
|
// Alias the shared ClassificationResult struct as result type.
|
||||||
|
using ImageClassifierResult =
|
||||||
|
::mediapipe::tasks::components::containers::ClassificationResult;
|
||||||
|
|
||||||
// The options for configuring a Mediapipe image classifier task.
|
// The options for configuring a Mediapipe image classifier task.
|
||||||
struct ImageClassifierOptions {
|
struct ImageClassifierOptions {
|
||||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||||
|
@ -56,9 +60,8 @@ struct ImageClassifierOptions {
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(
|
std::function<void(absl::StatusOr<ImageClassifierResult>, const Image&,
|
||||||
absl::StatusOr<components::containers::proto::ClassificationResult>,
|
int64)>
|
||||||
const Image&, int64)>
|
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
// TODO: describe exact preprocessing steps once
|
// TODO: describe exact preprocessing steps once
|
||||||
// YUVToImageCalculator is integrated.
|
// YUVToImageCalculator is integrated.
|
||||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
absl::StatusOr<ImageClassifierResult> Classify(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA. It's required to
|
// The image can be of any size with format RGB or RGBA. It's required to
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<components::containers::proto::ClassificationResult>
|
absl::StatusOr<ImageClassifierResult> ClassifyForVideo(
|
||||||
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions>
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
image_processing_options = std::nullopt);
|
std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to image classification, and the results will be
|
// Sends live image data to image classification, and the results will be
|
||||||
// available via the "result_callback" provided in the ImageClassifierOptions.
|
// available via the "result_callback" provided in the ImageClassifierOptions.
|
||||||
|
@ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// increasing.
|
// increasing.
|
||||||
//
|
//
|
||||||
// The "result_callback" provides:
|
// The "result_callback" provides:
|
||||||
// - The classification results as a ClassificationResult object.
|
// - The classification results as an ImageClassifierResult object.
|
||||||
// - The const reference to the corresponding input image that the image
|
// - The const reference to the corresponding input image that the image
|
||||||
// classifier runs on. Note that the const reference to the image will no
|
// classifier runs on. Note that the const reference to the image will no
|
||||||
// longer be valid when the callback returns. To access the image data
|
// longer be valid when the callback returns. To access the image data
|
||||||
|
|
|
@ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||||
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
@ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS";
|
||||||
// subgraph.
|
// subgraph.
|
||||||
struct ImageClassifierOutputStreams {
|
struct ImageClassifierOutputStreams {
|
||||||
Source<ClassificationResult> classification_result;
|
Source<ClassificationResult> classification_result;
|
||||||
|
Source<ClassificationResult> classifications;
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,17 +73,19 @@ struct ImageClassifierOutputStreams {
|
||||||
// Describes region of image to perform classification on.
|
// Describes region of image to perform classification on.
|
||||||
// @Optional: rect covering the whole image is used if not specified.
|
// @Optional: rect covering the whole image is used if not specified.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
// The aggregated classification result object has two dimensions:
|
// The classification results aggregated by classifier head.
|
||||||
// (classification head, classification category)
|
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
// The image that object detection runs on.
|
// The image that object detection runs on.
|
||||||
|
// TODO: remove this output once Java API migration is over.
|
||||||
|
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||||
|
// The aggregated classification result.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
|
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
|
||||||
// input_stream: "IMAGE:image_in"
|
// input_stream: "IMAGE:image_in"
|
||||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
// output_stream: "CLASSIFICATIONS:classifications_out"
|
||||||
// output_stream: "IMAGE:image_out"
|
// output_stream: "IMAGE:image_out"
|
||||||
// options {
|
// options {
|
||||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
|
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
|
||||||
|
@ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
output_streams.classification_result >>
|
output_streams.classification_result >>
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||||
|
output_streams.classifications >>
|
||||||
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
@ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
return ImageClassifierOutputStreams{
|
return ImageClassifierOutputStreams{
|
||||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
||||||
kClassificationResultTag)],
|
kClassificationResultTag)],
|
||||||
|
/*classifications=*/
|
||||||
|
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
|
||||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -32,8 +32,8 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/category.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
@ -50,10 +50,9 @@ namespace image_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::Category;
|
||||||
|
using ::mediapipe::tasks::components::containers::Classifications;
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::Rect;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
@ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
|
||||||
constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] =
|
constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] =
|
||||||
"mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
|
"mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
|
||||||
|
|
||||||
// Checks that the two provided `ClassificationResult` are equal, with a
|
// Checks that the two provided `ImageClassifierResult` are equal, with a
|
||||||
// tolerancy on floating-point score to account for numerical instabilities.
|
// tolerancy on floating-point score to account for numerical instabilities.
|
||||||
void ExpectApproximatelyEqual(const ClassificationResult& actual,
|
void ExpectApproximatelyEqual(const ImageClassifierResult& actual,
|
||||||
const ClassificationResult& expected) {
|
const ImageClassifierResult& expected) {
|
||||||
const float kPrecision = 1e-6;
|
const float kPrecision = 1e-6;
|
||||||
ASSERT_EQ(actual.classifications_size(), expected.classifications_size());
|
ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
|
||||||
for (int i = 0; i < actual.classifications_size(); ++i) {
|
for (int i = 0; i < actual.classifications.size(); ++i) {
|
||||||
const Classifications& a = actual.classifications(i);
|
const Classifications& a = actual.classifications[i];
|
||||||
const Classifications& b = expected.classifications(i);
|
const Classifications& b = expected.classifications[i];
|
||||||
EXPECT_EQ(a.head_index(), b.head_index());
|
EXPECT_EQ(a.head_index, b.head_index);
|
||||||
EXPECT_EQ(a.head_name(), b.head_name());
|
EXPECT_EQ(a.head_name, b.head_name);
|
||||||
EXPECT_EQ(a.entries_size(), b.entries_size());
|
EXPECT_EQ(a.categories.size(), b.categories.size());
|
||||||
for (int j = 0; j < a.entries_size(); ++j) {
|
for (int j = 0; j < a.categories.size(); ++j) {
|
||||||
const ClassificationEntry& x = a.entries(j);
|
const Category& x = a.categories[j];
|
||||||
const ClassificationEntry& y = b.entries(j);
|
const Category& y = b.categories[j];
|
||||||
EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms());
|
EXPECT_EQ(x.index, y.index);
|
||||||
EXPECT_EQ(x.categories_size(), y.categories_size());
|
EXPECT_NEAR(x.score, y.score, kPrecision);
|
||||||
for (int k = 0; k < x.categories_size(); ++k) {
|
EXPECT_EQ(x.category_name, y.category_name);
|
||||||
EXPECT_EQ(x.categories(k).index(), y.categories(k).index());
|
EXPECT_EQ(x.display_name, y.display_name);
|
||||||
EXPECT_EQ(x.categories(k).category_name(),
|
|
||||||
y.categories(k).category_name());
|
|
||||||
EXPECT_EQ(x.categories(k).display_name(),
|
|
||||||
y.categories(k).display_name());
|
|
||||||
EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(),
|
|
||||||
kPrecision);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
|
// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
|
||||||
// with max_results set to 3.
|
// with max_results set to 3.
|
||||||
ClassificationResult GenerateBurgerResults(int64 timestamp) {
|
ImageClassifierResult GenerateBurgerResults() {
|
||||||
return ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult result;
|
||||||
absl::StrFormat(R"pb(classifications {
|
result.classifications.emplace_back(Classifications{
|
||||||
entries {
|
/*categories=*/{
|
||||||
categories {
|
{/*index=*/934, /*score=*/0.793959200,
|
||||||
index: 934
|
/*category_name=*/"cheeseburger"},
|
||||||
score: 0.7939592
|
{/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"},
|
||||||
category_name: "cheeseburger"
|
{/*index=*/925, /*score=*/0.019340655,
|
||||||
}
|
/*category_name=*/"guacamole"}},
|
||||||
categories {
|
/*head_index=*/0,
|
||||||
index: 932
|
/*head_name=*/"probability"});
|
||||||
score: 0.027392805
|
return result;
|
||||||
category_name: "bagel"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 925
|
|
||||||
score: 0.019340655
|
|
||||||
category_name: "guacamole"
|
|
||||||
}
|
|
||||||
timestamp_ms: %d
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb",
|
|
||||||
timestamp));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates expected results for "multi_objects.jpg" using
|
// Generates expected results for "multi_objects.jpg" using
|
||||||
// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
|
// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
|
||||||
// box set around the soccer ball.
|
// box set around the soccer ball.
|
||||||
ClassificationResult GenerateSoccerBallResults(int64 timestamp) {
|
ImageClassifierResult GenerateSoccerBallResults() {
|
||||||
return ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult result;
|
||||||
absl::StrFormat(R"pb(classifications {
|
result.classifications.emplace_back(
|
||||||
entries {
|
Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493,
|
||||||
categories {
|
/*category_name=*/"soccer ball"}},
|
||||||
index: 806
|
/*head_index=*/0,
|
||||||
score: 0.996527493
|
/*head_name=*/"probability"});
|
||||||
category_name: "soccer ball"
|
return result;
|
||||||
}
|
|
||||||
timestamp_ms: %d
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb",
|
|
||||||
timestamp));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A custom OpResolver only containing the Ops required by the test model.
|
// A custom OpResolver only containing the Ops required by the test model.
|
||||||
|
@ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata);
|
||||||
options->running_mode = running_mode;
|
options->running_mode = running_mode;
|
||||||
options->result_callback = [](absl::StatusOr<ClassificationResult>,
|
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
|
|
||||||
auto image_classifier = ImageClassifier::Create(std::move(options));
|
auto image_classifier = ImageClassifier::Create(std::move(options));
|
||||||
|
@ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, GenerateBurgerResults(0));
|
ExpectApproximatelyEqual(results, GenerateBurgerResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
|
TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
|
||||||
|
@ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(
|
||||||
entries {
|
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625,
|
||||||
categories {
|
/*category_name=*/"cheeseburger"}},
|
||||||
index: 934
|
/*head_index=*/0,
|
||||||
score: 0.97265625
|
/*head_name=*/"probability"});
|
||||||
category_name: "cheeseburger"
|
ExpectApproximatelyEqual(results, expected);
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||||
|
@ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(
|
||||||
entries {
|
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592,
|
||||||
categories {
|
/*category_name=*/"cheeseburger"}},
|
||||||
index: 934
|
/*head_index=*/0,
|
||||||
score: 0.7939592
|
/*head_name=*/"probability"});
|
||||||
category_name: "cheeseburger"
|
ExpectApproximatelyEqual(results, expected);
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
||||||
|
@ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(Classifications{
|
||||||
entries {
|
/*categories=*/{
|
||||||
categories {
|
{/*index=*/934, /*score=*/0.7939592,
|
||||||
index: 934
|
/*category_name=*/"cheeseburger"},
|
||||||
score: 0.7939592
|
{/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}},
|
||||||
category_name: "cheeseburger"
|
/*head_index=*/0,
|
||||||
}
|
/*head_name=*/"probability"});
|
||||||
categories {
|
ExpectApproximatelyEqual(results, expected);
|
||||||
index: 932
|
|
||||||
score: 0.027392805
|
|
||||||
category_name: "bagel"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
||||||
|
@ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(Classifications{
|
||||||
entries {
|
/*categories=*/{
|
||||||
categories {
|
{/*index=*/934, /*score=*/0.7939592,
|
||||||
index: 934
|
/*category_name=*/"cheeseburger"},
|
||||||
score: 0.7939592
|
{/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
|
||||||
category_name: "cheeseburger"
|
{/*index=*/963, /*score=*/0.0063278517,
|
||||||
}
|
/*category_name=*/"meat loaf"}},
|
||||||
categories {
|
/*head_index=*/0,
|
||||||
index: 925
|
/*head_name=*/"probability"});
|
||||||
score: 0.019340655
|
ExpectApproximatelyEqual(results, expected);
|
||||||
category_name: "guacamole"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 963
|
|
||||||
score: 0.0063278517
|
|
||||||
category_name: "meat loaf"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||||
|
@ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(Classifications{
|
||||||
entries {
|
/*categories=*/{
|
||||||
categories {
|
{/*index=*/934, /*score=*/0.7939592,
|
||||||
index: 934
|
/*category_name=*/"cheeseburger"},
|
||||||
score: 0.7939592
|
{/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
|
||||||
category_name: "cheeseburger"
|
{/*index=*/963, /*score=*/0.0063278517,
|
||||||
}
|
/*category_name=*/"meat loaf"}},
|
||||||
categories {
|
/*head_index=*/0,
|
||||||
index: 925
|
/*head_name=*/"probability"});
|
||||||
score: 0.019340655
|
ExpectApproximatelyEqual(results, expected);
|
||||||
category_name: "guacamole"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 963
|
|
||||||
score: 0.0063278517
|
|
||||||
category_name: "meat loaf"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
||||||
|
@ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(
|
||||||
entries {
|
Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628,
|
||||||
categories {
|
/*category_name=*/"cheeseburger"}},
|
||||||
index: 934
|
/*head_index=*/0,
|
||||||
score: 0.725648628
|
/*head_name=*/"probability"});
|
||||||
category_name: "cheeseburger"
|
ExpectApproximatelyEqual(results, expected);
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
|
@ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||||
image, image_processing_options));
|
image, image_processing_options));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
|
ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
|
@ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||||
// Results differ slightly from the non-rotated image, but that's expected
|
// Results differ slightly from the non-rotated image, but that's expected
|
||||||
// as models are very sensitive to the slightest numerical differences
|
// as models are very sensitive to the slightest numerical differences
|
||||||
// introduced by the rotation and JPG encoding.
|
// introduced by the rotation and JPG encoding.
|
||||||
ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
|
ImageClassifierResult expected;
|
||||||
R"pb(classifications {
|
expected.classifications.emplace_back(Classifications{
|
||||||
entries {
|
/*categories=*/{
|
||||||
categories {
|
{/*index=*/934, /*score=*/0.6371766,
|
||||||
index: 934
|
/*category_name=*/"cheeseburger"},
|
||||||
score: 0.6371766
|
{/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"},
|
||||||
category_name: "cheeseburger"
|
{/*index=*/925, /*score=*/0.047918003,
|
||||||
}
|
/*category_name=*/"guacamole"}},
|
||||||
categories {
|
/*head_index=*/0,
|
||||||
index: 963
|
/*head_name=*/"probability"});
|
||||||
score: 0.049443405
|
ExpectApproximatelyEqual(results, expected);
|
||||||
category_name: "meat loaf"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 925
|
|
||||||
score: 0.047918003
|
|
||||||
category_name: "guacamole"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
|
@ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
|
||||||
image, image_processing_options));
|
image, image_processing_options));
|
||||||
|
|
||||||
ExpectApproximatelyEqual(results,
|
ImageClassifierResult expected;
|
||||||
ParseTextProtoOrDie<ClassificationResult>(
|
expected.classifications.emplace_back(
|
||||||
R"pb(classifications {
|
Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213,
|
||||||
entries {
|
/*category_name=*/"folding chair"}},
|
||||||
categories {
|
/*head_index=*/0,
|
||||||
index: 560
|
/*head_name=*/"probability"});
|
||||||
score: 0.6522213
|
ExpectApproximatelyEqual(results, expected);
|
||||||
category_name: "folding chair"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Testing all these once with ImageClassifier.
|
// Testing all these once with ImageClassifier.
|
||||||
|
@ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) {
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
image_classifier->ClassifyForVideo(image, i));
|
image_classifier->ClassifyForVideo(image, i));
|
||||||
ExpectApproximatelyEqual(results, GenerateBurgerResults(i));
|
ExpectApproximatelyEqual(results, GenerateBurgerResults());
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(image_classifier->Close());
|
MP_ASSERT_OK(image_classifier->Close());
|
||||||
}
|
}
|
||||||
|
@ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto results,
|
auto results,
|
||||||
image_classifier->ClassifyForVideo(image, i, image_processing_options));
|
image_classifier->ClassifyForVideo(image, i, image_processing_options));
|
||||||
ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
|
ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(image_classifier->Close());
|
MP_ASSERT_OK(image_classifier->Close());
|
||||||
}
|
}
|
||||||
|
@ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback = [](absl::StatusOr<ClassificationResult>,
|
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
|
@ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback = [](absl::StatusOr<ClassificationResult>,
|
options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
|
||||||
const Image& image, int64 timestamp_ms) {};
|
const Image& image, int64 timestamp_ms) {};
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
|
||||||
ImageClassifier::Create(std::move(options)));
|
ImageClassifier::Create(std::move(options)));
|
||||||
|
@ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LiveStreamModeResults {
|
struct LiveStreamModeResults {
|
||||||
ClassificationResult classification_result;
|
ImageClassifierResult classification_result;
|
||||||
std::pair<int, int> image_size;
|
std::pair<int, int> image_size;
|
||||||
int64 timestamp_ms;
|
int64 timestamp_ms;
|
||||||
};
|
};
|
||||||
|
@ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->classifier_options.max_results = 3;
|
options->classifier_options.max_results = 3;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&results](absl::StatusOr<ClassificationResult> classification_result,
|
[&results](absl::StatusOr<ImageClassifierResult> classification_result,
|
||||||
const Image& image, int64 timestamp_ms) {
|
const Image& image, int64 timestamp_ms) {
|
||||||
MP_ASSERT_OK(classification_result.status());
|
MP_ASSERT_OK(classification_result.status());
|
||||||
results.push_back(
|
results.push_back(
|
||||||
|
@ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
|
||||||
EXPECT_EQ(result.image_size.first, image.width());
|
EXPECT_EQ(result.image_size.first, image.width());
|
||||||
EXPECT_EQ(result.image_size.second, image.height());
|
EXPECT_EQ(result.image_size.second, image.height());
|
||||||
ExpectApproximatelyEqual(result.classification_result,
|
ExpectApproximatelyEqual(result.classification_result,
|
||||||
GenerateBurgerResults(timestamp_ms));
|
GenerateBurgerResults());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->classifier_options.max_results = 1;
|
options->classifier_options.max_results = 1;
|
||||||
options->result_callback =
|
options->result_callback =
|
||||||
[&results](absl::StatusOr<ClassificationResult> classification_result,
|
[&results](absl::StatusOr<ImageClassifierResult> classification_result,
|
||||||
const Image& image, int64 timestamp_ms) {
|
const Image& image, int64 timestamp_ms) {
|
||||||
MP_ASSERT_OK(classification_result.status());
|
MP_ASSERT_OK(classification_result.status());
|
||||||
results.push_back(
|
results.push_back(
|
||||||
|
@ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
||||||
EXPECT_EQ(result.image_size.first, image.width());
|
EXPECT_EQ(result.image_size.first, image.width());
|
||||||
EXPECT_EQ(result.image_size.second, image.height());
|
EXPECT_EQ(result.image_size.second, image.height());
|
||||||
ExpectApproximatelyEqual(result.classification_result,
|
ExpectApproximatelyEqual(result.classification_result,
|
||||||
GenerateSoccerBallResults(timestamp_ms));
|
GenerateSoccerBallResults());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user