Cleanup after migration to new classification output format.

PiperOrigin-RevId: 489921603
This commit is contained in:
MediaPipe Team 2022-11-21 01:55:49 -08:00 committed by Copybara-Service
parent 13c6b9a8c6
commit 7acbf557a1
15 changed files with 23 additions and 515 deletions

View File

@ -37,7 +37,6 @@ cc_library(
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status",
],

View File

@ -25,14 +25,12 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into either a ClassificationResult object
// representing the classification results aggregated by classifier head, or
@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications;
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// Example without timestamp aggregation:
// node {
@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node {
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
CalculatorContext* cc);
// TODO: deprecate this function once migration is over.
ClassificationResult LegacyConvertToClassificationResult(
CalculatorContext* cc);
};
absl::Status ClassificationAggregationCalculator::UpdateContract(
@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
<< "The size of classifications input streams should match the "
"size of head names specified in the calculator options";
}
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
// not connected. All dependent tasks must be updated to use these outputs
// first.
if (kTimestampsIn(cc).IsConnected()) {
RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected());
} else {
RET_CHECK(kClassificationsOut(cc).IsConnected());
}
return absl::OkStatus();
}
@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process(
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
classification_result = LegacyConvertToClassificationResult(cc);
kTimestampedClassificationsOut(cc).Send(
ConvertToTimestampedClassificationResults(cc));
} else {
classification_result = LegacyConvertToClassificationResult(cc);
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
}
kClassificationResultOut(cc).Send(classification_result);
@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
return results;
}
ClassificationResult
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
CalculatorContext* cc) {
ClassificationResult result;
Timestamp first_timestamp(0);
std::vector<Timestamp> timestamps;
if (time_aggregation_enabled_) {
timestamps = kTimestampsIn(cc).Get();
first_timestamp = timestamps[0];
} else {
timestamps = {cc->InputTimestamp()};
}
for (Timestamp timestamp : timestamps) {
int count = cached_classifications_[timestamp.Value()].size();
for (int i = 0; i < count; ++i) {
Classifications* c;
if (result.classifications_size() <= i) {
c = result.add_classifications();
if (!head_names_.empty()) {
c->set_head_index(i);
c->set_head_name(head_names_[i]);
}
} else {
c = result.mutable_classifications(i);
}
auto* entry = c->add_entries();
for (const auto& elem :
cached_classifications_[timestamp.Value()][i].classification()) {
auto* category = entry->add_categories();
if (elem.has_index()) {
category->set_index(elem.index());
}
if (elem.has_score()) {
category->set_score(elem.score());
}
if (elem.has_label()) {
category->set_category_name(elem.label());
}
if (elem.has_display_name()) {
category->set_display_name(elem.display_name());
}
}
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
1000);
}
}
return result;
}
MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator);
} // namespace api2

View File

@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "category_proto",
srcs = ["category.proto"],
)
mediapipe_proto_library(
name = "classifications_proto",
srcs = ["classifications.proto"],
deps = [
":category_proto",
"//mediapipe/framework/formats:classification_proto",
],
)

View File

@ -1,41 +0,0 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "CategoryProto";
// TODO: deprecate this message once migration is over.
// A single classification result.
message Category {
// The index of the category in the corresponding label map, usually packed in
// the TFLite Model Metadata [1].
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
optional int32 index = 1;
// The score for this category, e.g. (but not necessarily) a probability in
// [0,1].
optional float score = 2;
// A human readable name of the category filled from the label map.
optional string display_name = 3;
// An ID for the category, not necessarily human-readable, e.g. a Google
// Knowledge Graph ID [1], filled from the label map.
//
// [1]: https://developers.google.com/knowledge-graph
optional string category_name = 4;
}

View File

@ -18,27 +18,12 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
import "mediapipe/framework/formats/classification.proto";
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
option java_outer_classname = "ClassificationsProto";
// TODO: deprecate this message once migration is over.
// List of predicted categories with an optional timestamp.
message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores,
// e.g., from high to low probability.
repeated Category categories = 1;
// The optional timestamp (in milliseconds) associated to the classifcation
// entry. This is useful for time series use cases, e.g., audio
// classification.
optional int64 timestamp_ms = 2;
}
// Classifications for a given classifier head, i.e. for a given output tensor.
message Classifications {
// TODO: deprecate this field once migration is over.
repeated ClassificationEntry entries = 1;
// The classification results for this head.
optional mediapipe.ClassificationList classification_list = 4;
// The index of the classifier head these categories refer to. This is useful
@ -48,6 +33,8 @@ message Classifications {
// name.
// TODO: Add github link to metadata_schema.fbs.
optional string head_name = 3;
// Reserved fields.
reserved 1;
}
// Classifications for a given classifier model.

View File

@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kScoresTag[] = "SCORES";
constexpr char kTensorsTag[] = "TENSORS";
@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
// Struct holding the different output streams produced by the graph.
struct ClassificationPostprocessingOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<std::vector<ClassificationResult>> timestamped_classifications;
};
@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph(
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.timestamped_classifications >>
@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
// Connects output.
ClassificationPostprocessingOutputStreams output_streams{
/*classification_result=*/result_aggregation
[Output<ClassificationResult>(kClassificationResultTag)],
/*classifications=*/
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
/*timestamped_classifications=*/

View File

@ -58,9 +58,6 @@ namespace processors {
// The classification result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// // TODO: remove output once migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
absl::Status ConfigureClassificationPostprocessingGraph(
const tasks::core::ModelResources& model_resources,
const proto::ClassifierOptions& classifier_options,

View File

@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTensorsName[] = "tensors";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationResultName[] = "classification_result";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kClassificationsName[] = "classifications";
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
})pb")}));
}
// TODO: remove these tests once migration is over.
class LegacyPostprocessingTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::ClassifierOptions& options,
bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
*model_resources, options,
&postprocessing
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
postprocessing.In(kTimestampsTag);
}
postprocessing.Out(kClassificationResultTag)
.SetName(kClassificationResultName) >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kClassificationResultName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
template <typename T>
void AddTensor(
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
const Tensor::QuantizationParameters& quantization_parameters = {}) {
tensors_->emplace_back(element_type,
Tensor::Shape{1, static_cast<int>(tensor.size())},
quantization_parameters);
auto view = tensors_->back().GetCpuWriteView();
T* buffer = view.buffer<T>();
std::copy(tensor.begin(), tensor.end(), buffer);
}
absl::Status Run(
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
int timestamp = 0) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
// Reset tensors for future calls.
tensors_ = absl::make_unique<std::vector<Tensor>>();
if (aggregation_timestamps.has_value()) {
auto packet = absl::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
absl::StatusOr<ClassificationResult> GetClassificationResult(
OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<ClassificationResult>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
std::unique_ptr<std::vector<Tensor>> tensors_ =
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
options.set_score_threshold(0.5);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 18;
tensor[2] = 16;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(results, EqualsProto(R"pb(classifications {
entries {
categories { index: 1 score: 0.8 }
categories { index: 2 score: 0.6 }
timestamp_ms: 0
}
})pb"));
}
TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(
results,
EqualsProto(
R"pb(classifications {
entries {
categories {
index: 4
score: 0.8
category_name: "tiger shark"
}
categories {
index: 3
score: 0.6
category_name: "great white shark"
}
categories { index: 2 score: 0.4 category_name: "goldfish" }
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
// Build input tensors.
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
tensor[1] = 12;
tensor[2] = 14;
tensor[3] = 16;
tensor[4] = 18;
// Send tensors and get results.
AddTensor(tensor, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(results, EqualsProto(
R"pb(classifications {
entries {
categories {
index: 4
score: 0.6899744811
category_name: "tiger shark"
}
categories {
index: 3
score: 0.6456563062
category_name: "great white shark"
}
categories {
index: 2
score: 0.5986876601
category_name: "goldfish"
}
timestamp_ms: 0
}
head_index: 0
head_name: "probability"
})pb"));
}
TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
// Build input tensors.
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
tensor_0[1] = 0.2;
tensor_0[2] = 0.4;
tensor_0[3] = 0.6;
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
tensor_1[1] = 0.2;
tensor_1[2] = 0.4;
tensor_1[3] = 0.6;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
EXPECT_THAT(results, EqualsProto(
R"pb(classifications {
entries {
categories {
index: 3
score: 0.6
category_name: "Narration, monologue"
}
categories {
index: 2
score: 0.4
category_name: "Conversation"
}
timestamp_ms: 0
}
head_index: 0
head_name: "yamnet_classification"
}
classifications {
entries {
categories {
index: 3
score: 0.6
category_name: "Azara\'s Spinetail"
}
categories {
index: 2
score: 0.4
category_name: "House Sparrow"
}
timestamp_ms: 0
}
head_index: 1
head_name: "bird_classification"
})pb"));
}
TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
proto::ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
/*connect_timestamps=*/true));
// Build input tensors.
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
tensor_0[1] = 12;
tensor_0[2] = 14;
tensor_0[3] = 16;
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
tensor_1[5] = 12;
tensor_1[6] = 14;
tensor_1[7] = 16;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run());
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
/*quantization_parameters=*/{0.1, 10});
MP_ASSERT_OK(Run(
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
/*timestamp=*/1000));
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
// Validate results.
EXPECT_THAT(
results,
EqualsProto(
R"pb(classifications {
entries {
categories {
index: 3
score: 0.6
category_name: "great white shark"
}
categories { index: 2 score: 0.4 category_name: "goldfish" }
timestamp_ms: 0
}
entries {
categories { index: 7 score: 0.6 category_name: "stingray" }
categories {
index: 6
score: 0.4
category_name: "electric ray"
}
timestamp_ms: 1
}
head_index: 0
head_name: "probability"
})pb"));
}
} // namespace
} // namespace processors
} // namespace components

View File

@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
// TODO: remove once Java API migration is over.
// Struct holding the different output streams produced by the text classifier.
struct TextClassifierOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
};
} // namespace
// A "TextClassifierGraph" performs Natural Language classification (including
@ -72,10 +64,6 @@ struct TextClassifierOutputStreams {
// Outputs:
// CLASSIFICATIONS - ClassificationResult @Optional
// The classification results aggregated by classifier head.
// TODO: remove once Java API migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result object that has 3 dimensions:
// (classification head, classification timestamp, classification category).
//
// Example:
// node {
@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
auto classifications,
BuildTextClassifierTask(
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
graph[Input<std::string>(kTextTag)], graph));
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
classifications >> graph[Output<ClassificationResult>(kClassificationsTag)];
return graph.GetConfig();
}
@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
// TextClassifier model file with model metadata.
// text_in: (std::string) stream to run text classification on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
const proto::TextClassifierGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) {
@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
// Outputs the aggregated classification result as the subgraph output
// stream.
return TextClassifierOutputStreams{
/*classification_result=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)],
/*classifications=*/postprocessing[Output<ClassificationResult>(
kClassificationsTag)]};
return postprocessing[Output<ClassificationResult>(kClassificationsTag)];
}
};

View File

@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS";
// Struct holding the different output streams produced by the image classifier
// subgraph.
struct ImageClassifierOutputStreams {
Source<ClassificationResult> classification_result;
Source<ClassificationResult> classifications;
Source<Image> image;
};
@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams {
// The classification results aggregated by classifier head.
// IMAGE - Image
// The image that object detection runs on.
// TODO: remove this output once Java API migration is over.
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
// The aggregated classification result.
//
// Example:
// node {
@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.classification_result >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
output_streams.classifications >>
graph[Output<ClassificationResult>(kClassificationsTag)];
output_streams.image >> graph[Output<Image>(kImageTag)];
@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// Outputs the aggregated classification result as the subgraph output
// stream.
return ImageClassifierOutputStreams{
/*classification_result=*/postprocessing[Output<ClassificationResult>(
kClassificationResultTag)],
/*classifications=*/
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
/*image=*/preprocessing[Output<Image>(kImageTag)]};

View File

@ -48,7 +48,6 @@ android_library(
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",

View File

@ -97,7 +97,6 @@ android_library(
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",

View File

@ -68,7 +68,7 @@ py_library(
name = "category",
srcs = ["category.py"],
deps = [
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -16,10 +16,10 @@
import dataclasses
from typing import Any, Optional
from mediapipe.tasks.cc.components.containers.proto import category_pb2
from mediapipe.framework.formats import classification_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_CategoryProto = category_pb2.Category
_ClassificationProto = classification_pb2.Classification
@dataclasses.dataclass
@ -45,23 +45,23 @@ class Category:
category_name: Optional[str] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _CategoryProto:
def to_pb2(self) -> _ClassificationProto:
"""Generates a Category protobuf object."""
return _CategoryProto(
return _ClassificationProto(
index=self.index,
score=self.score,
display_name=self.display_name,
category_name=self.category_name)
label=self.category_name,
display_name=self.display_name)
@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category':
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category':
"""Creates a `Category` object from the given protobuf object."""
return Category(
index=pb2_obj.index,
score=pb2_obj.score,
display_name=pb2_obj.display_name,
category_name=pb2_obj.category_name)
category_name=pb2_obj.label)
def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.

View File

@ -49,11 +49,7 @@ class Classifications:
"""Generates a Classifications protobuf object."""
classification_list_proto = _ClassificationListProto()
for category in self.categories:
classification_proto = _ClassificationProto(
index=category.index,
score=category.score,
label=category.category_name,
display_name=category.display_name)
classification_proto = category.to_pb2()
classification_list_proto.classification.append(classification_proto)
return _ClassificationsProto(
classification_list=classification_list_proto,
@ -65,14 +61,9 @@ class Classifications:
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
"""Creates a `Classifications` object from the given protobuf object."""
categories = []
for entry in pb2_obj.classification_list.classification:
for classification in pb2_obj.classification_list.classification:
categories.append(
category_module.Category(
index=entry.index,
score=entry.score,
display_name=entry.display_name,
category_name=entry.label))
category_module.Category.create_from_pb2(classification))
return Classifications(
categories=categories,
head_index=pb2_obj.head_index,