Expose tool calculators in headers to enable dynamic registration by superusers.

PiperOrigin-RevId: 557174440
This commit is contained in:
MediaPipe Team 2023-08-15 10:28:29 -07:00 committed by Copybara-Service
parent a392561b31
commit c1d7e6023a
5 changed files with 123 additions and 67 deletions

View File

@ -335,6 +335,7 @@ mediapipe_cc_test(
cc_library( cc_library(
name = "packet_generator_wrapper_calculator", name = "packet_generator_wrapper_calculator",
srcs = ["packet_generator_wrapper_calculator.cc"], srcs = ["packet_generator_wrapper_calculator.cc"],
hdrs = ["packet_generator_wrapper_calculator.h"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = ["//mediapipe/framework:__subpackages__"],
deps = [ deps = [
":packet_generator_wrapper_calculator_cc_proto", ":packet_generator_wrapper_calculator_cc_proto",
@ -342,6 +343,9 @@ cc_library(
"//mediapipe/framework:calculator_registry", "//mediapipe/framework:calculator_registry",
"//mediapipe/framework:output_side_packet", "//mediapipe/framework:output_side_packet",
"//mediapipe/framework:packet_generator", "//mediapipe/framework:packet_generator",
"//mediapipe/framework:packet_set",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -386,21 +390,22 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":name_util", ":name_util",
":status_util",
"//mediapipe/calculators/internal:callback_packet_calculator", "//mediapipe/calculators/internal:callback_packet_calculator",
"//mediapipe/calculators/internal:callback_packet_calculator_cc_proto", "//mediapipe/calculators/internal:callback_packet_calculator_cc_proto",
"//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_graph", "//mediapipe/framework:calculator_graph",
"//mediapipe/framework:calculator_registry", "//mediapipe/framework:calculator_registry",
"//mediapipe/framework:input_stream",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:packet_type", "//mediapipe/framework:packet_type",
"//mediapipe/framework/port:logging", "//mediapipe/framework:timestamp",
"//mediapipe/framework/port:source_location", "//mediapipe/framework/port:source_location",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -1,29 +1,33 @@
#include "mediapipe/framework/tool/packet_generator_wrapper_calculator.h"
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_registry.h" #include "mediapipe/framework/calculator_registry.h"
#include "mediapipe/framework/output_side_packet.h" #include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/packet_generator.h" #include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/tool/packet_generator_wrapper_calculator.pb.h" #include "mediapipe/framework/tool/packet_generator_wrapper_calculator.pb.h"
namespace mediapipe { namespace mediapipe {
class PacketGeneratorWrapperCalculator : public CalculatorBase { absl::Status PacketGeneratorWrapperCalculator::GetContract(
public: CalculatorContract* cc) {
static absl::Status GetContract(CalculatorContract* cc) {
const auto& options = const auto& options =
cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>(); cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>();
ASSIGN_OR_RETURN(auto static_access, ASSIGN_OR_RETURN(auto static_access,
mediapipe::internal::StaticAccessToGeneratorRegistry:: mediapipe::internal::StaticAccessToGeneratorRegistry::
CreateByNameInNamespace(options.package(), CreateByNameInNamespace(options.package(),
options.packet_generator())); options.packet_generator()));
MP_RETURN_IF_ERROR(static_access->FillExpectations( MP_RETURN_IF_ERROR(static_access->FillExpectations(options.options(),
options.options(), &cc->InputSidePackets(), &cc->InputSidePackets(),
&cc->OutputSidePackets())) &cc->OutputSidePackets()))
.SetPrepend() .SetPrepend()
<< options.packet_generator() << "::FillExpectations() failed: "; << options.packet_generator() << "::FillExpectations() failed: ";
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status PacketGeneratorWrapperCalculator::Open(CalculatorContext* cc) {
const auto& options = const auto& options =
cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>(); cc->Options<::mediapipe::PacketGeneratorWrapperCalculatorOptions>();
ASSIGN_OR_RETURN(auto static_access, ASSIGN_OR_RETURN(auto static_access,
@ -36,17 +40,16 @@ class PacketGeneratorWrapperCalculator : public CalculatorBase {
&output_packets)) &output_packets))
.SetPrepend() .SetPrepend()
<< options.packet_generator() << "::Generate() failed: "; << options.packet_generator() << "::Generate() failed: ";
for (auto id = output_packets.BeginId(); id < output_packets.EndId(); for (auto id = output_packets.BeginId(); id < output_packets.EndId(); ++id) {
++id) {
cc->OutputSidePackets().Get(id).Set(output_packets.Get(id)); cc->OutputSidePackets().Get(id).Set(output_packets.Get(id));
} }
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) override { absl::Status PacketGeneratorWrapperCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
};
REGISTER_CALCULATOR(PacketGeneratorWrapperCalculator); REGISTER_CALCULATOR(PacketGeneratorWrapperCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_PACKET_GENERATOR_WRAPPER_CALCULATOR_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_PACKET_GENERATOR_WRAPPER_CALCULATOR_H_
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_base.h"
namespace mediapipe {
class PacketGeneratorWrapperCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_TOOL_PACKET_GENERATOR_WRAPPER_CALCULATOR_H_

View File

@ -18,54 +18,58 @@
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include <stdio.h>
#include <functional>
#include <map>
#include <memory> #include <memory>
#include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_join.h"
#include "mediapipe/calculators/internal/callback_packet_calculator.pb.h" #include "mediapipe/calculators/internal/callback_packet_calculator.pb.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_graph.h" #include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/calculator_registry.h" #include "mediapipe/framework/calculator_registry.h"
#include "mediapipe/framework/input_stream.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/source_location.h" #include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/name_util.h" #include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/status_util.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
namespace {
// Produces an output packet with the PostStream timestamp containing the absl::Status MediaPipeInternalSidePacketToPacketStreamCalculator::GetContract(
// input side packet. CalculatorContract* cc) {
class MediaPipeInternalSidePacketToPacketStreamCalculator
: public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).SetAny(); cc->InputSidePackets().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status MediaPipeInternalSidePacketToPacketStreamCalculator::Open(
CalculatorContext* cc) {
cc->Outputs().Index(0).AddPacket( cc->Outputs().Index(0).AddPacket(
cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); cc->InputSidePackets().Index(0).At(Timestamp::PostStream()));
cc->Outputs().Index(0).Close(); cc->Outputs().Index(0).Close();
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) final { absl::Status MediaPipeInternalSidePacketToPacketStreamCalculator::Process(
CalculatorContext* cc) {
// The framework treats this calculator as a source calculator. // The framework treats this calculator as a source calculator.
return mediapipe::tool::StatusStop(); return mediapipe::tool::StatusStop();
} }
};
REGISTER_CALCULATOR(MediaPipeInternalSidePacketToPacketStreamCalculator); REGISTER_CALCULATOR(MediaPipeInternalSidePacketToPacketStreamCalculator);
} // namespace
void AddVectorSink(const std::string& stream_name, // void AddVectorSink(const std::string& stream_name, //
CalculatorGraphConfig* config, // CalculatorGraphConfig* config, //

View File

@ -28,10 +28,12 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_SINK_H_ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_SINK_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_SINK_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_SINK_H_
#include <functional>
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/packet_type.h" #include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
@ -205,6 +207,16 @@ class CallbackWithHeaderCalculator : public CalculatorBase {
Packet header_packet_; Packet header_packet_;
}; };
// Produces an output packet with the PostStream timestamp containing the
// input side packet.
class MediaPipeInternalSidePacketToPacketStreamCalculator
: public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) final;
absl::Status Process(CalculatorContext* cc) final;
};
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe