Replace SourceOrNodeOutput with Source.

PiperOrigin-RevId: 504883990
This commit is contained in:
MediaPipe Team 2023-01-26 10:43:36 -08:00 committed by Copybara-Service
parent 2547f07c77
commit 29001234d5
7 changed files with 17 additions and 91 deletions

View File

@ -48,7 +48,6 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
@ -90,7 +89,6 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
@ -68,7 +67,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -455,12 +454,13 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
}
// If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in);
TensorsSource dequantized_tensors = tensors_in;
if (options.has_quantized_outputs()) {
GenericNode* tensors_dequantization_node =
&graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node->In(kTensorsTag);
dequantized_tensors = {tensors_dequantization_node, kTensorsTag};
dequantized_tensors = tensors_dequantization_node->Out(kTensorsTag)
.Cast<std::vector<Tensor>>();
}
// If there are multiple classification heads, the output tensors need to be
@ -477,7 +477,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
auto* range = split_tensor_vector_options.add_ranges();
range->set_begin(i);
range->set_end(i + 1);
split_tensors.emplace_back(split_tensor_vector_node, i);
split_tensors.push_back(
split_tensor_vector_node->Out(i).Cast<std::vector<Tensor>>());
}
dequantized_tensors >> split_tensor_vector_node->In(0);
} else {
@ -494,8 +495,9 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
.CopyFrom(options.score_calibration_options().at(i));
split_tensors[i] >> score_calibration_node->In(kScoresTag);
calibrated_tensors.emplace_back(score_calibration_node,
kCalibratedScoresTag);
calibrated_tensors.push_back(
score_calibration_node->Out(kCalibratedScoresTag)
.Cast<std::vector<Tensor>>());
} else {
calibrated_tensors.emplace_back(split_tensors[i]);
}

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -51,8 +50,6 @@ using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::ModelResources;
using TensorsSource =
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
@ -229,12 +226,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
// If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in);
Source<std::vector<Tensor>> dequantized_tensors = tensors_in;
if (options.has_quantized_outputs()) {
GenericNode& tensors_dequantization_node =
graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node.In(kTensorsTag);
dequantized_tensors = {&tensors_dequantization_node, kTensorsTag};
dequantized_tensors = tensors_dequantization_node.Out(kTensorsTag)
.Cast<std::vector<Tensor>>();
}
// Adds TensorsToEmbeddingsCalculator.

View File

@ -14,12 +14,6 @@
package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library(
name = "source_or_node_output",
hdrs = ["source_or_node_output.h"],
deps = ["//mediapipe/framework/api2:builder"],
)
cc_library(
name = "cosine_similarity",
srcs = ["cosine_similarity.cc"],

View File

@ -1,66 +0,0 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
#include "mediapipe/framework/api2/builder.h"
namespace mediapipe {
namespace tasks {
// Helper class representing either a Source object or a GenericNode output.
//
// Source and MultiSource (the output of a GenericNode) are widely incompatible,
// but being able to represent either of these in temporary variables and
// connect them later on facilitates graph building.
template <typename T>
class SourceOrNodeOutput {
public:
SourceOrNodeOutput() = delete;
// The caller is responsible for ensuring 'source' outlives this object.
explicit SourceOrNodeOutput(mediapipe::api2::builder::Source<T>* source)
: source_(source) {}
// The caller is responsible for ensuring 'node' outlives this object.
SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node,
std::string tag)
: node_(node), tag_(tag) {}
// The caller is responsible for ensuring 'node' outlives this object.
SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index)
: node_(node), index_(index) {}
// Connects the source or node output to the provided destination.
template <typename U>
void operator>>(const U& dest) {
if (source_ != nullptr) {
*source_ >> dest;
} else {
if (index_ < 0) {
node_->Out(tag_) >> dest;
} else {
node_->Out(index_) >> dest;
}
}
}
private:
mediapipe::api2::builder::Source<T>* source_ = nullptr;
mediapipe::api2::builder::GenericNode* node_ = nullptr;
std::string tag_ = "";
int index_ = -1;
};
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_

View File

@ -74,7 +74,6 @@ cc_library(
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
@ -69,7 +68,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions;
using TensorsSource =
mediapipe::tasks::SourceOrNodeOutput<std::vector<mediapipe::Tensor>>;
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
constexpr int kDefaultLocationsIndex = 0;
constexpr int kDefaultCategoriesIndex = 1;
@ -584,7 +583,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
auto post_processing_specs,
BuildPostProcessingSpecs(task_options, metadata_extractor));
// Calculators to perform score calibration, if specified in the metadata.
TensorsSource calibrated_tensors = {&inference, kTensorTag};
TensorsSource calibrated_tensors =
inference.Out(kTensorTag).Cast<std::vector<Tensor>>();
if (post_processing_specs.score_calibration_options.has_value()) {
// Split tensors.
auto* split_tensor_vector_node =
@ -623,7 +623,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
concatenate_tensor_vector_node->In(i);
}
}
calibrated_tensors = {concatenate_tensor_vector_node, 0};
calibrated_tensors =
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
}
// Calculator to convert output tensors to a detection proto vector.
// Connects TensorsToDetectionsCalculator's input stream to the output