Cleanup after migration to new classification output format.
PiperOrigin-RevId: 489921603
This commit is contained in:
parent
13c6b9a8c6
commit
7acbf557a1
|
@ -37,7 +37,6 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:packet",
|
"//mediapipe/framework/api2:packet",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
],
|
],
|
||||||
|
|
|
@ -25,14 +25,12 @@
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
|
||||||
|
|
||||||
// Aggregates ClassificationLists into either a ClassificationResult object
|
// Aggregates ClassificationLists into either a ClassificationResult object
|
||||||
// representing the classification results aggregated by classifier head, or
|
// representing the classification results aggregated by classifier head, or
|
||||||
|
@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||||
// The classification result aggregated by timestamp, then by head. Must be
|
// The classification result aggregated by timestamp, then by head. Must be
|
||||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
// timestamp aggregation is required.
|
// timestamp aggregation is required.
|
||||||
// // TODO: remove output once migration is over.
|
|
||||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
|
||||||
// The aggregated classification result.
|
|
||||||
//
|
//
|
||||||
// Example without timestamp aggregation:
|
// Example without timestamp aggregation:
|
||||||
// node {
|
// node {
|
||||||
|
@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node {
|
||||||
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
ClassificationResult ConvertToClassificationResult(CalculatorContext* cc);
|
||||||
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
|
std::vector<ClassificationResult> ConvertToTimestampedClassificationResults(
|
||||||
CalculatorContext* cc);
|
CalculatorContext* cc);
|
||||||
// TODO: deprecate this function once migration is over.
|
|
||||||
ClassificationResult LegacyConvertToClassificationResult(
|
|
||||||
CalculatorContext* cc);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||||
|
@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract(
|
||||||
<< "The size of classifications input streams should match the "
|
<< "The size of classifications input streams should match the "
|
||||||
"size of head names specified in the calculator options";
|
"size of head names specified in the calculator options";
|
||||||
}
|
}
|
||||||
// TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if
|
if (kTimestampsIn(cc).IsConnected()) {
|
||||||
// TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is
|
RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected());
|
||||||
// not connected. All dependent tasks must be updated to use these outputs
|
} else {
|
||||||
// first.
|
RET_CHECK(kClassificationsOut(cc).IsConnected());
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process(
|
||||||
if (kTimestampsIn(cc).IsEmpty()) {
|
if (kTimestampsIn(cc).IsEmpty()) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
classification_result = LegacyConvertToClassificationResult(cc);
|
|
||||||
kTimestampedClassificationsOut(cc).Send(
|
kTimestampedClassificationsOut(cc).Send(
|
||||||
ConvertToTimestampedClassificationResults(cc));
|
ConvertToTimestampedClassificationResults(cc));
|
||||||
} else {
|
} else {
|
||||||
classification_result = LegacyConvertToClassificationResult(cc);
|
|
||||||
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
|
kClassificationsOut(cc).Send(ConvertToClassificationResult(cc));
|
||||||
}
|
}
|
||||||
kClassificationResultOut(cc).Send(classification_result);
|
kClassificationResultOut(cc).Send(classification_result);
|
||||||
|
@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults(
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
ClassificationResult
|
|
||||||
ClassificationAggregationCalculator::LegacyConvertToClassificationResult(
|
|
||||||
CalculatorContext* cc) {
|
|
||||||
ClassificationResult result;
|
|
||||||
Timestamp first_timestamp(0);
|
|
||||||
std::vector<Timestamp> timestamps;
|
|
||||||
if (time_aggregation_enabled_) {
|
|
||||||
timestamps = kTimestampsIn(cc).Get();
|
|
||||||
first_timestamp = timestamps[0];
|
|
||||||
} else {
|
|
||||||
timestamps = {cc->InputTimestamp()};
|
|
||||||
}
|
|
||||||
for (Timestamp timestamp : timestamps) {
|
|
||||||
int count = cached_classifications_[timestamp.Value()].size();
|
|
||||||
for (int i = 0; i < count; ++i) {
|
|
||||||
Classifications* c;
|
|
||||||
if (result.classifications_size() <= i) {
|
|
||||||
c = result.add_classifications();
|
|
||||||
if (!head_names_.empty()) {
|
|
||||||
c->set_head_index(i);
|
|
||||||
c->set_head_name(head_names_[i]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c = result.mutable_classifications(i);
|
|
||||||
}
|
|
||||||
auto* entry = c->add_entries();
|
|
||||||
for (const auto& elem :
|
|
||||||
cached_classifications_[timestamp.Value()][i].classification()) {
|
|
||||||
auto* category = entry->add_categories();
|
|
||||||
if (elem.has_index()) {
|
|
||||||
category->set_index(elem.index());
|
|
||||||
}
|
|
||||||
if (elem.has_score()) {
|
|
||||||
category->set_score(elem.score());
|
|
||||||
}
|
|
||||||
if (elem.has_label()) {
|
|
||||||
category->set_category_name(elem.label());
|
|
||||||
}
|
|
||||||
if (elem.has_display_name()) {
|
|
||||||
category->set_display_name(elem.display_name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) /
|
|
||||||
1000);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator);
|
MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator);
|
||||||
|
|
||||||
} // namespace api2
|
} // namespace api2
|
||||||
|
|
|
@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "category_proto",
|
|
||||||
srcs = ["category.proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "classifications_proto",
|
name = "classifications_proto",
|
||||||
srcs = ["classifications.proto"],
|
srcs = ["classifications.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
":category_proto",
|
|
||||||
"//mediapipe/framework/formats:classification_proto",
|
"//mediapipe/framework/formats:classification_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
syntax = "proto2";
|
|
||||||
|
|
||||||
package mediapipe.tasks.components.containers.proto;
|
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
|
||||||
option java_outer_classname = "CategoryProto";
|
|
||||||
|
|
||||||
// TODO: deprecate this message once migration is over.
|
|
||||||
// A single classification result.
|
|
||||||
message Category {
|
|
||||||
// The index of the category in the corresponding label map, usually packed in
|
|
||||||
// the TFLite Model Metadata [1].
|
|
||||||
//
|
|
||||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
|
|
||||||
optional int32 index = 1;
|
|
||||||
// The score for this category, e.g. (but not necessarily) a probability in
|
|
||||||
// [0,1].
|
|
||||||
optional float score = 2;
|
|
||||||
// A human readable name of the category filled from the label map.
|
|
||||||
optional string display_name = 3;
|
|
||||||
// An ID for the category, not necessarily human-readable, e.g. a Google
|
|
||||||
// Knowledge Graph ID [1], filled from the label map.
|
|
||||||
//
|
|
||||||
// [1]: https://developers.google.com/knowledge-graph
|
|
||||||
optional string category_name = 4;
|
|
||||||
}
|
|
|
@ -18,27 +18,12 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.components.containers.proto;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/formats/classification.proto";
|
import "mediapipe/framework/formats/classification.proto";
|
||||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
option java_package = "com.google.mediapipe.tasks.components.containers.proto";
|
||||||
option java_outer_classname = "ClassificationsProto";
|
option java_outer_classname = "ClassificationsProto";
|
||||||
|
|
||||||
// TODO: deprecate this message once migration is over.
|
|
||||||
// List of predicted categories with an optional timestamp.
|
|
||||||
message ClassificationEntry {
|
|
||||||
// The array of predicted categories, usually sorted by descending scores,
|
|
||||||
// e.g., from high to low probability.
|
|
||||||
repeated Category categories = 1;
|
|
||||||
// The optional timestamp (in milliseconds) associated to the classifcation
|
|
||||||
// entry. This is useful for time series use cases, e.g., audio
|
|
||||||
// classification.
|
|
||||||
optional int64 timestamp_ms = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Classifications for a given classifier head, i.e. for a given output tensor.
|
// Classifications for a given classifier head, i.e. for a given output tensor.
|
||||||
message Classifications {
|
message Classifications {
|
||||||
// TODO: deprecate this field once migration is over.
|
|
||||||
repeated ClassificationEntry entries = 1;
|
|
||||||
// The classification results for this head.
|
// The classification results for this head.
|
||||||
optional mediapipe.ClassificationList classification_list = 4;
|
optional mediapipe.ClassificationList classification_list = 4;
|
||||||
// The index of the classifier head these categories refer to. This is useful
|
// The index of the classifier head these categories refer to. This is useful
|
||||||
|
@ -48,6 +33,8 @@ message Classifications {
|
||||||
// name.
|
// name.
|
||||||
// TODO: Add github link to metadata_schema.fbs.
|
// TODO: Add github link to metadata_schema.fbs.
|
||||||
optional string head_name = 3;
|
optional string head_name = 3;
|
||||||
|
// Reserved fields.
|
||||||
|
reserved 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Classifications for a given classifier model.
|
// Classifications for a given classifier model.
|
||||||
|
|
|
@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
|
constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
|
||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kScoresTag[] = "SCORES";
|
constexpr char kScoresTag[] = "SCORES";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
|
|
||||||
// Struct holding the different output streams produced by the graph.
|
// Struct holding the different output streams produced by the graph.
|
||||||
struct ClassificationPostprocessingOutputStreams {
|
struct ClassificationPostprocessingOutputStreams {
|
||||||
Source<ClassificationResult> classification_result;
|
|
||||||
Source<ClassificationResult> classifications;
|
Source<ClassificationResult> classifications;
|
||||||
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
Source<std::vector<ClassificationResult>> timestamped_classifications;
|
||||||
};
|
};
|
||||||
|
@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
// The classification result aggregated by timestamp, then by head. Must be
|
// The classification result aggregated by timestamp, then by head. Must be
|
||||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
// timestamp aggregation is required.
|
// timestamp aggregation is required.
|
||||||
// // TODO: remove output once migration is over.
|
|
||||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
|
||||||
// The aggregated classification result.
|
|
||||||
//
|
//
|
||||||
// The recommended way of using this graph is through the GraphBuilder API
|
// The recommended way of using this graph is through the GraphBuilder API
|
||||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||||
|
@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||||
output_streams.classification_result >>
|
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
|
||||||
output_streams.classifications >>
|
output_streams.classifications >>
|
||||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
output_streams.timestamped_classifications >>
|
output_streams.timestamped_classifications >>
|
||||||
|
@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
|
|
||||||
// Connects output.
|
// Connects output.
|
||||||
ClassificationPostprocessingOutputStreams output_streams{
|
ClassificationPostprocessingOutputStreams output_streams{
|
||||||
/*classification_result=*/result_aggregation
|
|
||||||
[Output<ClassificationResult>(kClassificationResultTag)],
|
|
||||||
/*classifications=*/
|
/*classifications=*/
|
||||||
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
|
result_aggregation[Output<ClassificationResult>(kClassificationsTag)],
|
||||||
/*timestamped_classifications=*/
|
/*timestamped_classifications=*/
|
||||||
|
|
|
@ -58,9 +58,6 @@ namespace processors {
|
||||||
// The classification result aggregated by timestamp, then by head. Must be
|
// The classification result aggregated by timestamp, then by head. Must be
|
||||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
// timestamp aggregation is required.
|
// timestamp aggregation is required.
|
||||||
// // TODO: remove output once migration is over.
|
|
||||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
|
||||||
// The aggregated classification result.
|
|
||||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const proto::ClassifierOptions& classifier_options,
|
const proto::ClassifierOptions& classifier_options,
|
||||||
|
|
|
@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kTensorsName[] = "tensors";
|
constexpr char kTensorsName[] = "tensors";
|
||||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
constexpr char kTimestampsName[] = "timestamps";
|
constexpr char kTimestampsName[] = "timestamps";
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
|
||||||
constexpr char kClassificationResultName[] = "classification_result";
|
|
||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kClassificationsName[] = "classifications";
|
constexpr char kClassificationsName[] = "classifications";
|
||||||
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS";
|
||||||
|
@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
})pb")}));
|
})pb")}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove these tests once migration is over.
|
|
||||||
class LegacyPostprocessingTest : public tflite_shims::testing::Test {
|
|
||||||
protected:
|
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
|
||||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
|
||||||
bool connect_timestamps = false) {
|
|
||||||
ASSIGN_OR_RETURN(auto model_resources,
|
|
||||||
CreateModelResourcesForModel(model_name));
|
|
||||||
|
|
||||||
Graph graph;
|
|
||||||
auto& postprocessing = graph.AddNode(
|
|
||||||
"mediapipe.tasks.components.processors."
|
|
||||||
"ClassificationPostprocessingGraph");
|
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
|
||||||
*model_resources, options,
|
|
||||||
&postprocessing
|
|
||||||
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
|
||||||
postprocessing.In(kTensorsTag);
|
|
||||||
if (connect_timestamps) {
|
|
||||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
|
||||||
kTimestampsName) >>
|
|
||||||
postprocessing.In(kTimestampsTag);
|
|
||||||
}
|
|
||||||
postprocessing.Out(kClassificationResultTag)
|
|
||||||
.SetName(kClassificationResultName) >>
|
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
|
||||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
|
||||||
kClassificationResultName));
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
|
||||||
return poller;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void AddTensor(
|
|
||||||
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
|
||||||
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
|
||||||
tensors_->emplace_back(element_type,
|
|
||||||
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
|
||||||
quantization_parameters);
|
|
||||||
auto view = tensors_->back().GetCpuWriteView();
|
|
||||||
T* buffer = view.buffer<T>();
|
|
||||||
std::copy(tensor.begin(), tensor.end(), buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::Status Run(
|
|
||||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
|
||||||
int timestamp = 0) {
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
|
||||||
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
|
||||||
// Reset tensors for future calls.
|
|
||||||
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
|
||||||
if (aggregation_timestamps.has_value()) {
|
|
||||||
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
|
||||||
for (const auto& timestamp : *aggregation_timestamps) {
|
|
||||||
packet->emplace_back(Timestamp(timestamp));
|
|
||||||
}
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
|
||||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
|
||||||
}
|
|
||||||
return absl::OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::StatusOr<ClassificationResult> GetClassificationResult(
|
|
||||||
OutputStreamPoller& poller) {
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
|
||||||
|
|
||||||
Packet packet;
|
|
||||||
if (!poller.Next(&packet)) {
|
|
||||||
return absl::InternalError("Unable to get output packet");
|
|
||||||
}
|
|
||||||
auto result = packet.Get<ClassificationResult>();
|
|
||||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
CalculatorGraph calculator_graph_;
|
|
||||||
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
|
||||||
absl::make_unique<std::vector<Tensor>>();
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) {
|
|
||||||
// Build graph.
|
|
||||||
proto::ClassifierOptions options;
|
|
||||||
options.set_max_results(3);
|
|
||||||
options.set_score_threshold(0.5);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto poller,
|
|
||||||
BuildGraph(kQuantizedImageClassifierWithoutMetadata, options));
|
|
||||||
// Build input tensors.
|
|
||||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
|
||||||
tensor[1] = 18;
|
|
||||||
tensor[2] = 16;
|
|
||||||
|
|
||||||
// Send tensors and get results.
|
|
||||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
|
||||||
/*quantization_parameters=*/{0.1, 10});
|
|
||||||
MP_ASSERT_OK(Run());
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
|
||||||
|
|
||||||
// Validate results.
|
|
||||||
EXPECT_THAT(results, EqualsProto(R"pb(classifications {
|
|
||||||
entries {
|
|
||||||
categories { index: 1 score: 0.8 }
|
|
||||||
categories { index: 2 score: 0.6 }
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
})pb"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) {
|
|
||||||
// Build graph.
|
|
||||||
proto::ClassifierOptions options;
|
|
||||||
options.set_max_results(3);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
|
||||||
// Build input tensors.
|
|
||||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
|
||||||
tensor[1] = 12;
|
|
||||||
tensor[2] = 14;
|
|
||||||
tensor[3] = 16;
|
|
||||||
tensor[4] = 18;
|
|
||||||
|
|
||||||
// Send tensors and get results.
|
|
||||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
|
||||||
/*quantization_parameters=*/{0.1, 10});
|
|
||||||
MP_ASSERT_OK(Run());
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
|
||||||
|
|
||||||
// Validate results.
|
|
||||||
EXPECT_THAT(
|
|
||||||
results,
|
|
||||||
EqualsProto(
|
|
||||||
R"pb(classifications {
|
|
||||||
entries {
|
|
||||||
categories {
|
|
||||||
index: 4
|
|
||||||
score: 0.8
|
|
||||||
category_name: "tiger shark"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 3
|
|
||||||
score: 0.6
|
|
||||||
category_name: "great white shark"
|
|
||||||
}
|
|
||||||
categories { index: 2 score: 0.4 category_name: "goldfish" }
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) {
|
|
||||||
// Build graph.
|
|
||||||
proto::ClassifierOptions options;
|
|
||||||
options.set_max_results(3);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto poller,
|
|
||||||
BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options));
|
|
||||||
// Build input tensors.
|
|
||||||
std::vector<uint8> tensor(kMobileNetNumClasses, 0);
|
|
||||||
tensor[1] = 12;
|
|
||||||
tensor[2] = 14;
|
|
||||||
tensor[3] = 16;
|
|
||||||
tensor[4] = 18;
|
|
||||||
|
|
||||||
// Send tensors and get results.
|
|
||||||
AddTensor(tensor, Tensor::ElementType::kUInt8,
|
|
||||||
/*quantization_parameters=*/{0.1, 10});
|
|
||||||
MP_ASSERT_OK(Run());
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
|
||||||
|
|
||||||
// Validate results.
|
|
||||||
EXPECT_THAT(results, EqualsProto(
|
|
||||||
R"pb(classifications {
|
|
||||||
entries {
|
|
||||||
categories {
|
|
||||||
index: 4
|
|
||||||
score: 0.6899744811
|
|
||||||
category_name: "tiger shark"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 3
|
|
||||||
score: 0.6456563062
|
|
||||||
category_name: "great white shark"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 2
|
|
||||||
score: 0.5986876601
|
|
||||||
category_name: "goldfish"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) {
|
|
||||||
// Build graph.
|
|
||||||
proto::ClassifierOptions options;
|
|
||||||
options.set_max_results(2);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto poller,
|
|
||||||
BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options));
|
|
||||||
// Build input tensors.
|
|
||||||
std::vector<float> tensor_0(kTwoHeadsNumClasses[0], 0);
|
|
||||||
tensor_0[1] = 0.2;
|
|
||||||
tensor_0[2] = 0.4;
|
|
||||||
tensor_0[3] = 0.6;
|
|
||||||
std::vector<float> tensor_1(kTwoHeadsNumClasses[1], 0);
|
|
||||||
tensor_1[1] = 0.2;
|
|
||||||
tensor_1[2] = 0.4;
|
|
||||||
tensor_1[3] = 0.6;
|
|
||||||
|
|
||||||
// Send tensors and get results.
|
|
||||||
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
|
||||||
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
|
||||||
MP_ASSERT_OK(Run());
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
|
||||||
|
|
||||||
EXPECT_THAT(results, EqualsProto(
|
|
||||||
R"pb(classifications {
|
|
||||||
entries {
|
|
||||||
categories {
|
|
||||||
index: 3
|
|
||||||
score: 0.6
|
|
||||||
category_name: "Narration, monologue"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 2
|
|
||||||
score: 0.4
|
|
||||||
category_name: "Conversation"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "yamnet_classification"
|
|
||||||
}
|
|
||||||
classifications {
|
|
||||||
entries {
|
|
||||||
categories {
|
|
||||||
index: 3
|
|
||||||
score: 0.6
|
|
||||||
category_name: "Azara\'s Spinetail"
|
|
||||||
}
|
|
||||||
categories {
|
|
||||||
index: 2
|
|
||||||
score: 0.4
|
|
||||||
category_name: "House Sparrow"
|
|
||||||
}
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
head_index: 1
|
|
||||||
head_name: "bird_classification"
|
|
||||||
})pb"));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) {
|
|
||||||
// Build graph.
|
|
||||||
proto::ClassifierOptions options;
|
|
||||||
options.set_max_results(2);
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
|
||||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
|
||||||
/*connect_timestamps=*/true));
|
|
||||||
// Build input tensors.
|
|
||||||
std::vector<uint8> tensor_0(kMobileNetNumClasses, 0);
|
|
||||||
tensor_0[1] = 12;
|
|
||||||
tensor_0[2] = 14;
|
|
||||||
tensor_0[3] = 16;
|
|
||||||
std::vector<uint8> tensor_1(kMobileNetNumClasses, 0);
|
|
||||||
tensor_1[5] = 12;
|
|
||||||
tensor_1[6] = 14;
|
|
||||||
tensor_1[7] = 16;
|
|
||||||
|
|
||||||
// Send tensors and get results.
|
|
||||||
AddTensor(tensor_0, Tensor::ElementType::kUInt8,
|
|
||||||
/*quantization_parameters=*/{0.1, 10});
|
|
||||||
MP_ASSERT_OK(Run());
|
|
||||||
AddTensor(tensor_1, Tensor::ElementType::kUInt8,
|
|
||||||
/*quantization_parameters=*/{0.1, 10});
|
|
||||||
MP_ASSERT_OK(Run(
|
|
||||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
|
||||||
/*timestamp=*/1000));
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller));
|
|
||||||
|
|
||||||
// Validate results.
|
|
||||||
EXPECT_THAT(
|
|
||||||
results,
|
|
||||||
EqualsProto(
|
|
||||||
R"pb(classifications {
|
|
||||||
entries {
|
|
||||||
categories {
|
|
||||||
index: 3
|
|
||||||
score: 0.6
|
|
||||||
category_name: "great white shark"
|
|
||||||
}
|
|
||||||
categories { index: 2 score: 0.4 category_name: "goldfish" }
|
|
||||||
timestamp_ms: 0
|
|
||||||
}
|
|
||||||
entries {
|
|
||||||
categories { index: 7 score: 0.6 category_name: "stingray" }
|
|
||||||
categories {
|
|
||||||
index: 6
|
|
||||||
score: 0.4
|
|
||||||
category_name: "electric ray"
|
|
||||||
}
|
|
||||||
timestamp_ms: 1
|
|
||||||
}
|
|
||||||
head_index: 0
|
|
||||||
head_name: "probability"
|
|
||||||
})pb"));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace processors
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
|
|
|
@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
|
||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kTextTag[] = "TEXT";
|
constexpr char kTextTag[] = "TEXT";
|
||||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
|
||||||
// TODO: remove once Java API migration is over.
|
|
||||||
// Struct holding the different output streams produced by the text classifier.
|
|
||||||
struct TextClassifierOutputStreams {
|
|
||||||
Source<ClassificationResult> classification_result;
|
|
||||||
Source<ClassificationResult> classifications;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// A "TextClassifierGraph" performs Natural Language classification (including
|
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||||
|
@ -72,10 +64,6 @@ struct TextClassifierOutputStreams {
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATIONS - ClassificationResult @Optional
|
// CLASSIFICATIONS - ClassificationResult @Optional
|
||||||
// The classification results aggregated by classifier head.
|
// The classification results aggregated by classifier head.
|
||||||
// TODO: remove once Java API migration is over.
|
|
||||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
|
||||||
// The aggregated classification result object that has 3 dimensions:
|
|
||||||
// (classification head, classification timestamp, classification category).
|
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
|
@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto output_streams,
|
auto classifications,
|
||||||
BuildTextClassifierTask(
|
BuildTextClassifierTask(
|
||||||
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<std::string>(kTextTag)], graph));
|
graph[Input<std::string>(kTextTag)], graph));
|
||||||
output_streams.classification_result >>
|
classifications >> graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
|
||||||
output_streams.classifications >>
|
|
||||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
// TextClassifier model file with model metadata.
|
// TextClassifier model file with model metadata.
|
||||||
// text_in: (std::string) stream to run text classification on.
|
// text_in: (std::string) stream to run text classification on.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
|
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
||||||
const proto::TextClassifierGraphOptions& task_options,
|
const proto::TextClassifierGraphOptions& task_options,
|
||||||
const ModelResources& model_resources, Source<std::string> text_in,
|
const ModelResources& model_resources, Source<std::string> text_in,
|
||||||
Graph& graph) {
|
Graph& graph) {
|
||||||
|
@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
// Outputs the aggregated classification result as the subgraph output
|
// Outputs the aggregated classification result as the subgraph output
|
||||||
// stream.
|
// stream.
|
||||||
return TextClassifierOutputStreams{
|
return postprocessing[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
|
||||||
kClassificationResultTag)],
|
|
||||||
/*classifications=*/postprocessing[Output<ClassificationResult>(
|
|
||||||
kClassificationsTag)]};
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
|
||||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||||
|
@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS";
|
||||||
// Struct holding the different output streams produced by the image classifier
|
// Struct holding the different output streams produced by the image classifier
|
||||||
// subgraph.
|
// subgraph.
|
||||||
struct ImageClassifierOutputStreams {
|
struct ImageClassifierOutputStreams {
|
||||||
Source<ClassificationResult> classification_result;
|
|
||||||
Source<ClassificationResult> classifications;
|
Source<ClassificationResult> classifications;
|
||||||
Source<Image> image;
|
Source<Image> image;
|
||||||
};
|
};
|
||||||
|
@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams {
|
||||||
// The classification results aggregated by classifier head.
|
// The classification results aggregated by classifier head.
|
||||||
// IMAGE - Image
|
// IMAGE - Image
|
||||||
// The image that object detection runs on.
|
// The image that object detection runs on.
|
||||||
// TODO: remove this output once Java API migration is over.
|
|
||||||
// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
|
|
||||||
// The aggregated classification result.
|
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
// node {
|
// node {
|
||||||
|
@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
|
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
|
||||||
graph[Input<Image>(kImageTag)],
|
graph[Input<Image>(kImageTag)],
|
||||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||||
output_streams.classification_result >>
|
|
||||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
|
||||||
output_streams.classifications >>
|
output_streams.classifications >>
|
||||||
graph[Output<ClassificationResult>(kClassificationsTag)];
|
graph[Output<ClassificationResult>(kClassificationsTag)];
|
||||||
output_streams.image >> graph[Output<Image>(kImageTag)];
|
output_streams.image >> graph[Output<Image>(kImageTag)];
|
||||||
|
@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
// Outputs the aggregated classification result as the subgraph output
|
// Outputs the aggregated classification result as the subgraph output
|
||||||
// stream.
|
// stream.
|
||||||
return ImageClassifierOutputStreams{
|
return ImageClassifierOutputStreams{
|
||||||
/*classification_result=*/postprocessing[Output<ClassificationResult>(
|
|
||||||
kClassificationResultTag)],
|
|
||||||
/*classifications=*/
|
/*classifications=*/
|
||||||
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
|
postprocessing[Output<ClassificationResult>(kClassificationsTag)],
|
||||||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||||
|
|
|
@ -48,7 +48,6 @@ android_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||||
|
|
|
@ -97,7 +97,6 @@ android_library(
|
||||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||||
|
|
|
@ -68,7 +68,7 @@ py_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["category.py"],
|
srcs = ["category.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,10 +16,10 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
from mediapipe.framework.formats import classification_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_CategoryProto = category_pb2.Category
|
_ClassificationProto = classification_pb2.Classification
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -45,23 +45,23 @@ class Category:
|
||||||
category_name: Optional[str] = None
|
category_name: Optional[str] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _CategoryProto:
|
def to_pb2(self) -> _ClassificationProto:
|
||||||
"""Generates a Category protobuf object."""
|
"""Generates a Category protobuf object."""
|
||||||
return _CategoryProto(
|
return _ClassificationProto(
|
||||||
index=self.index,
|
index=self.index,
|
||||||
score=self.score,
|
score=self.score,
|
||||||
display_name=self.display_name,
|
label=self.category_name,
|
||||||
category_name=self.category_name)
|
display_name=self.display_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category':
|
def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category':
|
||||||
"""Creates a `Category` object from the given protobuf object."""
|
"""Creates a `Category` object from the given protobuf object."""
|
||||||
return Category(
|
return Category(
|
||||||
index=pb2_obj.index,
|
index=pb2_obj.index,
|
||||||
score=pb2_obj.score,
|
score=pb2_obj.score,
|
||||||
display_name=pb2_obj.display_name,
|
display_name=pb2_obj.display_name,
|
||||||
category_name=pb2_obj.category_name)
|
category_name=pb2_obj.label)
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
"""Checks if this object is equal to the given object.
|
"""Checks if this object is equal to the given object.
|
||||||
|
|
|
@ -49,11 +49,7 @@ class Classifications:
|
||||||
"""Generates a Classifications protobuf object."""
|
"""Generates a Classifications protobuf object."""
|
||||||
classification_list_proto = _ClassificationListProto()
|
classification_list_proto = _ClassificationListProto()
|
||||||
for category in self.categories:
|
for category in self.categories:
|
||||||
classification_proto = _ClassificationProto(
|
classification_proto = category.to_pb2()
|
||||||
index=category.index,
|
|
||||||
score=category.score,
|
|
||||||
label=category.category_name,
|
|
||||||
display_name=category.display_name)
|
|
||||||
classification_list_proto.classification.append(classification_proto)
|
classification_list_proto.classification.append(classification_proto)
|
||||||
return _ClassificationsProto(
|
return _ClassificationsProto(
|
||||||
classification_list=classification_list_proto,
|
classification_list=classification_list_proto,
|
||||||
|
@ -65,14 +61,9 @@ class Classifications:
|
||||||
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications':
|
||||||
"""Creates a `Classifications` object from the given protobuf object."""
|
"""Creates a `Classifications` object from the given protobuf object."""
|
||||||
categories = []
|
categories = []
|
||||||
for entry in pb2_obj.classification_list.classification:
|
for classification in pb2_obj.classification_list.classification:
|
||||||
categories.append(
|
categories.append(
|
||||||
category_module.Category(
|
category_module.Category.create_from_pb2(classification))
|
||||||
index=entry.index,
|
|
||||||
score=entry.score,
|
|
||||||
display_name=entry.display_name,
|
|
||||||
category_name=entry.label))
|
|
||||||
|
|
||||||
return Classifications(
|
return Classifications(
|
||||||
categories=categories,
|
categories=categories,
|
||||||
head_index=pb2_obj.head_index,
|
head_index=pb2_obj.head_index,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user