Cleanup after migration to new classification output format.
PiperOrigin-RevId: 489921603
This commit is contained in:
parent
13c6b9a8c6
commit
7acbf557a1
|
@ -37,7 +37,6 @@ cc_library(
|
|||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
|
|
|
@ -25,14 +25,12 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
|
||||
// Aggregates ClassificationLists into either a ClassificationResult object
|
||||
// representing the classification results aggregated by classifier head, or
|
||||
|
@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications;
|
|||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// Example without timestamp aggregation:
|
||||
// node {
|
||||
|
@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node {
|
|||
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
||||
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
|
||||
CalculatorContext* cc);
|
||||
// TODO: deprecate this function once migration is over.
|
||||
ClassificationResult LegacyConvertToClassificationResult(
|
||||
CalculatorContext* cc);
|
||||
};
|
||||
|
||||
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||
|
@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
|
|||
<< "The size of classifications input streams should match the "
|
||||
"size of head names specified in the calculator options";
|
||||
}
|
||||
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
|
||||
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
|
||||
// not connected. All dependent tasks must be updated to use these outputs
|
||||
// first.
|
||||
if (kTimestampsIn(cc).IsConnected()) {
|
||||
RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected());
|
||||
} else {
|
||||
RET_CHECK(kClassificationsOut(cc).IsConnected());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process(
|
|||
if (kTimestampsIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
classification_result = LegacyConvertToClassificationResult(cc);
|
||||
kTimestampedClassificationsOut(cc).Send(
|
||||
ConvertToTimestampedClassificationResults(cc));
|
||||
} else {
|
||||
classification_result = LegacyConvertToClassificationResult(cc);
|
||||
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
|
||||
}
|
||||
kClassificationResultOut(cc).Send(classification_result);
|
||||
|
@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
|
|||
return results;
|
||||
}
|
||||
|
||||
ClassificationResult
|
||||
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
|
||||
CalculatorContext* cc) {
|
||||
ClassificationResult result;
|
||||
Timestamp first_timestamp(0);
|
||||
std::vector<Timestamp> timestamps;
|
||||
if (time_aggregation_enabled_) {
|
||||
timestamps = kTimestampsIn(cc).Get();
|
||||
first_timestamp = timestamps[0];
|
||||
} else {
|
||||
timestamps = {cc->InputTimestamp()};
|
||||
}
|
||||
for (Timestamp timestamp : timestamps) {
|
||||
int count = cached_classifications_[timestamp.Value()].size();
|
||||
for (int i = 0; i < count; ++i) {
|
||||
Classifications* c;
|
||||
if (result.classifications_size() <= i) {
|
||||
c = result.add_classifications();
|
||||
if (!head_names_.empty()) {
|
||||
c->set_head_index(i);
|
||||
c->set_head_name(head_names_[i]);
|
||||
}
|
||||
} else {
|
||||
c = result.mutable_classifications(i);
|
||||
}
|
||||
auto* entry = c->add_entries();
|
||||
for (const auto& elem :
|
||||
cached_classifications_[timestamp.Value()][i].classification()) {
|
||||
auto* category = entry->add_categories();
|
||||
if (elem.has_index()) {
|
||||
category->set_index(elem.index());
|
||||
}
|
||||
if (elem.has_score()) {
|
||||
category->set_score(elem.score());
|
||||
}
|
||||
if (elem.has_label()) {
|
||||
category->set_category_name(elem.label());
|
||||
}
|
||||
if (elem.has_display_name()) {
|
||||
category->set_display_name(elem.display_name());
|
||||
}
|
||||
}
|
||||
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
|
||||
1000);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator);
|
||||
|
||||
} // namespace api2
|
||||
|
|
|
@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "category_proto",
|
||||
srcs = ["category.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifications_proto",
|
||||
srcs = ["classifications.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
"//mediapipe/framework/formats:classification_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// TODO: deprecate this message once migration is over.
|
||||
// A single classification result.
|
||||
message Category {
|
||||
// The index of the category in the corresponding label map, usually packed in
|
||||
// the TFLite Model Metadata [1].
|
||||
//
|
||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
||||
optional int32 index = 1;
|
||||
// The score for this category, e.g. (but not necessarily) a probability in
|
||||
// [0,1].
|
||||
optional float score = 2;
|
||||
// A human readable name of the category filled from the label map.
|
||||
optional string display_name = 3;
|
||||
// An ID for the category, not necessarily human-readable, e.g. a Google
|
||||
// Knowledge Graph ID [1], filled from the label map.
|
||||
//
|
||||
// [1]: https://developers.google.com/knowledge-graph
|
||||
optional string category_name = 4;
|
||||
}
|
|
@ -18,27 +18,12 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
import "mediapipe/framework/formats/classification.proto";
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// TODO: deprecate this message once migration is over.
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
// e.g., from high to low probability.
|
||||
repeated Category categories = 1;
|
||||
// The optional timestamp (in milliseconds) associated to the classifcation
|
||||
// entry. This is useful for time series use cases, e.g., audio
|
||||
// classification.
|
||||
optional int64 timestamp_ms = 2;
|
||||
}
|
||||
|
||||
// Classifications for a given classifier head, i.e. for a given output tensor.
|
||||
message Classifications {
|
||||
// TODO: deprecate this field once migration is over.
|
||||
repeated ClassificationEntry entries = 1;
|
||||
// The classification results for this head.
|
||||
optional mediapipe.ClassificationList classification_list = 4;
|
||||
// The index of the classifier head these categories refer to. This is useful
|
||||
|
@ -48,6 +33,8 @@ message Classifications {
|
|||
// name.
|
||||
// TODO: Add github link to metadata_schema.fbs.
|
||||
optional string head_name = 3;
|
||||
// Reserved fields.
|
||||
reserved 1;
|
||||
}
|
||||
|
||||
// Classifications for a given classifier model.
|
||||
|
|
|
@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
|||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kScoresTag[] = "SCORES";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
|||
|
||||
// Struct holding the different output streams produced by the graph.
|
||||
struct ClassificationPostprocessingOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||
};
|
||||
|
@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
|||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||
|
@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.timestamped_classifications >>
|
||||
|
@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
|
||||
// Connects output.
|
||||
ClassificationPostprocessingOutputStreams output_streams{
|
||||
/*classification_result=*/result_aggregation
|
||||
[Output<ClassificationResult>(kClassificationResultTag)],
|
||||
/*classifications=*/
|
||||
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
|
||||
/*timestamped_classifications=*/
|
||||
|
|
|
@ -58,9 +58,6 @@ namespace processors {
|
|||
// The classification result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
// // TODO: remove output once migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
|
|
|
@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
constexpr char kTensorsName[] = "tensors";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationResultName[] = "classification_result";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kClassificationsName[] = "classifications";
|
||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||
|
@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
|||
})pb")}));
|
||||
}
|
||||
|
||||
// TODO: remove these tests once migration is over.
|
||||
class LegacyPostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
&postprocessing
|
||||
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
postprocessing.In(kTimestampsTag);
|
||||
}
|
||||
postprocessing.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultName) >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kClassificationResultName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddTensor(
|
||||
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||
tensors_->emplace_back(element_type,
|
||||
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||
quantization_parameters);
|
||||
auto view = tensors_->back().GetCpuWriteView();
|
||||
T* buffer = view.buffer<T>();
|
||||
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||
}
|
||||
|
||||
absl::Status Run(
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||
int timestamp = 0) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||
// Reset tensors for future calls.
|
||||
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> GetClassificationResult(
|
||||
OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<ClassificationResult>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||
absl::make_unique<std::vector<Tensor>>();
|
||||
};
|
||||
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
options.set_score_threshold(0.5);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 18;
|
||||
tensor[2] = 16;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(results, EqualsProto(R"pb(classifications {
|
||||
entries {
|
||||
categories { index: 1 score: 0.8 }
|
||||
categories { index: 2 score: 0.6 }
|
||||
timestamp_ms: 0
|
||||
}
|
||||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 12;
|
||||
tensor[2] = 14;
|
||||
tensor[3] = 16;
|
||||
tensor[4] = 18;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results,
|
||||
EqualsProto(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 4
|
||||
score: 0.8
|
||||
category_name: "tiger shark"
|
||||
}
|
||||
categories {
|
||||
index: 3
|
||||
score: 0.6
|
||||
category_name: "great white shark"
|
||||
}
|
||||
categories { index: 2 score: 0.4 category_name: "goldfish" }
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
||||
tensor[1] = 12;
|
||||
tensor[2] = 14;
|
||||
tensor[3] = 16;
|
||||
tensor[4] = 18;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(results, EqualsProto(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 4
|
||||
score: 0.6899744811
|
||||
category_name: "tiger shark"
|
||||
}
|
||||
categories {
|
||||
index: 3
|
||||
score: 0.6456563062
|
||||
category_name: "great white shark"
|
||||
}
|
||||
categories {
|
||||
index: 2
|
||||
score: 0.5986876601
|
||||
category_name: "goldfish"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
|
||||
// Build input tensors.
|
||||
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
|
||||
tensor_0[1] = 0.2;
|
||||
tensor_0[2] = 0.4;
|
||||
tensor_0[3] = 0.6;
|
||||
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
|
||||
tensor_1[1] = 0.2;
|
||||
tensor_1[2] = 0.4;
|
||||
tensor_1[3] = 0.6;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
||||
|
||||
EXPECT_THAT(results, EqualsProto(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 3
|
||||
score: 0.6
|
||||
category_name: "Narration, monologue"
|
||||
}
|
||||
categories {
|
||||
index: 2
|
||||
score: 0.4
|
||||
category_name: "Conversation"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "yamnet_classification"
|
||||
}
|
||||
classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 3
|
||||
score: 0.6
|
||||
category_name: "Azara\'s Spinetail"
|
||||
}
|
||||
categories {
|
||||
index: 2
|
||||
score: 0.4
|
||||
category_name: "House Sparrow"
|
||||
}
|
||||
timestamp_ms: 0
|
||||
}
|
||||
head_index: 1
|
||||
head_name: "bird_classification"
|
||||
})pb"));
|
||||
}
|
||||
|
||||
TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||
/*connect_timestamps=*/true));
|
||||
// Build input tensors.
|
||||
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
|
||||
tensor_0[1] = 12;
|
||||
tensor_0[2] = 14;
|
||||
tensor_0[3] = 16;
|
||||
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
|
||||
tensor_1[5] = 12;
|
||||
tensor_1[6] = 14;
|
||||
tensor_1[7] = 16;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run());
|
||||
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
|
||||
/*quantization_parameters=*/{0.1, 10});
|
||||
MP_ASSERT_OK(Run(
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||
/*timestamp=*/1000));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_THAT(
|
||||
results,
|
||||
EqualsProto(
|
||||
R"pb(classifications {
|
||||
entries {
|
||||
categories {
|
||||
index: 3
|
||||
score: 0.6
|
||||
category_name: "great white shark"
|
||||
}
|
||||
categories { index: 2 score: 0.4 category_name: "goldfish" }
|
||||
timestamp_ms: 0
|
||||
}
|
||||
entries {
|
||||
categories { index: 7 score: 0.6 category_name: "stingray" }
|
||||
categories {
|
||||
index: 6
|
||||
score: 0.4
|
||||
category_name: "electric ray"
|
||||
}
|
||||
timestamp_ms: 1
|
||||
}
|
||||
head_index: 0
|
||||
head_name: "probability"
|
||||
})pb"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
|
|
|
@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source;
|
|||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
// TODO: remove once Java API migration is over.
|
||||
// Struct holding the different output streams produced by the text classifier.
|
||||
struct TextClassifierOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||
|
@ -72,10 +64,6 @@ struct TextClassifierOutputStreams {
|
|||
// Outputs:
|
||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||
// The classification results aggregated by classifier head.
|
||||
// TODO: remove once Java API migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result object that has 3 dimensions:
|
||||
// (classification head, classification timestamp, classification category).
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
|
@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
|||
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
auto classifications,
|
||||
BuildTextClassifierTask(
|
||||
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<std::string>(kTextTag)], graph));
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
classifications >> graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
|||
// TextClassifier model file with model metadata.
|
||||
// text_in: (std::string) stream to run text classification on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
||||
const proto::TextClassifierGraphOptions& task_options,
|
||||
const ModelResources& model_resources, Source<std::string> text_in,
|
||||
Graph& graph) {
|
||||
|
@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
|||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
// stream.
|
||||
return TextClassifierOutputStreams{
|
||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)],
|
||||
/*classifications=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationsTag)]};
|
||||
return postprocessing[Output<ClassificationResult>(kClassificationsTag)];
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
|||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
|
@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
// Struct holding the different output streams produced by the image classifier
|
||||
// subgraph.
|
||||
struct ImageClassifierOutputStreams {
|
||||
Source<ClassificationResult> classification_result;
|
||||
Source<ClassificationResult> classifications;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams {
|
|||
// The classification results aggregated by classifier head.
|
||||
// IMAGE - Image
|
||||
// The image that object detection runs on.
|
||||
// TODO: remove this output once Java API migration is over.
|
||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
||||
// The aggregated classification result.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
|
@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
output_streams.classification_result >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
output_streams.classifications >>
|
||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||
|
@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
// Outputs the aggregated classification result as the subgraph output
|
||||
// stream.
|
||||
return ImageClassifierOutputStreams{
|
||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)],
|
||||
/*classifications=*/
|
||||
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
|
||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||
|
|
|
@ -48,7 +48,6 @@ android_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
|
|
|
@ -97,7 +97,6 @@ android_library(
|
|||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||
|
|
|
@ -68,7 +68,7 @@ py_library(
|
|||
name = "category",
|
||||
srcs = ["category.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,10 +16,10 @@
|
|||
import dataclasses
|
||||
from typing import Any, Optional
|
||||
|
||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_CategoryProto = category_pb2.Category
|
||||
_ClassificationProto = classification_pb2.Classification
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -45,23 +45,23 @@ class Category:
|
|||
category_name: Optional[str] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _CategoryProto:
|
||||
def to_pb2(self) -> _ClassificationProto:
|
||||
"""Generates a Category protobuf object."""
|
||||
return _CategoryProto(
|
||||
return _ClassificationProto(
|
||||
index=self.index,
|
||||
score=self.score,
|
||||
display_name=self.display_name,
|
||||
category_name=self.category_name)
|
||||
label=self.category_name,
|
||||
display_name=self.display_name)
|
||||
|
||||
@classmethod
|
||||
@doc_controls.do_not_generate_docs
|
||||
def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category':
|
||||
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category':
|
||||
"""Creates a `Category` object from the given protobuf object."""
|
||||
return Category(
|
||||
index=pb2_obj.index,
|
||||
score=pb2_obj.score,
|
||||
display_name=pb2_obj.display_name,
|
||||
category_name=pb2_obj.category_name)
|
||||
category_name=pb2_obj.label)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Checks if this object is equal to the given object.
|
||||
|
|
|
@ -49,11 +49,7 @@ class Classifications:
|
|||
"""Generates a Classifications protobuf object."""
|
||||
classification_list_proto = _ClassificationListProto()
|
||||
for category in self.categories:
|
||||
classification_proto = _ClassificationProto(
|
||||
index=category.index,
|
||||
score=category.score,
|
||||
label=category.category_name,
|
||||
display_name=category.display_name)
|
||||
classification_proto = category.to_pb2()
|
||||
classification_list_proto.classification.append(classification_proto)
|
||||
return _ClassificationsProto(
|
||||
classification_list=classification_list_proto,
|
||||
|
@ -65,14 +61,9 @@ class Classifications:
|
|||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||
"""Creates a `Classifications` object from the given protobuf object."""
|
||||
categories = []
|
||||
for entry in pb2_obj.classification_list.classification:
|
||||
for classification in pb2_obj.classification_list.classification:
|
||||
categories.append(
|
||||
category_module.Category(
|
||||
index=entry.index,
|
||||
score=entry.score,
|
||||
display_name=entry.display_name,
|
||||
category_name=entry.label))
|
||||
|
||||
category_module.Category.create_from_pb2(classification))
|
||||
return Classifications(
|
||||
categories=categories,
|
||||
head_index=pb2_obj.head_index,
|
||||
|
|
Loading…
Reference in New Issue
Block a user