diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index cec44a9e3..10bc0726a 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -48,7 +48,6 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "//mediapipe/tasks/metadata:metadata_schema_cc", @@ -90,7 +89,6 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/metadata:metadata_extractor", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 5a0472f5c..cfb3b02cf 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" @@ -68,7 +67,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; -using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; +using TensorsSource = mediapipe::api2::builder::Source>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -455,12 +454,13 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { } // If output tensors are quantized, they must be dequantized first. - TensorsSource dequantized_tensors(&tensors_in); + TensorsSource dequantized_tensors = tensors_in; if (options.has_quantized_outputs()) { GenericNode* tensors_dequantization_node = &graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node->In(kTensorsTag); - dequantized_tensors = {tensors_dequantization_node, kTensorsTag}; + dequantized_tensors = tensors_dequantization_node->Out(kTensorsTag) + .Cast>(); } // If there are multiple classification heads, the output tensors need to be @@ -477,7 +477,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { auto* range = split_tensor_vector_options.add_ranges(); range->set_begin(i); range->set_end(i + 1); - split_tensors.emplace_back(split_tensor_vector_node, i); + split_tensors.push_back( + split_tensor_vector_node->Out(i).Cast>()); } dequantized_tensors >> split_tensor_vector_node->In(0); } else { @@ -494,8 +495,9 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { score_calibration_node->GetOptions() .CopyFrom(options.score_calibration_options().at(i)); split_tensors[i] >> score_calibration_node->In(kScoresTag); - calibrated_tensors.emplace_back(score_calibration_node, - kCalibratedScoresTag); + calibrated_tensors.push_back( + score_calibration_node->Out(kCalibratedScoresTag) + .Cast>()); } else { calibrated_tensors.emplace_back(split_tensors[i]); } diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index ad4881e12..7b023ba41 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -51,8 +50,6 @@ using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; using ::mediapipe::tasks::core::ModelResources; -using TensorsSource = - ::mediapipe::tasks::SourceOrNodeOutput>; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; @@ -229,12 +226,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { Source> tensors_in, Source> timestamps_in, Graph& graph) { // If output tensors are quantized, they must be dequantized first. - TensorsSource dequantized_tensors(&tensors_in); + Source> dequantized_tensors = tensors_in; if (options.has_quantized_outputs()) { GenericNode& tensors_dequantization_node = graph.AddNode("TensorsDequantizationCalculator"); tensors_in >> tensors_dequantization_node.In(kTensorsTag); - dequantized_tensors = {&tensors_dequantization_node, kTensorsTag}; + dequantized_tensors = tensors_dequantization_node.Out(kTensorsTag) + .Cast>(); } // Adds TensorsToEmbeddingsCalculator. diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 8bb5b8415..2e0ea3ce6 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -14,12 +14,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) -cc_library( - name = "source_or_node_output", - hdrs = ["source_or_node_output.h"], - deps = ["//mediapipe/framework/api2:builder"], -) - cc_library( name = "cosine_similarity", srcs = ["cosine_similarity.cc"], diff --git a/mediapipe/tasks/cc/components/utils/source_or_node_output.h b/mediapipe/tasks/cc/components/utils/source_or_node_output.h deleted file mode 100644 index 55805d5a3..000000000 --- a/mediapipe/tasks/cc/components/utils/source_or_node_output.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ - -#include "mediapipe/framework/api2/builder.h" - -namespace mediapipe { -namespace tasks { - -// Helper class representing either a Source object or a GenericNode output. -// -// Source and MultiSource (the output of a GenericNode) are widely incompatible, -// but being able to represent either of these in temporary variables and -// connect them later on facilitates graph building. -template -class SourceOrNodeOutput { - public: - SourceOrNodeOutput() = delete; - // The caller is responsible for ensuring 'source' outlives this object. - explicit SourceOrNodeOutput(mediapipe::api2::builder::Source* source) - : source_(source) {} - // The caller is responsible for ensuring 'node' outlives this object. - SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, - std::string tag) - : node_(node), tag_(tag) {} - // The caller is responsible for ensuring 'node' outlives this object. - SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index) - : node_(node), index_(index) {} - - // Connects the source or node output to the provided destination. - template - void operator>>(const U& dest) { - if (source_ != nullptr) { - *source_ >> dest; - } else { - if (index_ < 0) { - node_->Out(tag_) >> dest; - } else { - node_->Out(index_) >> dest; - } - } - } - - private: - mediapipe::api2::builder::Source* source_ = nullptr; - mediapipe::api2::builder::GenericNode* node_ = nullptr; - std::string tag_ = ""; - int index_ = -1; -}; - -} // namespace tasks -} // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_ diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 5269796ae..0238449c7 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -74,7 +74,6 @@ cc_library( "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index e5af7544d..cb85fc46f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -34,7 +34,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" #include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" -#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" @@ -69,7 +68,7 @@ using LabelItems = mediapipe::proto_ns::Map; using ObjectDetectorOptionsProto = object_detector::proto::ObjectDetectorOptions; using TensorsSource = - mediapipe::tasks::SourceOrNodeOutput>; + mediapipe::api2::builder::Source>; constexpr int kDefaultLocationsIndex = 0; constexpr int kDefaultCategoriesIndex = 1; @@ -584,7 +583,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { auto post_processing_specs, BuildPostProcessingSpecs(task_options, metadata_extractor)); // Calculators to perform score calibration, if specified in the metadata. - TensorsSource calibrated_tensors = {&inference, kTensorTag}; + TensorsSource calibrated_tensors = + inference.Out(kTensorTag).Cast>(); if (post_processing_specs.score_calibration_options.has_value()) { // Split tensors. auto* split_tensor_vector_node = @@ -623,7 +623,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { concatenate_tensor_vector_node->In(i); } } - calibrated_tensors = {concatenate_tensor_vector_node, 0}; + calibrated_tensors = + concatenate_tensor_vector_node->Out(0).Cast>(); } // Calculator to convert output tensors to a detection proto vector. // Connects TensorsToDetectionsCalculator's input stream to the output