mediapipe/mediapipe/framework/calculator_base.h
MediaPipe Team 350fbb2100 Project import generated by Copybara.
GitOrigin-RevId: d073f8e21be2fcc0e503cb97c6695078b6b75310
2021-02-27 03:30:05 -05:00

238 lines
10 KiB
C++

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines CalculatorBase, the base class for feature computation.
#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_
#define MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_
#include <type_traits>
#include "absl/memory/memory.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
// Experimental: CalculatorBase will eventually replace Calculator as the
// base class of leaf (non-subgraph) nodes in a CalculatorGraph.
//
// The base calculator class. A subclass must, at a minimum, provide the
// implementation of GetContract(), Process(), and register the calculator
// using REGISTER_CALCULATOR(MyClass).
//
// The framework calls four primary functions on a calculator.
// On initialization of the graph, a static function is called.
// GetContract()
// Then, for each run of the graph on a set of input side packets, the
// following sequence will occur.
// Open()
// Process() (repeatedly)
// Close()
//
// The entire calculator is constructed and destroyed for each graph run
// (set of input side packets, which could mean once per video, or once
// per image). Any expensive operations and large objects should be
// input side packets.
//
// The framework calls Open() to initialize the calculator.
// If appropriate, Open() should call cc->SetOffset() or
// cc->Outputs().Get(id)->SetNextTimestampBound() to allow the framework to
// better optimize packet queueing.
//
// The framework calls Process() for every packet received on the input
// streams. The framework guarantees that cc->InputTimestamp() will
// increase with every call to Process(). An empty packet will be on the
// input stream if there is no packet on a particular input stream (but
// some other input stream has a packet).
//
// The framework calls Close() after all calls to Process().
//
// Calculators with no inputs are referred to as "sources" and are handled
// slightly differently than non-sources (see the function comments for
// Process() for more details).
//
// Calculators must be thread-compatible.
// The framework does not call the non-const methods of a calculator from
// multiple threads at the same time. However, the thread that calls the
// methods of a calculator is not fixed. Therefore, calculators should not
// use ThreadLocal objects.
class CalculatorBase {
public:
CalculatorBase();
virtual ~CalculatorBase();
// The subclasses of CalculatorBase must implement GetContract.
// The calculator cannot be registered without it. Notice that although
// this function is static the registration macro provides access to
// each subclass' GetContract function.
//
// static absl::Status GetContract(CalculatorContract* cc);
//
// GetContract fills in the calculator's contract with the framework, such
// as its expectations of what packets it will receive. When this function
// is called, the numbers of inputs, outputs, and input side packets will
// have already been determined by the calculator graph. You can use
// indexes, tags, or tag:index to access input streams, output streams,
// or input side packets.
//
// Example (uses tags for inputs and indexes for outputs and input side
// packets):
// cc->Inputs().Tag("VIDEO").Set<ImageFrame>("Input Image Frames.");
// cc->Inputs().Tag("AUDIO").Set<Matrix>("Input Audio Frames.");
// cc->Outputs().Index(0).Set<Matrix>("Output FooBar feature.");
// cc->InputSidePackets().Index(0).Set<MyModel>(
// "Model used for FooBar feature extraction.");
//
// Example (same number and type of outputs as inputs):
// for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
// // SetAny() is used to specify that whatever the type of the
// // stream is, it's acceptable. This does not mean that any
// // packet is acceptable. Packets in the stream still have a
// // particular type. SetAny() has the same effect as explicitly
// // setting the type to be the stream's type.
// cc->Inputs().Index(i).SetAny(StrCat("Generic Input Stream ", i));
// // Set each output to accept the same specific type as the
// // corresponding input.
// cc->Outputs().Index(i).SetSameAs(
// &cc->Inputs().Index(i), StrCat("Generic Output Stream ", i));
// }
// Open is called before any Process() calls, on a freshly constructed
// calculator. Subclasses may override this method to perform necessary
// setup, and possibly output Packets and/or set output streams' headers.
// Must return absl::OkStatus() to indicate success. On failure any
// other status code can be returned. If failure is returned then the
// framework will call neither Process() nor Close() on the calculator (so any
// necessary cleanup should be done before returning failure or in the
// destructor).
virtual absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
// Processes the incoming inputs. May call the methods on cc to access
// inputs and produce outputs.
//
// Process() called on a non-source node must return
// absl::OkStatus() to indicate that all went well, or any other
// status code to signal an error.
// For example:
// absl::UnknownError("Failure Message");
// Notice the convenience functions in util/task/canonical_errors.h .
// If a non-source Calculator returns tool::StatusStop(), then this
// signals the graph is being cancelled early. In this case, all
// source Calculators and graph input streams will be closed (and
// remaining Packets will propagate through the graph).
//
// A source node will continue to have Process() called on it as long
// as it returns absl::OkStatus(). To indicate that there is
// no more data to be generated return tool::StatusStop(). Any other
// status indicates an error has occurred.
virtual absl::Status Process(CalculatorContext* cc) = 0;
// Is called if Open() was called and succeeded. Is called either
// immediately after processing is complete or after a graph run has ended
// (if an error occurred in the graph). Must return absl::OkStatus()
// to indicate success. On failure any other status code can be returned.
// Packets may be output during a call to Close(). However, output packets
// are silently discarded if Close() is called after a graph run has ended.
//
// NOTE: If Close() needs to perform an action only when processing is
// complete, Close() must check if cc->GraphStatus() is OK.
virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); }
// Returns a value according to which the framework selects
// the next source calculator to Process(); smaller value means
// Process() first. The default implementation returns the smallest
// NextTimestampBound value over all the output streams, but subclasses
// may override this. If a calculator is not a source, this method is
// not called.
// TODO: Does this method need to be virtual? No Calculator
// subclasses override the SourceProcessOrder method.
virtual Timestamp SourceProcessOrder(const CalculatorContext* cc) const;
};
namespace api2 {
class Node;
} // namespace api2
namespace internal {
// Gives access to the static functions within subclasses of CalculatorBase.
// This adds functionality akin to virtual static functions.
class CalculatorBaseFactory {
public:
virtual ~CalculatorBaseFactory() {}
virtual absl::Status GetContract(CalculatorContract* cc) = 0;
virtual std::unique_ptr<CalculatorBase> CreateCalculator(
CalculatorContext* calculator_context) = 0;
virtual std::string ContractMethodName() { return "GetContract"; }
};
// Functions for checking that the calculator has the required GetContract.
template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
typedef absl::Status (*GetContractType)(CalculatorContract * cc);
return std::is_same<decltype(&T::GetContract), GetContractType>::value;
}
template <class T>
constexpr bool CalculatorHasGetContract(...) {
return false;
}
// Provides access to the static functions within a specific subclass
// of CalculatorBase.
template <class T, class Enable = void>
class CalculatorBaseFactoryFor : public CalculatorBaseFactory {
static_assert(std::is_base_of<mediapipe::CalculatorBase, T>::value,
"Classes registered with REGISTER_CALCULATOR must be "
"subclasses of mediapipe::CalculatorBase.");
};
template <class T>
class CalculatorBaseFactoryFor<
T,
typename std::enable_if<std::is_base_of<mediapipe::CalculatorBase, T>{} &&
!std::is_base_of<mediapipe::api2::Node, T>{}>::type>
: public CalculatorBaseFactory {
public:
static_assert(CalculatorHasGetContract<T>(nullptr),
"GetContract() must be defined with the correct signature in "
"every calculator.");
// Provides access to the static function GetContract within a specific
// subclass of CalculatorBase.
absl::Status GetContract(CalculatorContract* cc) final {
// CalculatorBaseSubclass must implement this function, since it is not
// implemented in the parent class.
return T::GetContract(cc);
}
std::unique_ptr<CalculatorBase> CreateCalculator(
CalculatorContext* calculator_context) final {
return absl::make_unique<T>();
}
};
} // namespace internal
using CalculatorBaseRegistry =
GlobalFactoryRegistry<std::unique_ptr<internal::CalculatorBaseFactory>>;
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_