Project import generated by Copybara.

GitOrigin-RevId: f4b1fe3f15810450fb6539e733f6a260d3ee082c
This commit is contained in:
MediaPipe Team 2021-09-01 13:49:12 -07:00 committed by jqtang
parent 710fb3de58
commit 6abec128ed
64 changed files with 2384 additions and 161 deletions

View File

@ -157,11 +157,11 @@ http_archive(
http_archive(
name = "pybind11",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz",
"https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz",
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.7.1.tar.gz",
"https://github.com/pybind/pybind11/archive/v2.7.1.tar.gz",
],
sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d",
strip_prefix = "pybind11-2.4.3",
sha256 = "616d1c42e4cf14fa27b2a4ff759d7d7b33006fdc5ad8fd603bb2c22622f27020",
strip_prefix = "pybind11-2.7.1",
build_file = "@pybind11_bazel//:pybind11.BUILD",
)

View File

@ -113,6 +113,10 @@ bazel to build the iOS application. The content of the
5. `Main.storyboard` and `Launch.storyboard`
6. `Assets.xcassets` directory.
Note: In newer versions of Xcode, you may see additional files `SceneDelegate.h`
and `SceneDelegate.m`. Make sure to copy them too and add them to the `BUILD`
file mentioned below.
Copy these files to a directory named `HelloWorld` to a location that can access
the MediaPipe source code. For example, the source code of the application that
we will build in this tutorial is located in
@ -247,6 +251,12 @@ We need to get frames from the `_cameraSource` into our application
`MPPInputSourceDelegate`. So our application `ViewController` can be a delegate
of `_cameraSource`.
Update the interface definition of `ViewController` accordingly:
```
@interface ViewController () <MPPInputSourceDelegate>
```
To handle camera setup and process incoming frames, we should use a queue
different from the main queue. Add the following to the implementation block of
the `ViewController`:
@ -288,6 +298,12 @@ utility called `MPPLayerRenderer` to display images on the screen. This utility
can be used to display `CVPixelBufferRef` objects, which is the type of the
images provided by `MPPCameraInputSource` to its delegates.
In `ViewController.m`, add the following import line:
```
#import "mediapipe/objc/MPPLayerRenderer.h"
```
To display images of the screen, we need to add a new `UIView` object called
`_liveView` to the `ViewController`.
@ -411,6 +427,12 @@ Objective-C++.
### Use the graph in `ViewController`
In `ViewController.m`, add the following import line:
```
#import "mediapipe/objc/MPPGraph.h"
```
Declare a static constant with the name of the graph, the input stream and the
output stream:
@ -549,6 +571,12 @@ method to receive packets on this output stream and display them on the screen:
}
```
Update the interface definition of `ViewController` with `MPPGraphDelegate`:
```
@interface ViewController () <MPPGraphDelegate, MPPInputSourceDelegate>
```
And that is all! Build and run the app on your iOS device. You should see the
results of running the edge detection graph on a live video feed. Congrats!
@ -560,5 +588,5 @@ appropriate `BUILD` file dependencies for the edge detection graph.
[Bazel]:https://bazel.build/
[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt
[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common)
[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld)
[common]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common
[helloworld]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld

View File

@ -796,7 +796,7 @@ This will use a Docker image that will isolate mediapipe's installation from the
```bash
$ docker run -it --name mediapipe mediapipe:latest
root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world
root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazelisk run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world
# Should print:
# Hello World!

View File

@ -529,7 +529,7 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http
> ```
> and then run
>
> ```build
> ```bash
> bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR]
> ```
> INPUT_DIR should be the folder with initial asset .obj files to be processed,

View File

@ -141,7 +141,7 @@ Optionally, MediaPipe Pose can predicts a full-body
Please find more detail in the
[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html),
this [paper](https://arxiv.org/abs/2006.10204),
[the model card](./models.md#pose) and the [Output](#Output) section below.
[the model card](./models.md#pose) and the [Output](#output) section below.
## Solution APIs
@ -281,8 +281,8 @@ with mp_pose.Pose(
continue
print(
f'Nose coordinates: ('
f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, '
f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})'
f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].x * image_width}, '
f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].y * image_height})'
)
annotated_image = image.copy()
@ -369,6 +369,7 @@ Supported configuration options:
<div class="container">
<video class="input_video"></video>
<canvas class="output_canvas" width="1280px" height="720px"></canvas>
<div class="landmark-grid-container"></div>
</div>
</body>
</html>

View File

@ -262,7 +262,7 @@ to visualize its associated subgraphs, please see
[(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1DoeyGzMmWUsjfVgZfGGecrn7GKzYcEAo/view?usp=sharing)
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu:selfiesegmentationgpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu/BUILD)
* iOS target:
[`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](http:/mediapipe/examples/ios/selfiesegmentationgpu/BUILD)
[`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/selfiesegmentationgpu/BUILD)
### Desktop

View File

@ -13,6 +13,9 @@ has_toc: false
{:toc}
---
MediaPipe offers open source cross-platform, customizable ML solutions for live
and streaming media.
<!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. -->
<!-- Whenever this table is updated, paste a copy to ../external_index.md. -->

View File

@ -42,4 +42,9 @@ REGISTER_CALCULATOR(BeginLoopDetectionCalculator);
typedef BeginLoopCalculator<std::vector<Matrix>> BeginLoopMatrixCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixCalculator);
// A calculator to process std::vector<std::vector<Matrix>>.
typedef BeginLoopCalculator<std::vector<std::vector<Matrix>>>
BeginLoopMatrixVectorCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
} // namespace mediapipe

View File

@ -73,6 +73,7 @@ class InferenceCalculatorCpuImpl
private:
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_;
@ -91,8 +92,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadModel(cc));
MP_RETURN_IF_ERROR(LoadDelegate(cc));
return absl::OkStatus();
return LoadDelegateAndAllocateTensors(cc);
}
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
@ -156,11 +156,19 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) {
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
return absl::OkStatus();
}
absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type !=
RET_CHECK_NE(
interpreter_->tensor(interpreter_->inputs()[0])->quantization.type,
kTfLiteAffineQuantization);
return absl::OkStatus();
}

View File

@ -53,6 +53,7 @@ class InferenceCalculatorGlImpl
absl::Status WriteKernelsToFile();
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
@ -119,9 +120,10 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
}
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this,
&cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc);
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: LoadDelegateAndAllocateTensors(cc);
}));
return absl::OkStatus();
}
@ -324,11 +326,19 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type !=
RET_CHECK_NE(
interpreter_->tensor(interpreter_->inputs()[0])->quantization.type,
kTfLiteAffineQuantization);
return absl::OkStatus();
}

View File

@ -92,6 +92,7 @@ class InferenceCalculatorMetalImpl
private:
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_;
@ -130,8 +131,7 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
MP_RETURN_IF_ERROR(LoadDelegate(cc));
return absl::OkStatus();
return LoadDelegateAndAllocateTensors(cc);
}
absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
@ -212,11 +212,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) {
interpreter_->SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
return absl::OkStatus();
}
absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type !=
RET_CHECK_NE(
interpreter_->tensor(interpreter_->inputs()[0])->quantization.type,
kTfLiteAffineQuantization);
return absl::OkStatus();
}
@ -236,6 +244,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
id<MTLDevice> device = gpu_helper_.mtlDevice;
// Get input image sizes.

View File

@ -670,7 +670,8 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
detection_scores[i], detection_classes[i], options_.flip_vertically());
const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0) {
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
std::isnan(bbox.height())) {
// Decoded detection boxes could have negative values for width/height due
// to model prediction. Filter out those boxes since some downstream
// calculators may assume non-negative values. (b/171391719)

View File

@ -138,7 +138,6 @@ using ::tflite::gpu::gl::GlShader;
// }
// }
//
// Currently only OpenGLES 3.1 and CPU backends supported.
// TODO Refactor and add support for other backends/platforms.
//
class TensorsToSegmentationCalculator : public CalculatorBase {

View File

@ -56,6 +56,8 @@ constexpr char kBboxTag[] = "BBOX";
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kImageTag[] = "IMAGE";
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
class UnpackMediaSequenceCalculatorTest : public ::testing::Test {

View File

@ -175,7 +175,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
auto* text = label_annotation->mutable_text();
std::string display_text = labels[i];
if (cc->Inputs().HasTag(kScoresTag)) {
if (cc->Inputs().HasTag(kScoresTag) ||
options_.display_classification_score()) {
absl::StrAppend(&display_text, ":", scores[i]);
}
text->set_display_text(display_text);

View File

@ -62,4 +62,7 @@ message LabelsToRenderDataCalculatorOptions {
// Uses Classification.display_name field instead of Classification.label.
optional bool use_display_name = 9 [default = false];
// Displays Classification score if enabled.
optional bool display_classification_score = 10 [default = false];
}

View File

@ -225,11 +225,10 @@ class SubgraphImpl : public Subgraph, public Intf {
// registration. Deprecated.
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(calculator_registration, __LINE__)( \
mediapipe::CalculatorBaseRegistry::Register( \
REGISTRY_STATIC_VAR(calculator_registration, \
__LINE__)(mediapipe::CalculatorBaseRegistry::Register( \
Impl::kCalculatorName, \
absl::make_unique< \
mediapipe::internal::CalculatorBaseFactoryFor<Impl>>))
absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>))
// This macro is used to register a non-split-contract calculator. Deprecated.
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)

View File

@ -454,12 +454,12 @@ class OutputShardAccessBase {
if (output_) output_->SetNextTimestampBound(timestamp);
}
bool IsClosed() { return output_ ? output_->IsClosed() : true; }
bool IsClosed() const { return output_ ? output_->IsClosed() : true; }
void Close() {
if (output_) output_->Close();
}
bool IsConnected() { return output_ != nullptr; }
bool IsConnected() const { return output_ != nullptr; }
protected:
const CalculatorContext& context_;
@ -559,7 +559,7 @@ class InputShardAccess : public Packet<T> {
PacketBase packet() const&& { return *this; }
bool IsDone() const { return stream_->IsDone(); }
bool IsConnected() { return stream_ != nullptr; }
bool IsConnected() const { return stream_ != nullptr; }
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
@ -619,7 +619,7 @@ class InputSidePacketAccess : public Packet<T> {
const PacketBase& packet() const& { return *this; }
PacketBase packet() const&& { return *this; }
bool IsConnected() { return connected_; }
bool IsConnected() const { return connected_; }
private:
InputSidePacketAccess(const mediapipe::Packet* packet)
@ -639,8 +639,8 @@ class InputShardOrSideAccess : public Packet<T> {
PacketBase packet() const&& { return *this; }
bool IsDone() const { return stream_->IsDone(); }
bool IsConnected() { return connected_; }
bool IsStream() { return stream_ != nullptr; }
bool IsConnected() const { return connected_; }
bool IsStream() const { return stream_ != nullptr; }
PacketBase Header() const { return FromOldPacket(stream_->Header()); }
@ -662,7 +662,7 @@ class InputShardOrSideAccess : public Packet<T> {
class PacketTypeAccess {
public:
bool IsConnected() { return packet_type_ != nullptr; }
bool IsConnected() const { return packet_type_ != nullptr; }
protected:
PacketTypeAccess(PacketType* pt) : packet_type_(pt) {}
@ -675,7 +675,7 @@ class PacketTypeAccess {
class PacketTypeAccessFallback : public PacketTypeAccess {
public:
bool IsStream() { return is_stream_; }
bool IsStream() const { return is_stream_; }
private:
PacketTypeAccessFallback(PacketType* pt, bool is_stream)

View File

@ -321,6 +321,8 @@ message CalculatorGraphConfig {
// The maximum number of invocations that can be executed in parallel.
// If not specified, the limit is one invocation.
int32 max_in_flight = 16;
// Defines an option value for this Node from graph options or packets.
repeated string option_value = 17;
// DEPRECATED: For backwards compatibility we allow users to
// specify the old name for "input_side_packet" in proto configs.
// These are automatically converted to input_side_packets during

View File

@ -465,7 +465,7 @@ absl::Status CalculatorGraph::ObserveOutputStream(
}
absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller(
const std::string& stream_name) {
const std::string& stream_name, bool observe_timestamp_bounds) {
RET_CHECK(initialized_).SetNoLogging()
<< "CalculatorGraph is not initialized.";
int output_stream_index = validated_graph_->OutputStreamIndex(stream_name);
@ -479,7 +479,7 @@ absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller(
stream_name, &any_packet_type_,
std::bind(&CalculatorGraph::UpdateThrottledNodes, this,
std::placeholders::_1, std::placeholders::_2),
&output_stream_managers_[output_stream_index]));
&output_stream_managers_[output_stream_index], observe_timestamp_bounds));
OutputStreamPoller poller(internal_poller);
graph_output_streams_.push_back(std::move(internal_poller));
return std::move(poller);

View File

@ -164,7 +164,8 @@ class CalculatorGraph {
// polling API for accessing a stream's output. Should only be called before
// Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See
// also the helpers in tool/sink.h.
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name);
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name,
bool observe_timestamp_bounds = false);
// Gets output side packet by name after the graph is done. However, base
// packets (generated by PacketGenerators) can be retrieved before

View File

@ -4348,5 +4348,349 @@ TEST(CalculatorGraph, GraphInputStreamWithTag) {
ASSERT_EQ(5, packet_dump.size());
}
TEST(CalculatorGraph, GraphInputStreamBeforeStartRun) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "VIDEO_METADATA:video_metadata"
input_stream: "max_count"
node {
calculator: "PassThroughCalculator"
input_stream: "FIRST_INPUT:video_metadata"
input_stream: "max_count"
output_stream: "FIRST_INPUT:output_0"
output_stream: "output_1"
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output_0", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
ASSERT_EQ(graph
.AddPacketToInputStream("video_metadata",
MakePacket<int>(0).At(Timestamp(0)))
.code(),
absl::StatusCode::kFailedPrecondition);
}
// Returns the first packet of the input stream.
class FirstPacketFilterCalculator : public CalculatorBase {
public:
FirstPacketFilterCalculator() {}
~FirstPacketFilterCalculator() override {}
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (!seen_first_packet_) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
cc->Outputs().Index(0).Close();
seen_first_packet_ = true;
}
return absl::OkStatus();
}
private:
bool seen_first_packet_ = false;
};
REGISTER_CALCULATOR(FirstPacketFilterCalculator);
constexpr int kDefaultMaxCount = 1000;
TEST(CalculatorGraph, TestPollPacket) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
for (int queue_size = 1; queue_size < 10; ++queue_size) {
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
poller.SetMaxQueueSize(queue_size);
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
}
TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node1 = config.add_node();
node1->set_calculator("CountingSourceCalculator");
node1->add_output_stream("stream1");
node1->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraphConfig::Node* node2 = config.add_node();
node2->set_calculator("PassThroughCalculator");
node2->add_input_stream("stream1");
node2->add_output_stream("stream2");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller1 = graph.AddOutputStreamPoller("stream1");
ASSERT_TRUE(status_or_poller1.ok());
OutputStreamPoller poller1 = std::move(status_or_poller1.value());
auto status_or_poller2 = graph.AddOutputStreamPoller("stream2");
ASSERT_TRUE(status_or_poller2.ok());
OutputStreamPoller poller2 = std::move(status_or_poller2.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet1;
Packet packet2;
int num_packets1 = 0;
int num_packets2 = 0;
int running_pollers = 2;
while (running_pollers > 0) {
if (poller1.Next(&packet1)) {
EXPECT_EQ(num_packets1++, packet1.Get<int>());
} else {
--running_pollers;
}
if (poller2.Next(&packet2)) {
EXPECT_EQ(num_packets2++, packet2.Get<int>());
} else {
--running_pollers;
}
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller1.Next(&packet1));
EXPECT_FALSE(poller2.Next(&packet2));
EXPECT_EQ(kDefaultMaxCount, num_packets1);
EXPECT_EQ(kDefaultMaxCount, num_packets2);
}
class TimestampBoundTestCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
if (count_ % 50 == 1) {
// Outputs packets at t10 and t60.
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(count_).At(Timestamp(count_)));
} else if (count_ % 15 == 7) {
cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_));
}
absl::SleepFor(absl::Milliseconds(3));
++count_;
if (count_ == 110) {
return tool::StatusStop();
}
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(TimestampBoundTestCalculator);
TEST(CalculatorGraph, TestPollPacketsWithTimestampNotification) {
std::string config_str = R"(
node {
calculator: "TimestampBoundTestCalculator"
output_stream: "foo"
}
)";
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
auto status_or_poller =
graph.AddOutputStreamPoller("foo", /*observe_timestamp_bounds=*/true);
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
Packet packet;
std::vector<int> timestamps;
std::vector<int> values;
MP_ASSERT_OK(graph.StartRun({}));
while (poller.Next(&packet)) {
if (packet.IsEmpty()) {
timestamps.push_back(packet.Timestamp().Value());
} else {
values.push_back(packet.Get<int>());
}
}
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_FALSE(poller.Next(&packet));
ASSERT_FALSE(timestamps.empty());
int prev_t = 0;
for (auto t : timestamps) {
EXPECT_TRUE(t > prev_t && t < 110);
prev_t = t;
}
ASSERT_EQ(3, values.size());
EXPECT_EQ(1, values[0]);
EXPECT_EQ(51, values[1]);
EXPECT_EQ(101, values[2]);
}
// Ensure that when a custom input stream handler is used to handle packets from
// input streams, an error message is outputted with the appropriate link to
// resolve the issue when the calculator doesn't handle inputs in monotonically
// increasing order of timestamps.
TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input0'
input_stream: 'input1'
node {
calculator: 'SimpleMuxCalculator'
input_stream: 'input0'
input_stream: 'input1'
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
output_stream: 'output'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Send packets to input stream "input0" at timestamps 0 and 1 consecutively.
Timestamp input0_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(1, packet_dump[0].Get<int>());
++input0_timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(3).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(3, packet_dump[1].Get<int>());
// Send a packet to input stream "input1" at timestamp 0 after sending two
// packets at timestamps 0 and 1 to input stream "input0". This will result
// in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle
// inputs from all streams in monotonically increasing order of timestamps.
Timestamp input1_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(2).At(input1_timestamp)));
absl::Status run_status = graph.WaitUntilIdle();
EXPECT_THAT(
run_status.ToString(),
testing::AllOf(
// The core problem.
testing::HasSubstr("timestamp mismatch on a calculator"),
testing::HasSubstr(
"timestamps that are not strictly monotonically increasing"),
// Link to the possible solution.
testing::HasSubstr("ImmediateInputStreamHandler class comment")));
}
void DoTestMultipleGraphRuns(absl::string_view input_stream_handler,
bool select_packet) {
std::string graph_proto = absl::StrFormat(R"(
input_stream: 'input'
input_stream: 'select'
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'select'
input_stream_handler {
input_stream_handler: "%s"
}
output_stream: 'output'
output_stream: 'select_out'
}
)",
input_stream_handler.data());
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
struct Run {
Timestamp timestamp;
int value;
};
std::vector<Run> runs = {{.timestamp = Timestamp(2000), .value = 2},
{.timestamp = Timestamp(1000), .value = 1}};
for (const Run& run : runs) {
MP_ASSERT_OK(graph.StartRun({}));
if (select_packet) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(run.timestamp)));
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(run.value).At(run.timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(run.value, packet_dump[0].Get<int>());
EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
packet_dump.clear();
}
}
TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) {
DoTestMultipleGraphRuns("BarrierInputStreamHandler", true);
DoTestMultipleGraphRuns("DefaultInputStreamHandler", true);
DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true);
DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true);
DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false);
DoTestMultipleGraphRuns("MuxInputStreamHandler", true);
DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true);
DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true);
}
} // namespace
} // namespace mediapipe

View File

@ -45,7 +45,7 @@ std::string JoinPathImpl(bool honor_abs,
// This size calculation is worst-case: it assumes one extra "/" for every
// path other than the first.
size_t total_size = paths.size() - 1;
for (const absl::string_view path : paths) total_size += path.size();
for (const absl::string_view& path : paths) total_size += path.size();
result.resize(total_size);
auto begin = result.begin();

View File

@ -81,6 +81,12 @@ mediapipe_proto_library(
deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"],
)
mediapipe_proto_library(
name = "affine_transform_data_proto",
srcs = ["affine_transform_data.proto"],
visibility = ["//visibility:public"],
)
mediapipe_proto_library(
name = "time_series_header_proto",
srcs = ["time_series_header.proto"],
@ -119,6 +125,31 @@ cc_library(
],
)
cc_library(
name = "affine_transform",
srcs = ["affine_transform.cc"],
hdrs = ["affine_transform.h"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework:type_map",
"//mediapipe/framework/formats:affine_transform_data_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:point",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
],
)
cc_library(
name = "image_frame",
srcs = ["image_frame.cc"],

View File

@ -0,0 +1,228 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/formats/affine_transform.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/point2.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/type_map.h"
namespace mediapipe {
using ::mediapipe::AffineTransformData;
AffineTransform::AffineTransform() { SetScale(Point2_f(1, 1)); }
AffineTransform::AffineTransform(
const AffineTransformData& affine_transform_data)
: affine_transform_data_(affine_transform_data), is_dirty_(true) {
// make sure scale is set to default (1, 1) when none provided
if (!affine_transform_data_.has_scale()) {
SetScale(Point2_f(1, 1));
}
}
AffineTransform AffineTransform::Create(const Point2_f& translation,
const Point2_f& scale, float rotation,
const Point2_f& shear) {
AffineTransformData affine_transform_data;
auto* t = affine_transform_data.mutable_translation();
t->set_x(translation.x());
t->set_y(translation.y());
auto* s = affine_transform_data.mutable_scale();
s->set_x(scale.x());
s->set_y(scale.y());
s = affine_transform_data.mutable_shear();
s->set_x(shear.x());
s->set_y(shear.y());
affine_transform_data.set_rotation(rotation);
return AffineTransform(affine_transform_data);
}
// Accessor for the composition matrix
std::vector<float> AffineTransform::GetCompositionMatrix() {
float r = affine_transform_data_.rotation();
const auto t = affine_transform_data_.translation();
const auto sc = affine_transform_data_.scale();
const auto sh = affine_transform_data_.shear();
if (is_dirty_) {
// Composition matrix M = T*R*Sh*Sc
// Column based to match GL matrix store order
float cos_r = std::cos(r);
float sin_r = std::sin(r);
matrix_[0] = (cos_r + sin_r * -sh.y()) * sc.x();
matrix_[1] = (-sin_r + cos_r * -sh.y()) * sc.x();
matrix_[2] = 0;
matrix_[3] = (cos_r * -sh.x() + sin_r) * sc.y();
matrix_[4] = (-sin_r * -sh.x() + cos_r) * sc.y();
matrix_[5] = 0;
matrix_[6] = t.x();
matrix_[7] = -t.y();
matrix_[8] = 1;
is_dirty_ = false;
}
return matrix_;
}
Point2_f AffineTransform::GetScale() const {
return Point2_f(affine_transform_data_.scale().x(),
affine_transform_data_.scale().y());
}
Point2_f AffineTransform::GetTranslation() const {
return Point2_f(affine_transform_data_.translation().x(),
affine_transform_data_.translation().y());
}
Point2_f AffineTransform::GetShear() const {
return Point2_f(affine_transform_data_.shear().x(),
affine_transform_data_.shear().y());
}
float AffineTransform::GetRotation() const {
return affine_transform_data_.rotation();
}
void AffineTransform::SetScale(const Point2_f& scale) {
auto* s = affine_transform_data_.mutable_scale();
s->set_x(scale.x());
s->set_y(scale.y());
is_dirty_ = true;
}
void AffineTransform::SetTranslation(const Point2_f& translation) {
auto* t = affine_transform_data_.mutable_translation();
t->set_x(translation.x());
t->set_y(translation.y());
is_dirty_ = true;
}
void AffineTransform::SetShear(const Point2_f& shear) {
auto* s = affine_transform_data_.mutable_shear();
s->set_x(shear.x());
s->set_y(shear.y());
is_dirty_ = true;
}
void AffineTransform::SetRotation(float rotationInRadians) {
affine_transform_data_.set_rotation(rotationInRadians);
is_dirty_ = true;
}
void AffineTransform::AddScale(const Point2_f& scale) {
auto* s = affine_transform_data_.mutable_scale();
s->set_x(s->x() + scale.x());
s->set_y(s->y() + scale.y());
is_dirty_ = true;
}
void AffineTransform::AddTranslation(const Point2_f& translation) {
auto* t = affine_transform_data_.mutable_translation();
t->set_x(t->x() + translation.x());
t->set_y(t->y() + translation.y());
is_dirty_ = true;
}
void AffineTransform::AddShear(const Point2_f& shear) {
auto* s = affine_transform_data_.mutable_shear();
s->set_x(s->x() + shear.x());
s->set_y(s->y() + shear.y());
is_dirty_ = true;
}
void AffineTransform::AddRotation(float rotationInRadians) {
affine_transform_data_.set_rotation(affine_transform_data_.rotation() +
rotationInRadians);
is_dirty_ = true;
}
void AffineTransform::SetFromProto(const AffineTransformData& proto) {
affine_transform_data_ = proto;
}
void AffineTransform::ConvertToProto(AffineTransformData* proto) const {
*proto = affine_transform_data_;
}
AffineTransformData AffineTransform::ConvertToProto() const {
AffineTransformData affine_transform_data;
ConvertToProto(&affine_transform_data);
return affine_transform_data;
}
bool compare(float lhs, float rhs, float epsilon = 0.001f) {
return std::fabs(lhs - rhs) < epsilon;
}
bool AffineTransform::Equals(const AffineTransform& other,
float epsilon) const {
auto trans1 = GetTranslation();
auto trans2 = other.GetTranslation();
if (!(compare(trans1.x(), trans2.x(), epsilon) &&
compare(trans1.y(), trans2.y(), epsilon)))
return false;
auto scale1 = GetScale();
auto scale2 = other.GetScale();
if (!(compare(scale1.x(), scale2.x(), epsilon) &&
compare(scale1.y(), scale2.y(), epsilon)))
return false;
auto shear1 = GetShear();
auto shear2 = other.GetShear();
if (!(compare(shear1.x(), shear2.x(), epsilon) &&
compare(shear1.y(), shear2.y(), epsilon)))
return false;
auto rot1 = GetRotation();
auto rot2 = other.GetRotation();
if (!compare(rot1, rot2, epsilon)) {
return false;
}
return true;
}
bool AffineTransform::Equal(const AffineTransform& lhs,
const AffineTransform& rhs, float epsilon) {
return lhs.Equals(rhs, epsilon);
}
MEDIAPIPE_REGISTER_TYPE(mediapipe::AffineTransform,
"::mediapipe::AffineTransform", nullptr, nullptr);
} // namespace mediapipe

View File

@ -0,0 +1,86 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// A container for affine transform data
// This wrapper provides two functionalities:
// 1. Factory methods for creation of Transform objects and thus
// AffineTransformData protocol buffers. These methods guarantee a valid
// affine transform data and are the preferred way of creating such.
// 2. Accessors which allow for access of the data and the convertion to proto
// format
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_
#include <memory>
#include <vector>
#include "mediapipe/framework/formats/affine_transform_data.pb.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/point2.h"
namespace mediapipe {
class AffineTransform {
public:
// CREATION METHODS.
AffineTransform();
// Constructs a affine transform wrapping the specified affine transform data.
// Checks the validity of the input and crashes upon failure.
explicit AffineTransform(const AffineTransformData& transform_data);
static AffineTransform Create(const Point2_f& translation = Point2_f(0, 0),
const Point2_f& scale = Point2_f(1, 1),
float rotation = 0,
const Point2_f& shear = Point2_f(0, 0));
// ACCESSORS
// Accessor for the composition matrix
std::vector<float> GetCompositionMatrix();
Point2_f GetScale() const;
Point2_f GetTranslation() const;
Point2_f GetShear() const;
float GetRotation() const;
void SetScale(const Point2_f& scale);
void SetTranslation(const Point2_f& translation);
void SetShear(const Point2_f& shear);
void SetRotation(float rotation);
void AddScale(const Point2_f& scale);
void AddTranslation(const Point2_f& translation);
void AddShear(const Point2_f& shear);
void AddRotation(float rotation);
// Serializes and deserializes the affine transform object.
void ConvertToProto(AffineTransformData* proto) const;
AffineTransformData ConvertToProto() const;
void SetFromProto(const AffineTransformData& proto);
bool Equals(const AffineTransform& other, float epsilon = 0.001f) const;
static bool Equal(const AffineTransform& lhs, const AffineTransform& rhs,
float epsilon = 0.001f);
private:
// The wrapped transform data.
AffineTransformData affine_transform_data_;
std::vector<float> matrix_ = {1, 0, 0, 0, 1, 0, 0, 0, 1};
bool is_dirty_ = false;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_

View File

@ -0,0 +1,33 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
option objc_class_prefix = "MediaPipe";
// Proto for serializing Vector2 data
message Vector2Data {
optional float x = 1;
optional float y = 2;
}
// Proto for serializing Affine Transform data.
message AffineTransformData {
optional Vector2Data translation = 1;
optional Vector2Data scale = 2;
optional Vector2Data shear = 3;
optional float rotation = 4; // in radians
}

View File

@ -0,0 +1,94 @@
#include "mediapipe/framework/formats/affine_transform.h"
#include <string>
#include "base/logging.h"
#include "mediapipe/framework/formats/affine_transform_data.pb.h"
#include "mediapipe/framework/port/point2.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
TEST(AffineTransformTest, TraslationTest) {
AffineTransform transform;
transform.SetTranslation(Point2_f(10, -3));
auto trans = transform.GetTranslation();
EXPECT_FLOAT_EQ(10, trans.x());
EXPECT_FLOAT_EQ(-3, trans.y());
transform.AddTranslation(Point2_f(-10, 3));
trans = transform.GetTranslation();
EXPECT_FLOAT_EQ(0, trans.x());
EXPECT_FLOAT_EQ(0, trans.y());
}
TEST(AffineTransformTest, ScaleTest) {
AffineTransform transform;
transform.SetScale(Point2_f(10, -3));
auto scale = transform.GetScale();
EXPECT_FLOAT_EQ(10, scale.x());
EXPECT_FLOAT_EQ(-3, scale.y());
transform.AddScale(Point2_f(-10, 3));
scale = transform.GetScale();
EXPECT_FLOAT_EQ(0, scale.x());
EXPECT_FLOAT_EQ(0, scale.y());
}
TEST(AffineTransformTest, RotationTest) {
AffineTransform transform;
transform.SetRotation(0.7);
float rot = transform.GetRotation();
EXPECT_FLOAT_EQ(0.7, rot);
transform.AddRotation(-0.7);
rot = transform.GetRotation();
EXPECT_FLOAT_EQ(0, rot);
}
TEST(AffineTransformTest, ShearTest) {
AffineTransform transform;
transform.SetShear(Point2_f(10, -3));
auto shear = transform.GetShear();
EXPECT_FLOAT_EQ(10, shear.x());
EXPECT_FLOAT_EQ(-3, shear.y());
transform.AddShear(Point2_f(-10, 3));
shear = transform.GetShear();
EXPECT_FLOAT_EQ(0, shear.x());
EXPECT_FLOAT_EQ(0, shear.y());
}
TEST(AffineTransformTest, TransformTest) {
AffineTransform transform1;
transform1 = AffineTransform::Create(Point2_f(0.1, -0.2), Point2_f(0.3, -0.4),
0.5, Point2_f(0.6, -0.7));
AffineTransform transform2;
transform2 = AffineTransform::Create(Point2_f(0.1, -0.2), Point2_f(0.3, -0.4),
0.5, Point2_f(0.6, -0.7));
EXPECT_THAT(true, transform1.Equals(transform2));
EXPECT_THAT(true, AffineTransform::Equal(transform1, transform2));
transform1 = AffineTransform::Create(Point2_f(0.00001, -0.00002),
Point2_f(0.00003, -0.00004), 0.00005,
Point2_f(0.00006, -0.00007));
transform2 = AffineTransform::Create(Point2_f(0.00001, -0.00002),
Point2_f(0.00003, -0.00004), 0.00005,
Point2_f(0.00006, -0.00007));
EXPECT_THAT(true, transform1.Equals(transform2, 0.000001));
EXPECT_THAT(true, AffineTransform::Equal(transform1, transform2, 0.000001));
}
} // namespace mediapipe

View File

@ -125,9 +125,10 @@ absl::Status OutputStreamObserver::Notify() {
absl::Status OutputStreamPollerImpl::Initialize(
const std::string& stream_name, const PacketType* packet_type,
std::function<void(InputStreamManager*, bool*)> queue_size_callback,
OutputStreamManager* output_stream_manager) {
OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
MP_RETURN_IF_ERROR(GraphOutputStream::Initialize(stream_name, packet_type,
output_stream_manager));
output_stream_manager,
observe_timestamp_bounds));
input_stream_handler_->SetQueueSizeCallbacks(queue_size_callback,
queue_size_callback);
return absl::OkStatus();
@ -176,11 +177,17 @@ void OutputStreamPollerImpl::NotifyError() {
bool OutputStreamPollerImpl::Next(Packet* packet) {
CHECK(packet);
bool empty_queue = true;
bool timestamp_bound_changed = false;
Timestamp min_timestamp = Timestamp::Unset();
mutex_.Lock();
while (true) {
min_timestamp = input_stream_->MinTimestampOrBound(&empty_queue);
if (graph_has_error_ || !empty_queue ||
if (empty_queue) {
timestamp_bound_changed =
input_stream_handler_->ProcessTimestampBounds() &&
output_timestamp_ < min_timestamp.PreviousAllowedInStream();
}
if (graph_has_error_ || !empty_queue || timestamp_bound_changed ||
min_timestamp == Timestamp::Done()) {
break;
} else {
@ -191,10 +198,16 @@ bool OutputStreamPollerImpl::Next(Packet* packet) {
mutex_.Unlock();
return false;
}
if (empty_queue) {
output_timestamp_ = min_timestamp.PreviousAllowedInStream();
} else {
output_timestamp_ = min_timestamp;
}
mutex_.Unlock();
if (min_timestamp == Timestamp::Done()) {
return false;
}
if (!empty_queue) {
int num_packets_dropped = 0;
bool stream_is_done = false;
*packet = input_stream_->PopPacketAtTimestamp(
@ -202,6 +215,9 @@ bool OutputStreamPollerImpl::Next(Packet* packet) {
CHECK_EQ(num_packets_dropped, 0)
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
num_packets_dropped, input_stream_->Name());
} else if (timestamp_bound_changed) {
*packet = Packet().At(min_timestamp.PreviousAllowedInStream());
}
return true;
}

View File

@ -143,7 +143,8 @@ class OutputStreamPollerImpl : public GraphOutputStream {
absl::Status Initialize(
const std::string& stream_name, const PacketType* packet_type,
std::function<void(InputStreamManager*, bool*)> queue_size_callback,
OutputStreamManager* output_stream_manager);
OutputStreamManager* output_stream_manager,
bool observe_timestamp_bounds = false);
void PrepareForRun(std::function<void()> notification_callback,
std::function<void(absl::Status)> error_callback) override;
@ -170,6 +171,7 @@ class OutputStreamPollerImpl : public GraphOutputStream {
absl::Mutex mutex_;
absl::CondVar handler_condvar_ ABSL_GUARDED_BY(mutex_);
bool graph_has_error_ ABSL_GUARDED_BY(mutex_);
Timestamp output_timestamp_ ABSL_GUARDED_BY(mutex_) = Timestamp::Min();
};
} // namespace internal

View File

@ -51,4 +51,17 @@ message NightLightCalculatorOptions {
// Format string used by string::Substitute to construct the output.
optional string format_string = 9;
message LightBundle {
optional string room_id = 1;
repeated NightLightCalculatorOptions room_lights = 2;
}
repeated LightBundle bundle = 10;
// The number of night-lights.
repeated int32 num_lights = 11;
// Options for nested night-lights.
optional NightLightCalculatorOptions sub_options = 12;
}

View File

@ -180,15 +180,66 @@ cc_library(
],
)
mediapipe_proto_library(
name = "field_data_proto",
srcs = ["field_data.proto"],
visibility = ["//visibility:public"],
deps = ["@com_google_protobuf//:any_proto"],
)
cc_library(
name = "options_field_util",
srcs = ["options_field_util.cc"],
hdrs = ["options_field_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":field_data_cc_proto",
":name_util",
":options_registry",
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "options_syntax_util",
srcs = ["options_syntax_util.cc"],
hdrs = ["options_syntax_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":name_util",
":options_field_util",
":options_registry",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "options_util",
srcs = ["options_util.cc"],
hdrs = ["options_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":options_field_util",
":options_map",
":options_registry",
":options_syntax_util",
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:collection",
"//mediapipe/framework:input_stream_shard",
"//mediapipe/framework:output_side_packet",
@ -199,7 +250,7 @@ cc_library(
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:type_util",
"//mediapipe/framework/tool:name_util",
"@com_google_absl//absl/strings",
],
)
@ -227,6 +278,8 @@ cc_library(
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
],
)
@ -246,11 +299,13 @@ mediapipe_cc_test(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:validated_graph_config",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/testdata:night_light_calculator_options_lib",
"//mediapipe/framework/tool:node_chain_subgraph_options_lib",
"//mediapipe/framework/tool:options_syntax_util",
"//mediapipe/util:header_util",
],
)

View File

@ -0,0 +1,47 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from mediapipe/framework/tool/source.proto.
// The forked proto must remain identical to the original proto and should be
// ONLY used by mediapipe open source project.
syntax = "proto2";
package mediapipe;
// `MessageData`, like protobuf.Any, contains an arbitrary serialized protbuf
// along with a URL that describes the type of the serialized message.
message MessageData {
// A URL/resource name that identifies the type of serialized protbuf.
optional string type_url = 1;
// Must be a valid serialized protocol buffer of the above specified type.
optional bytes value = 2;
}
// Data for one Protobuf field or one MediaPipe packet.
message FieldData {
oneof value {
sint32 int32_value = 1;
sint64 int64_value = 2;
uint32 uint32_value = 3;
uint64 uint64_value = 4;
double double_value = 5;
float float_value = 6;
bool bool_value = 7;
sint32 enum_value = 8;
string string_value = 9;
MessageData message_value = 10;
}
}

View File

@ -0,0 +1,495 @@
#include "mediapipe/framework/tool/options_field_util.h"
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/proto_util_lite.h"
namespace mediapipe {
namespace tool {
namespace options_field_util {
using ::mediapipe::proto_ns::internal::WireFormatLite;
using FieldType = WireFormatLite::FieldType;
using ::mediapipe::proto_ns::io::ArrayInputStream;
using ::mediapipe::proto_ns::io::CodedInputStream;
using ::mediapipe::proto_ns::io::CodedOutputStream;
using ::mediapipe::proto_ns::io::StringOutputStream;
// Utility functions for OptionsFieldUtil.
namespace {
// Converts a FieldDescriptor::Type to the corresponding FieldType.
FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) {
return static_cast<FieldType>(type);
}
absl::Status WriteValue(const FieldData& value, FieldType field_type,
std::string* field_bytes) {
StringOutputStream sos(field_bytes);
CodedOutputStream out(&sos);
switch (field_type) {
case WireFormatLite::TYPE_INT32:
WireFormatLite::WriteInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_SINT32:
WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_INT64:
WireFormatLite::WriteInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_SINT64:
WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_UINT32:
WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out);
break;
case WireFormatLite::TYPE_UINT64:
WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_DOUBLE:
WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_FLOAT:
WireFormatLite::WriteFloatNoTag(value.float_value(), &out);
break;
case WireFormatLite::TYPE_BOOL:
WireFormatLite::WriteBoolNoTag(value.bool_value(), &out);
break;
case WireFormatLite::TYPE_ENUM:
WireFormatLite::WriteEnumNoTag(value.enum_value(), &out);
break;
case WireFormatLite::TYPE_STRING:
out.WriteString(value.string_value());
break;
case WireFormatLite::TYPE_MESSAGE:
out.WriteString(value.message_value().value());
break;
default:
return absl::UnimplementedError(
absl::StrCat("Cannot write type: ", field_type));
}
return mediapipe::OkStatus();
}
// Serializes a packet value.
absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field,
std::string* result) {
FieldType field_type = AsFieldType(field->type());
return WriteValue(packet, field_type, result);
}
template <typename ValueT, FieldType kFieldType>
static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) {
ArrayInputStream ais(field_bytes.data(), field_bytes.size());
CodedInputStream input(&ais);
ValueT result;
if (!WireFormatLite::ReadPrimitive<ValueT, kFieldType>(&input, &result)) {
status->Update(mediapipe::InvalidArgumentError(absl::StrCat(
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<ValueT>(),
".")));
}
return result;
}
absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type,
absl::string_view message_type, FieldData* result) {
absl::Status status;
result->Clear();
switch (field_type) {
case WireFormatLite::TYPE_INT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_INT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_SINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_INT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_INT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_SINT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT32:
result->set_uint32_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT64:
result->set_uint64_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_DOUBLE:
result->set_double_value(
ReadValue<double, WireFormatLite::TYPE_DOUBLE>(field_bytes, &status));
break;
case WireFormatLite::TYPE_FLOAT:
result->set_float_value(
ReadValue<float, WireFormatLite::TYPE_FLOAT>(field_bytes, &status));
break;
case WireFormatLite::TYPE_BOOL:
result->set_bool_value(
ReadValue<bool, WireFormatLite::TYPE_BOOL>(field_bytes, &status));
break;
case WireFormatLite::TYPE_ENUM:
result->set_enum_value(
ReadValue<int32, WireFormatLite::TYPE_ENUM>(field_bytes, &status));
break;
case WireFormatLite::TYPE_STRING:
result->set_string_value(std::string(field_bytes));
break;
case WireFormatLite::TYPE_MESSAGE:
result->mutable_message_value()->set_value(std::string(field_bytes));
result->mutable_message_value()->set_type_url(TypeUrl(message_type));
break;
default:
status = absl::UnimplementedError(
absl::StrCat("Cannot read type: ", field_type));
break;
}
return status;
}
// Deserializes a packet from a protobuf field.
absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field,
FieldData* result) {
FieldType field_type = AsFieldType(field->type());
std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE)
? field->message_type()->full_name()
: "";
return ReadValue(bytes, field_type, message_type, result);
}
// Converts a chain of fields and indexes into field-numbers and indexes.
ProtoUtilLite::ProtoPath AsProtoPath(const FieldPath& field_path) {
ProtoUtilLite::ProtoPath result;
for (auto field : field_path) {
result.push_back({field.first->number(), field.second});
}
return result;
}
// Returns the options protobuf for a subgraph.
// TODO: Ensure that this works with multiple options protobufs.
absl::Status GetOptionsMessage(
const proto_ns::RepeatedPtrField<mediapipe::protobuf::Any>& options_any,
const proto_ns::MessageLite& options_ext, FieldData* result) {
// Read the "graph_options" or "node_options" field.
for (const auto& options : options_any) {
if (options.type_url().empty()) {
continue;
}
result->mutable_message_value()->set_type_url(options.type_url());
result->mutable_message_value()->set_value(std::string(options.value()));
return mediapipe::OkStatus();
}
// Read the "options" field.
FieldData message_data;
*message_data.mutable_message_value()->mutable_value() =
options_ext.SerializeAsString();
message_data.mutable_message_value()->set_type_url(options_ext.GetTypeName());
std::vector<const FieldDescriptor*> ext_fields;
OptionsRegistry::FindAllExtensions(options_ext.GetTypeName(), &ext_fields);
for (auto ext_field : ext_fields) {
absl::Status status = GetField({{ext_field, 0}}, message_data, result);
if (!status.ok()) {
return status;
}
if (result->has_message_value()) {
return status;
}
}
return mediapipe::OkStatus();
}
// Sets a protobuf in a repeated protobuf::Any field.
void SetOptionsMessage(
const FieldData& node_options,
proto_ns::RepeatedPtrField<mediapipe::protobuf::Any>* result) {
protobuf::Any* options_any = nullptr;
for (auto& any : *result) {
if (any.type_url() == node_options.message_value().type_url()) {
options_any = &any;
}
}
if (!options_any) {
options_any = result->Add();
options_any->set_type_url(node_options.message_value().type_url());
}
*options_any->mutable_value() = node_options.message_value().value();
}
} // anonymous namespace
// Deserializes a packet containing a MessageLite value.
absl::Status ReadMessage(const std::string& value, const std::string& type_name,
Packet* result) {
auto packet = packet_internal::PacketFromDynamicProto(type_name, value);
if (packet.ok()) {
*result = *packet;
}
return packet.status();
}
// Merge two options FieldData values.
absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over,
FieldData* result) {
absl::Status status;
if (over.value_case() == FieldData::VALUE_NOT_SET) {
*result = base;
return status;
}
if (base.value_case() == FieldData::VALUE_NOT_SET) {
*result = over;
return status;
}
if (over.value_case() != base.value_case()) {
return absl::InvalidArgumentError(absl::StrCat(
"Cannot merge field data with data types: ", base.value_case(), ", ",
over.value_case()));
}
if (over.message_value().type_url() != base.message_value().type_url()) {
return absl::InvalidArgumentError(
absl::StrCat("Cannot merge field data with message types: ",
base.message_value().type_url(), ", ",
over.message_value().type_url()));
}
absl::Cord merged_value;
merged_value.Append(base.message_value().value());
merged_value.Append(over.message_value().value());
result->mutable_message_value()->set_type_url(
base.message_value().type_url());
result->mutable_message_value()->set_value(std::string(merged_value));
return status;
}
// Writes a FieldData value into protobuf field.
absl::Status SetField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data) {
if (field_path.empty()) {
*message_data->mutable_message_value() = value.message_value();
return mediapipe::OkStatus();
}
ProtoUtilLite proto_util;
const FieldDescriptor* field = field_path.back().first;
FieldType field_type = AsFieldType(field->type());
std::string field_value;
MP_RETURN_IF_ERROR(WriteField(value, field, &field_value));
ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path);
std::string* message_bytes =
message_data->mutable_message_value()->mutable_value();
int field_count;
MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path,
field_type, &field_count));
MP_RETURN_IF_ERROR(
proto_util.ReplaceFieldRange(message_bytes, AsProtoPath(field_path),
field_count, field_type, {field_value}));
return mediapipe::OkStatus();
}
// Merges a packet value into nested protobuf Message.
absl::Status MergeField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data) {
absl::Status status;
FieldType field_type = field_path.empty()
? FieldType::TYPE_MESSAGE
: AsFieldType(field_path.back().first->type());
std::string message_type =
(value.has_message_value())
? ParseTypeUrl(std::string(value.message_value().type_url()))
: "";
FieldData v = value;
if (field_type == FieldType::TYPE_MESSAGE) {
FieldData b;
status.Update(GetField(field_path, *message_data, &b));
status.Update(MergeOptionsMessages(b, v, &v));
}
status.Update(SetField(field_path, v, message_data));
return status;
}
// Reads a packet value from a protobuf field.
absl::Status GetField(const FieldPath& field_path,
const FieldData& message_data, FieldData* result) {
if (field_path.empty()) {
*result->mutable_message_value() = message_data.message_value();
return mediapipe::OkStatus();
}
ProtoUtilLite proto_util;
const FieldDescriptor* field = field_path.back().first;
FieldType field_type = AsFieldType(field->type());
std::vector<std::string> field_values;
ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path);
const std::string& message_bytes = message_data.message_value().value();
int field_count;
MP_RETURN_IF_ERROR(proto_util.GetFieldCount(message_bytes, proto_path,
field_type, &field_count));
if (field_count == 0) {
return mediapipe::OkStatus();
}
MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1,
field_type, &field_values));
MP_RETURN_IF_ERROR(ReadField(field_values.front(), field, result));
return mediapipe::OkStatus();
}
// Returns the options protobuf for a graph.
absl::Status GetOptionsMessage(const CalculatorGraphConfig& config,
FieldData* result) {
return GetOptionsMessage(config.graph_options(), config.options(), result);
}
// Returns the options protobuf for a node.
absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node,
FieldData* result) {
return GetOptionsMessage(node.node_options(), node.options(), result);
}
// Sets the node_options field in a Node, and clears the options field.
void SetOptionsMessage(const FieldData& node_options,
CalculatorGraphConfig::Node* node) {
SetOptionsMessage(node_options, node->mutable_node_options());
node->clear_options();
}
// Represents a protobuf enum value stored in a Packet.
struct ProtoEnum {
ProtoEnum(int32 v) : value(v) {}
int32 value;
};
absl::Status AsPacket(const FieldData& data, Packet* result) {
switch (data.value_case()) {
case FieldData::ValueCase::kInt32Value:
*result = MakePacket<int32>(data.int32_value());
break;
case FieldData::ValueCase::kInt64Value:
*result = MakePacket<int64>(data.int64_value());
break;
case FieldData::ValueCase::kUint32Value:
*result = MakePacket<uint32>(data.uint32_value());
break;
case FieldData::ValueCase::kUint64Value:
*result = MakePacket<uint64>(data.uint64_value());
break;
case FieldData::ValueCase::kDoubleValue:
*result = MakePacket<double>(data.double_value());
break;
case FieldData::ValueCase::kFloatValue:
*result = MakePacket<float>(data.float_value());
break;
case FieldData::ValueCase::kBoolValue:
*result = MakePacket<bool>(data.bool_value());
break;
case FieldData::ValueCase::kEnumValue:
*result = MakePacket<ProtoEnum>(data.enum_value());
break;
case FieldData::ValueCase::kStringValue:
*result = MakePacket<std::string>(data.string_value());
break;
case FieldData::ValueCase::kMessageValue: {
auto r = packet_internal::PacketFromDynamicProto(
ParseTypeUrl(std::string(data.message_value().type_url())),
std::string(data.message_value().value()));
if (!r.ok()) {
return r.status();
}
*result = r.value();
break;
}
case FieldData::VALUE_NOT_SET:
*result = Packet();
}
return mediapipe::OkStatus();
}
absl::Status AsFieldData(Packet packet, FieldData* result) {
static const auto* kTypeIds = new std::map<size_t, int32>{
{tool::GetTypeHash<int32>(), WireFormatLite::CPPTYPE_INT32},
{tool::GetTypeHash<int64>(), WireFormatLite::CPPTYPE_INT64},
{tool::GetTypeHash<uint32>(), WireFormatLite::CPPTYPE_UINT32},
{tool::GetTypeHash<uint64>(), WireFormatLite::CPPTYPE_UINT64},
{tool::GetTypeHash<double>(), WireFormatLite::CPPTYPE_DOUBLE},
{tool::GetTypeHash<float>(), WireFormatLite::CPPTYPE_FLOAT},
{tool::GetTypeHash<bool>(), WireFormatLite::CPPTYPE_BOOL},
{tool::GetTypeHash<ProtoEnum>(), WireFormatLite::CPPTYPE_ENUM},
{tool::GetTypeHash<std::string>(), WireFormatLite::CPPTYPE_STRING},
};
if (packet.ValidateAsProtoMessageLite().ok()) {
result->mutable_message_value()->set_value(
packet.GetProtoMessageLite().SerializeAsString());
result->mutable_message_value()->set_type_url(
TypeUrl(packet.GetProtoMessageLite().GetTypeName()));
return mediapipe::OkStatus();
}
if (kTypeIds->count(packet.GetTypeId()) == 0) {
return absl::UnimplementedError(absl::StrCat(
"Cannot construct FieldData for: ", packet.DebugTypeName()));
}
switch (kTypeIds->at(packet.GetTypeId())) {
case WireFormatLite::CPPTYPE_INT32:
result->set_int32_value(packet.Get<int32>());
break;
case WireFormatLite::CPPTYPE_INT64:
result->set_int64_value(packet.Get<int64>());
break;
case WireFormatLite::CPPTYPE_UINT32:
result->set_uint32_value(packet.Get<uint32>());
break;
case WireFormatLite::CPPTYPE_UINT64:
result->set_uint64_value(packet.Get<uint64>());
break;
case WireFormatLite::CPPTYPE_DOUBLE:
result->set_double_value(packet.Get<double>());
break;
case WireFormatLite::CPPTYPE_FLOAT:
result->set_float_value(packet.Get<float>());
break;
case WireFormatLite::CPPTYPE_BOOL:
result->set_bool_value(packet.Get<bool>());
break;
case WireFormatLite::CPPTYPE_ENUM:
result->set_enum_value(packet.Get<ProtoEnum>().value);
break;
case WireFormatLite::CPPTYPE_STRING:
result->set_string_value(packet.Get<std::string>());
break;
}
return mediapipe::OkStatus();
}
std::string TypeUrl(absl::string_view type_name) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name));
}
std::string ParseTypeUrl(absl::string_view type_url) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
if (std::string(type_url).rfind(kTypeUrlPrefix, 0) == 0) {
return std::string(
type_url.substr(kTypeUrlPrefix.length(), std::string::npos));
}
return std::string(type_url);
}
} // namespace options_field_util
} // namespace tool
} // namespace mediapipe

View File

@ -0,0 +1,73 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_
#include <string>
#include <vector>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/tool/field_data.pb.h"
#include "mediapipe/framework/tool/options_registry.h"
namespace mediapipe {
namespace tool {
// Utility to read and write Packet data from protobuf fields.
namespace options_field_util {
// A chain of nested fields and indexes.
using FieldPath = std::vector<std::pair<const FieldDescriptor*, int>>;
// Writes a field value into protobuf field.
absl::Status SetField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data);
// Reads a field value from a protobuf field.
absl::Status GetField(const FieldPath& field_path,
const FieldData& message_data, FieldData* result);
// Merges a field value into nested protobuf Message.
absl::Status MergeField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data);
// Deserializes a packet containing a MessageLite value.
absl::Status ReadMessage(const std::string& value, const std::string& type_name,
Packet* result);
// Merge two options protobuf field values.
absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over,
FieldData* result);
// Returns the options protobuf for a graph.
absl::Status GetOptionsMessage(const CalculatorGraphConfig& config,
FieldData* result);
// Returns the options protobuf for a node.
absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node,
FieldData* result);
// Sets the node_options field in a Node, and clears the options field.
void SetOptionsMessage(const FieldData& node_options,
CalculatorGraphConfig::Node* node);
// Constructs a Packet for a FieldData proto.
absl::Status AsPacket(const FieldData& data, Packet* result);
// Constructs a FieldData proto for a Packet.
absl::Status AsFieldData(Packet packet, FieldData* result);
// Returns the protobuf type-url for a protobuf type-name.
std::string TypeUrl(absl::string_view type_name);
// Returns the protobuf type-name for a protobuf type-url.
std::string ParseTypeUrl(absl::string_view type_url);
} // namespace options_field_util
} // namespace tool
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_

View File

@ -1,47 +1,112 @@
#include "mediapipe/framework/tool/options_registry.h"
#include "absl/synchronization/mutex.h"
namespace mediapipe {
namespace tool {
proto_ns::DescriptorPool* OptionsRegistry::options_descriptor_pool() {
static proto_ns::DescriptorPool* result = new proto_ns::DescriptorPool();
return result;
namespace {
// Returns a canonical message type name, with any leading "." removed.
std::string CanonicalTypeName(const std::string& type_name) {
return (type_name.rfind('.', 0) == 0) ? type_name.substr(1) : type_name;
}
} // namespace
RegistrationToken OptionsRegistry::Register(
const proto_ns::FileDescriptorSet& files) {
absl::MutexLock lock(&mutex());
for (auto& file : files.file()) {
options_descriptor_pool()->BuildFile(file);
for (auto& message_type : file.message_type()) {
Register(message_type, file.package());
}
}
return RegistrationToken([]() {});
}
const proto_ns::Descriptor* OptionsRegistry::GetProtobufDescriptor(
const std::string& type_name) {
const proto_ns::Descriptor* result =
proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName(
type_name);
if (!result) {
result = options_descriptor_pool()->FindMessageTypeByName(type_name);
void OptionsRegistry::Register(const proto_ns::DescriptorProto& message_type,
const std::string& parent_name) {
auto full_name = absl::StrCat(parent_name, ".", message_type.name());
descriptors()[full_name] = Descriptor(message_type, full_name);
for (auto& nested : message_type.nested_type()) {
Register(nested, full_name);
}
return result;
for (auto& extension : message_type.extension()) {
extensions()[CanonicalTypeName(extension.extendee())].push_back(
FieldDescriptor(extension));
}
}
const Descriptor* OptionsRegistry::GetProtobufDescriptor(
const std::string& type_name) {
absl::ReaderMutexLock lock(&mutex());
auto it = descriptors().find(CanonicalTypeName(type_name));
return (it == descriptors().end()) ? nullptr : &it->second;
}
void OptionsRegistry::FindAllExtensions(
const proto_ns::Descriptor& extendee,
std::vector<const proto_ns::FieldDescriptor*>* result) {
using proto_ns::DescriptorPool;
std::vector<const proto_ns::FieldDescriptor*> extensions;
DescriptorPool::generated_pool()->FindAllExtensions(&extendee, &extensions);
options_descriptor_pool()->FindAllExtensions(&extendee, &extensions);
absl::flat_hash_set<int> numbers;
for (const proto_ns::FieldDescriptor* extension : extensions) {
bool inserted = numbers.insert(extension->number()).second;
if (inserted) {
result->push_back(extension);
absl::string_view extendee, std::vector<const FieldDescriptor*>* result) {
absl::ReaderMutexLock lock(&mutex());
result->clear();
if (extensions().count(extendee) > 0) {
for (const FieldDescriptor& field : extensions().at(extendee)) {
result->push_back(&field);
}
}
}
absl::flat_hash_map<std::string, Descriptor>& OptionsRegistry::descriptors() {
static auto* descriptors = new absl::flat_hash_map<std::string, Descriptor>();
return *descriptors;
}
absl::flat_hash_map<std::string, std::vector<FieldDescriptor>>&
OptionsRegistry::extensions() {
static auto* extensions =
new absl::flat_hash_map<std::string, std::vector<FieldDescriptor>>();
return *extensions;
}
absl::Mutex& OptionsRegistry::mutex() {
static auto* mutex = new absl::Mutex();
return *mutex;
}
Descriptor::Descriptor(const proto_ns::DescriptorProto& proto,
const std::string& full_name)
: full_name_(full_name) {
for (auto& field : proto.field()) {
fields_[field.name()] = FieldDescriptor(field);
}
}
const std::string& Descriptor::full_name() const { return full_name_; }
const FieldDescriptor* Descriptor::FindFieldByName(
const std::string& name) const {
auto it = fields_.find(name);
return (it != fields_.end()) ? &it->second : nullptr;
}
FieldDescriptor::FieldDescriptor(const proto_ns::FieldDescriptorProto& proto) {
name_ = proto.name();
message_type_ = CanonicalTypeName(proto.type_name());
type_ = proto.type();
number_ = proto.number();
}
const std::string& FieldDescriptor::name() const { return name_; }
int FieldDescriptor::number() const { return number_; }
proto_ns::FieldDescriptorProto::Type FieldDescriptor::type() const {
return type_;
}
const Descriptor* FieldDescriptor::message_type() const {
return OptionsRegistry::GetProtobufDescriptor(message_type_);
}
} // namespace tool
} // namespace mediapipe

View File

@ -1,12 +1,16 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
#include "absl/container/flat_hash_map.h"
#include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
namespace mediapipe {
namespace tool {
class Descriptor;
class FieldDescriptor;
// A static registry that stores descriptors for protobufs used in MediaPipe
// calculator options. Lite-proto builds do not normally include descriptors.
// These registered descriptors allow individual protobuf fields to be
@ -17,23 +21,60 @@ class OptionsRegistry {
static RegistrationToken Register(const proto_ns::FileDescriptorSet& files);
// Finds the descriptor for a protobuf.
static const proto_ns::Descriptor* GetProtobufDescriptor(
const std::string& type_name);
static const Descriptor* GetProtobufDescriptor(const std::string& type_name);
// Returns all known proto2 extensions to a type.
static void FindAllExtensions(
const proto_ns::Descriptor& extendee,
std::vector<const proto_ns::FieldDescriptor*>* result);
static void FindAllExtensions(absl::string_view extendee,
std::vector<const FieldDescriptor*>* result);
private:
// Stores the descriptors for each options protobuf type.
static proto_ns::DescriptorPool* options_descriptor_pool();
// Registers protobuf descriptors a MessageLite and nested types.
static void Register(const proto_ns::DescriptorProto& message_type,
const std::string& parent_name);
static absl::flat_hash_map<std::string, Descriptor>& descriptors();
static absl::flat_hash_map<std::string, std::vector<FieldDescriptor>>&
extensions();
static absl::Mutex& mutex();
// Registers the descriptors for each options protobuf type.
template <class MessageT>
static const RegistrationToken registration_token;
};
// A custom implementation proto_ns::Descriptor. This implementation
// avoids a code size problem introduced by proto_ns::FieldDescriptor.
class Descriptor {
public:
Descriptor() {}
Descriptor(const proto_ns::DescriptorProto& proto,
const std::string& full_name);
const std::string& full_name() const;
const FieldDescriptor* FindFieldByName(const std::string& name) const;
private:
std::string full_name_;
absl::flat_hash_map<std::string, FieldDescriptor> fields_;
};
// A custom implementation proto_ns::FieldDescriptor. This implementation
// avoids a code size problem introduced by proto_ns::FieldDescriptor.
class FieldDescriptor {
public:
FieldDescriptor() {}
FieldDescriptor(const proto_ns::FieldDescriptorProto& proto);
const std::string& name() const;
int number() const;
proto_ns::FieldDescriptorProto::Type type() const;
const Descriptor* message_type() const;
private:
std::string name_;
std::string message_type_;
proto_ns::FieldDescriptorProto::Type type_;
int number_;
};
} // namespace tool
} // namespace mediapipe

View File

@ -0,0 +1,143 @@
#include "mediapipe/framework/tool/options_syntax_util.h"
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h"
namespace mediapipe {
namespace tool {
// Helper functions for parsing the graph options syntax.
class OptionsSyntaxUtil::OptionsSyntaxHelper {
public:
// The usual graph options syntax tokens.
OptionsSyntaxHelper() : syntax_{"OPTIONS", "options", "/"} {}
// Returns the tag name for an option protobuf field.
std::string OptionFieldTag(const std::string& name) { return name; }
// Returns the packet name for an option protobuf field.
absl::string_view OptionFieldPacket(absl::string_view name) { return name; }
// Returns the option protobuf field name for a tag or packet name.
absl::string_view OptionFieldName(absl::string_view name) { return name; }
// Returns the field-path for an option stream-tag.
FieldPath OptionFieldPath(const std::string& tag,
const Descriptor* descriptor) {
int prefix = syntax_.tag_name.length() + syntax_.separator.length();
std::string suffix = tag.substr(prefix);
std::vector<absl::string_view> name_tags =
absl::StrSplit(suffix, syntax_.separator);
FieldPath result;
for (absl::string_view name_tag : name_tags) {
if (name_tag.empty()) {
continue;
}
absl::string_view option_name = OptionFieldName(name_tag);
int index;
if (absl::SimpleAtoi(option_name, &index)) {
result.back().second = index;
} else {
auto field = descriptor->FindFieldByName(std::string(option_name));
descriptor = field ? field->message_type() : nullptr;
result.push_back({std::move(field), 0});
}
}
return result;
}
// Returns the option field name for a graph options packet name.
std::string GraphOptionFieldName(const std::string& graph_option_name) {
int prefix = syntax_.packet_name.length() + syntax_.separator.length();
std::string result = graph_option_name;
result.erase(0, prefix);
return result;
}
// Returns the graph options packet name for an option field name.
std::string GraphOptionName(const std::string& option_name) {
std::string packet_prefix =
syntax_.packet_name + absl::AsciiStrToLower(syntax_.separator);
return absl::StrCat(packet_prefix, option_name);
}
// Returns the tag name for a graph option.
std::string OptionTagName(const std::string& option_name) {
return absl::StrCat(syntax_.tag_name, syntax_.separator,
OptionFieldTag(option_name));
}
// Converts slash-separated field names into a tag name.
std::string OptionFieldsTag(const std::string& option_names) {
std::string tag_prefix = syntax_.tag_name + syntax_.separator;
std::vector<absl::string_view> names = absl::StrSplit(option_names, '/');
if (!names.empty() && names[0] == syntax_.tag_name) {
names.erase(names.begin());
}
if (!names.empty() && names[0] == syntax_.packet_name) {
names.erase(names.begin());
}
std::string result;
std::string sep = "";
for (absl::string_view v : names) {
absl::StrAppend(&result, sep, OptionFieldTag(std::string(v)));
sep = syntax_.separator;
}
result = tag_prefix + result;
return result;
}
// Token definitions for the graph options syntax.
struct OptionsSyntax {
// The tag name for an options protobuf.
std::string tag_name;
// The packet name for an options protobuf.
std::string packet_name;
// The separator between nested options fields.
std::string separator;
};
OptionsSyntax syntax_;
}; // class OptionsSyntaxHelper
OptionsSyntaxUtil::OptionsSyntaxUtil()
: syntax_helper_(std::make_unique<OptionsSyntaxHelper>()) {}
OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name)
: OptionsSyntaxUtil() {
syntax_helper_->syntax_.tag_name = tag_name;
}
OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name,
const std::string& packet_name,
const std::string& separator)
: OptionsSyntaxUtil() {
syntax_helper_->syntax_.tag_name = tag_name;
syntax_helper_->syntax_.packet_name = packet_name;
syntax_helper_->syntax_.separator = separator;
}
OptionsSyntaxUtil::~OptionsSyntaxUtil() {}
std::string OptionsSyntaxUtil::OptionFieldsTag(
const std::string& option_names) {
return syntax_helper_->OptionFieldsTag(option_names);
}
OptionsSyntaxUtil::FieldPath OptionsSyntaxUtil::OptionFieldPath(
const std::string& tag, const Descriptor* descriptor) {
return syntax_helper_->OptionFieldPath(tag, descriptor);
}
} // namespace tool
} // namespace mediapipe

View File

@ -0,0 +1,45 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/tool/options_field_util.h"
#include "mediapipe/framework/tool/options_registry.h"
namespace mediapipe {
namespace tool {
// Utility to parse the graph options syntax used in "option_value",
// "side_packet", and "stream".
class OptionsSyntaxUtil {
public:
using FieldPath = options_field_util::FieldPath;
OptionsSyntaxUtil();
OptionsSyntaxUtil(const std::string& tag_name);
OptionsSyntaxUtil(const std::string& tag_name, const std::string& packet_name,
const std::string& separator);
~OptionsSyntaxUtil();
// Converts slash-separated field names into a tag name.
std::string OptionFieldsTag(const std::string& option_names);
// Returns the field-path for an option stream-tag.
FieldPath OptionFieldPath(const std::string& tag,
const Descriptor* descriptor);
private:
class OptionsSyntaxHelper;
std::unique_ptr<OptionsSyntaxHelper> syntax_helper_;
};
} // namespace tool
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_

View File

@ -1,16 +1,82 @@
#include "mediapipe/framework/tool/options_util.h"
#include "mediapipe/framework/port/proto_ns.h"
#include <memory>
#include <string>
#include <variant>
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/input_stream_shard.h"
#include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/options_field_util.h"
#include "mediapipe/framework/tool/options_registry.h"
#include "mediapipe/framework/tool/options_syntax_util.h"
#include "mediapipe/framework/tool/proto_util_lite.h"
namespace mediapipe {
namespace tool {
// TODO: Return registered protobuf Descriptors when available.
const proto_ns::Descriptor* GetProtobufDescriptor(
const std::string& type_name) {
return proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName(
type_name);
// Copy literal options from graph_options to node_options.
absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
CalculatorGraphConfig* config) {
Status status;
FieldData config_options, parent_node_options, graph_options;
status.Update(
options_field_util::GetOptionsMessage(*config, &config_options));
status.Update(
options_field_util::GetOptionsMessage(parent_node, &parent_node_options));
status.Update(options_field_util::MergeOptionsMessages(
config_options, parent_node_options, &graph_options));
const Descriptor* options_descriptor =
OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl(
std::string(graph_options.message_value().type_url())));
if (!options_descriptor) {
return status;
}
OptionsSyntaxUtil syntax_util;
for (auto& node : *config->mutable_node()) {
FieldData node_data;
status.Update(options_field_util::GetOptionsMessage(node, &node_data));
if (!node_data.has_message_value() || node.option_value_size() == 0) {
continue;
}
const Descriptor* node_options_descriptor =
OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl(
std::string(node_data.message_value().type_url())));
if (!node_options_descriptor) {
continue;
}
for (const std::string& option_def : node.option_value()) {
std::vector<std::string> tag_and_name = absl::StrSplit(option_def, ':');
std::string graph_tag = syntax_util.OptionFieldsTag(tag_and_name[1]);
std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]);
FieldData packet_data;
status.Update(options_field_util::GetField(
syntax_util.OptionFieldPath(graph_tag, options_descriptor),
graph_options, &packet_data));
status.Update(options_field_util::MergeField(
syntax_util.OptionFieldPath(node_tag, node_options_descriptor),
packet_data, &node_data));
}
options_field_util::SetOptionsMessage(node_data, &node);
}
return status;
}
// Makes all configuration modifications needed for graph options.
absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node,
CalculatorGraphConfig* config) {
MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config));
return mediapipe::OkStatus();
}
} // namespace tool

View File

@ -21,7 +21,6 @@
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe {
@ -75,8 +74,9 @@ inline T RetrieveOptions(const T& base, const InputStreamShardSet& stream_set,
return base;
}
// Finds the descriptor for a protobuf.
const proto_ns::Descriptor* GetProtobufDescriptor(const std::string& type_name);
// Copy literal options from enclosing graphs.
absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node,
CalculatorGraphConfig* config);
} // namespace tool
} // namespace mediapipe

View File

@ -16,15 +16,41 @@
#include <vector>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/testdata/night_light_calculator.pb.h"
#include "mediapipe/framework/tool/node_chain_subgraph.pb.h"
#include "mediapipe/framework/tool/options_registry.h"
#include "mediapipe/framework/tool/options_syntax_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::proto_ns::FieldDescriptorProto;
using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type;
// A test Calculator using DeclareOptions and DefineOptions.
class NightLightCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return mediapipe::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
return mediapipe::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
return mediapipe::OkStatus();
}
private:
NightLightCalculatorOptions options_;
};
REGISTER_CALCULATOR(NightLightCalculator);
// Tests for calculator and graph options.
//
class OptionsUtilTest : public ::testing::Test {
@ -35,21 +61,108 @@ class OptionsUtilTest : public ::testing::Test {
// Retrieves the description of a protobuf.
TEST_F(OptionsUtilTest, GetProtobufDescriptor) {
const proto_ns::Descriptor* descriptor =
tool::GetProtobufDescriptor("mediapipe.CalculatorGraphConfig");
#ifndef MEDIAPIPE_MOBILE
const tool::Descriptor* descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.CalculatorGraphConfig");
EXPECT_NE(nullptr, descriptor);
#else
EXPECT_EQ(nullptr, descriptor);
#endif
}
// Retrieves the description of a protobuf from the OptionsRegistry.
// Shows a calculator node deriving options from graph options.
// The subgraph specifies "graph_options" as "NodeChainSubgraphOptions".
// The calculator specifies "node_options as "NightLightCalculatorOptions".
TEST_F(OptionsUtilTest, CopyLiteralOptions) {
CalculatorGraphConfig subgraph_config;
auto node = subgraph_config.add_node();
*node->mutable_calculator() = "NightLightCalculator";
*node->add_option_value() = "num_lights:options/chain_length";
// The options framework requires at least an empty options protobuf
// as an indication the options protobuf type expected by the node.
NightLightCalculatorOptions node_options;
node->add_node_options()->PackFrom(node_options);
NodeChainSubgraphOptions options;
options.set_chain_length(8);
subgraph_config.add_graph_options()->PackFrom(options);
subgraph_config.set_type("NightSubgraph");
CalculatorGraphConfig graph_config;
node = graph_config.add_node();
*node->mutable_calculator() = "NightSubgraph";
CalculatorGraph graph;
graph_config.set_num_threads(4);
MP_EXPECT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {}));
CalculatorGraphConfig expanded_config = graph.Config();
expanded_config.clear_executor();
CalculatorGraphConfig::Node actual_node;
actual_node = expanded_config.node(0);
CalculatorGraphConfig::Node expected_node;
expected_node.set_name("nightsubgraph__NightLightCalculator");
expected_node.set_calculator("NightLightCalculator");
NightLightCalculatorOptions expected_node_options;
expected_node_options.add_num_lights(8);
expected_node.add_node_options()->PackFrom(expected_node_options);
*expected_node.add_option_value() = "num_lights:options/chain_length";
EXPECT_THAT(actual_node, EqualsProto(expected_node));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
// Ensure static protobuf packet registration.
MakePacket<NodeChainSubgraphOptions>();
MakePacket<NightLightCalculatorOptions>();
}
// Retrieves the description of a protobuf message and a nested protobuf message
// from the OptionsRegistry.
TEST_F(OptionsUtilTest, GetProtobufDescriptorRegistered) {
const proto_ns::Descriptor* descriptor =
const tool::Descriptor* options_descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.NightLightCalculatorOptions");
EXPECT_NE(nullptr, descriptor);
EXPECT_NE(nullptr, options_descriptor);
const tool::Descriptor* bundle_descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.NightLightCalculatorOptions.LightBundle");
EXPECT_NE(nullptr, bundle_descriptor);
EXPECT_EQ(options_descriptor->full_name(),
"mediapipe.NightLightCalculatorOptions");
const tool::FieldDescriptor* bundle_field =
options_descriptor->FindFieldByName("bundle");
EXPECT_EQ(bundle_field->message_type(), bundle_descriptor);
}
// Constructs the FieldPath for a nested node-option.
TEST_F(OptionsUtilTest, OptionsSyntaxUtil) {
const tool::Descriptor* descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.NightLightCalculatorOptions");
std::string tag;
tool::OptionsSyntaxUtil::FieldPath field_path;
{
// The default tag syntax.
tool::OptionsSyntaxUtil syntax_util;
tag = syntax_util.OptionFieldsTag("options/sub_options/num_lights");
EXPECT_EQ(tag, "OPTIONS/sub_options/num_lights");
field_path = syntax_util.OptionFieldPath(tag, descriptor);
EXPECT_EQ(field_path.size(), 2);
EXPECT_EQ(field_path[0].first->name(), "sub_options");
EXPECT_EQ(field_path[1].first->name(), "num_lights");
}
{
// A tag syntax with a text-coded separator.
tool::OptionsSyntaxUtil syntax_util("OPTIONS", "options", "_Z0Z_");
tag = syntax_util.OptionFieldsTag("options/sub_options/num_lights");
EXPECT_EQ(tag, "OPTIONS_Z0Z_sub_options_Z0Z_num_lights");
field_path = syntax_util.OptionFieldPath(tag, descriptor);
EXPECT_EQ(field_path.size(), 2);
EXPECT_EQ(field_path[0].first->name(), "sub_options");
EXPECT_EQ(field_path[1].first->name(), "num_lights");
}
}
} // namespace

View File

@ -196,6 +196,27 @@ absl::Status ProtoUtilLite::GetFieldRange(
return absl::OkStatus();
}
// Returns the number of field values in a repeated protobuf field.
absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
ProtoPath proto_path,
FieldType field_type,
int* field_count) {
int field_id, index;
std::tie(field_id, index) = proto_path.back();
proto_path.pop_back();
std::vector<std::string> parent;
if (proto_path.empty()) {
parent.push_back(std::string(message));
} else {
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
}
FieldAccess access(field_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
*field_count = access.mutable_field_values()->size();
return absl::OkStatus();
}
// If ok, returns OkStatus, otherwise returns InvalidArgumentError.
template <typename T>
absl::Status SyntaxStatus(bool ok, const std::string& text, T* result) {

View File

@ -75,6 +75,11 @@ class ProtoUtilLite {
FieldType field_type,
std::vector<FieldValue>* field_values);
// Returns the number of field values in a repeated protobuf field.
static absl::Status GetFieldCount(const FieldValue& message,
ProtoPath proto_path, FieldType field_type,
int* field_count);
// Serialize one or more protobuf field values from text.
static absl::Status Serialize(const std::vector<std::string>& text_values,
FieldType field_type,

View File

@ -278,6 +278,8 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
RET_CHECK(config);
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
CalculatorGraphConfig::Node(), config));
auto* nodes = config->mutable_node();
while (1) {
auto subgraph_nodes_start = std::stable_partition(
@ -297,6 +299,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName(
config->package(), node.calculator(),
&subgraph_context));
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(node, &subgraph));
MP_RETURN_IF_ERROR(PrefixNames(node_name, &subgraph));
MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph));
subgraphs.push_back(subgraph);

View File

@ -99,6 +99,26 @@ const GLchar* const kScaledVertexShader = VERTEX_PREAMBLE _STRINGIFY(
sample_coordinate = texture_coordinate.xy;
});
const GLchar* const kTransformedVertexShader = VERTEX_PREAMBLE _STRINGIFY(
in vec4 position; in mediump vec4 texture_coordinate;
out mediump vec2 sample_coordinate; uniform mat3 transform;
uniform vec2 viewport_size;
void main() {
// switch from clip to viewport aspect ratio in order to properly
// apply transformation
vec2 half_viewport_size = viewport_size * 0.5;
vec3 pos = vec3(position.xy * half_viewport_size, 1);
// apply transform
pos = transform * pos;
// switch back to clip space
gl_Position = vec4(pos.xy / half_viewport_size, 0, 1);
sample_coordinate = texture_coordinate.xy;
});
const GLchar* const kBasicTexturedFragmentShader = FRAGMENT_PREAMBLE _STRINGIFY(
DEFAULT_PRECISION(mediump, float)

View File

@ -38,6 +38,17 @@ extern const GLchar* const kBasicVertexShader;
// vec2 sample_coordinate - texture coordinate for shader
extern const GLchar* const kScaledVertexShader;
// Applies an affine transformation to the vertex and leaves texture coordinates
// as is. Input attributes:
// vec4 position - vertex position
// vec4 texture_coordinate - texture coordinate
// Input uniform:
// mat3 homogeneous affine transform - transformation matrix for vertices
// vec2 viewport_size - size of the viewport
// Output varying:
// vec2 sample_coordinate - texture coordinate for shader
extern const GLchar* const kTransformedVertexShader;
// Outputs the texture as it is.
// Input varying:
// vec2 sample_coordinate - texture coordinate

View File

@ -77,16 +77,24 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
}},
{GpuBufferFormat::kOneComponent8,
{
// This should be GL_RED, but it would change the output for existing
// shaders. It would not be a good representation of a grayscale texture,
// unless we use texture swizzling. We could add swizzle parameters (e.g.
// GL_TEXTURE_SWIZZLE_R) in GLES 3 and desktop GL, and use GL_LUMINANCE
// in GLES 2. Or we could just punt and make it a red texture.
// {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
// This format is like RGBA grayscale: GL_LUMINANCE replicates
// the single channel texel values to RGB channels, and set alpha
// to 1.0. If it is desired to see only the texel values in the R
// channel, use kOneComponent8Red instead.
#if !TARGET_OS_OSX
{GL_LUMINANCE, GL_LUMINANCE, GL_UNSIGNED_BYTE, 1},
#else
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
#endif // TARGET_OS_OSX
}},
{GpuBufferFormat::kOneComponent8Red,
{
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
}},
{GpuBufferFormat::kTwoComponent8,
{
{GL_RG8, GL_RG, GL_UNSIGNED_BYTE, 1},
}},
#ifdef __APPLE__
// TODO: figure out GL_RED_EXT etc. on Android.
{GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange,
@ -195,6 +203,8 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
case GpuBufferFormat::kTwoComponentFloat32:
return ImageFormat::VEC32F2;
case GpuBufferFormat::kGrayHalf16:
case GpuBufferFormat::kOneComponent8Red:
case GpuBufferFormat::kTwoComponent8:
case GpuBufferFormat::kTwoComponentHalf16:
case GpuBufferFormat::kRGBAHalf64:
case GpuBufferFormat::kRGBAFloat128:

View File

@ -38,6 +38,8 @@ enum class GpuBufferFormat : uint32_t {
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
kTwoComponentFloat32 = MEDIAPIPE_FOURCC('2', 'C', '0', 'f'),
kBiPlanar420YpCbCr8VideoRange = MEDIAPIPE_FOURCC('4', '2', '0', 'v'),
@ -82,6 +84,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
return kCVPixelFormatType_OneComponent32Float;
case GpuBufferFormat::kOneComponent8:
return kCVPixelFormatType_OneComponent8;
case GpuBufferFormat::kOneComponent8Red:
return -1;
case GpuBufferFormat::kTwoComponent8:
return kCVPixelFormatType_TwoComponent8;
case GpuBufferFormat::kTwoComponentHalf16:
return kCVPixelFormatType_TwoComponent16Half;
case GpuBufferFormat::kTwoComponentFloat32:
@ -114,6 +120,8 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) {
return GpuBufferFormat::kGrayFloat32;
case kCVPixelFormatType_OneComponent8:
return GpuBufferFormat::kOneComponent8;
case kCVPixelFormatType_TwoComponent8:
return GpuBufferFormat::kTwoComponent8;
case kCVPixelFormatType_TwoComponent16Half:
return GpuBufferFormat::kTwoComponentHalf16;
case kCVPixelFormatType_TwoComponent32Float:

View File

@ -19,6 +19,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.TextureFrame;
import java.util.ArrayList;
import java.util.List;
/**
@ -32,6 +33,9 @@ public class ImageSolutionResult implements SolutionResult {
private Bitmap cachedBitmap;
// A list of the output image packets produced by the graph.
protected List<Packet> imageResultPackets;
// The cached texture frames.
protected List<TextureFrame> imageResultTextureFrames;
private TextureFrame cachedTextureFrame;
// Result timestamp, which is set to the timestamp of the corresponding input image. May return
// Long.MIN_VALUE if the input image is not associated with a timestamp.
@ -61,8 +65,38 @@ public class ImageSolutionResult implements SolutionResult {
return PacketGetter.getTextureFrame(imagePacket);
}
// Returns the cached input image as a {@link TextureFrame}.
public TextureFrame getCachedInputTextureFrame() {
return cachedTextureFrame;
}
// Produces all texture frames from image packets and caches them for further use. The caller must
// release the cached {@link TextureFrame}s after using.
void produceAllTextureFrames() {
cachedTextureFrame = acquireInputTextureFrame();
if (imageResultPackets == null) {
return;
}
imageResultTextureFrames = new ArrayList<>();
for (Packet p : imageResultPackets) {
imageResultTextureFrames.add(PacketGetter.getTextureFrame(p));
}
}
// Releases all cached {@link TextureFrame}s.
void releaseCachedTextureFrames() {
if (cachedTextureFrame != null) {
cachedTextureFrame.release();
}
if (imageResultTextureFrames != null) {
for (TextureFrame textureFrame : imageResultTextureFrames) {
textureFrame.release();
}
}
}
// Releases image packet and the underlying data.
void releaseImagePacket() {
void releaseImagePackets() {
imagePacket.release();
if (imageResultPackets != null) {
for (Packet p : imageResultPackets) {

View File

@ -79,7 +79,7 @@ public class OutputHandler<T extends SolutionResult> {
}
if (solutionResult instanceof ImageSolutionResult) {
ImageSolutionResult imageSolutionResult = (ImageSolutionResult) solutionResult;
imageSolutionResult.releaseImagePacket();
imageSolutionResult.releaseImagePackets();
}
}
}

View File

@ -50,13 +50,25 @@ public class SolutionGlSurfaceView<T extends ImageSolutionResult> extends GLSurf
}
/**
* Sets the next textureframe and solution result to render.
* Sets the next input {@link TextureFrame} and solution result to render.
*
* @param solutionResult a solution result object that contains the solution outputs and a
* textureframe.
*/
public void setRenderData(T solutionResult) {
renderer.setRenderData(solutionResult);
renderer.setRenderData(solutionResult, false);
}
/**
* Sets the next input {@link TextureFrame} and solution result to render.
*
* @param solutionResult a solution result object that contains the solution outputs and a {@link
* TextureFrame}.
* @param produceTextureFrames whether to produce and cache all the {@link TextureFrame}s for
* further use.
*/
public void setRenderData(T solutionResult, boolean produceTextureFrames) {
renderer.setRenderData(solutionResult, produceTextureFrames);
}
/** Sets if the input image needs to be rendered. Default to true. */

View File

@ -27,9 +27,10 @@ import javax.microedition.khronos.opengles.GL10;
* MediaPipe Solution's GlSurfaceViewRenderer.
*
* <p>Users can provide a custom {@link ResultGlRenderer} for rendering MediaPipe solution results.
* For setting the latest solution result, call {@link #setRenderData(ImageSolutionResult)}. By
* default, the renderer renders the input images. Call {@link #setRenderInputImage(boolean)} to
* explicitly set whether the input images should be rendered or not.
* For setting the latest solution result, call {@link #setRenderData(ImageSolutionResult,
* boolean)}. By default, the renderer renders the input images. Call {@link
* #setRenderInputImage(boolean)} to explicitly set whether the input images should be rendered or
* not.
*/
public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
extends GlSurfaceViewRenderer {
@ -49,16 +50,24 @@ public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
}
/**
* Sets the next textureframe and solution result to render.
* Sets the next input {@link TextureFrame} and solution result to render.
*
* @param solutionResult a solution result object that contains the solution outputs and a
* textureframe.
* @param produceTextureFrames whether to produce and cache all the {@link TextureFrame}s for
* further use.
*/
public void setRenderData(T solutionResult) {
public void setRenderData(T solutionResult, boolean produceTextureFrames) {
TextureFrame frame = solutionResult.acquireInputTextureFrame();
setFrameSize(frame.getWidth(), frame.getHeight());
setNextFrame(frame);
nextSolutionResult.getAndSet(solutionResult);
if (produceTextureFrames) {
solutionResult.produceAllTextureFrames();
}
T oldSolutionResult = nextSolutionResult.getAndSet(solutionResult);
if (oldSolutionResult != null) {
oldSolutionResult.releaseCachedTextureFrames();
}
}
@Override
@ -78,8 +87,9 @@ public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
ShaderUtil.checkGlError("glActiveTexture");
}
T solutionResult = null;
if (nextSolutionResult != null) {
T solutionResult = nextSolutionResult.getAndSet(null);
solutionResult = nextSolutionResult.getAndSet(null);
float[] textureBoundary = calculateTextureBoundary();
// Scales the values from [0, 1] to [-1, 1].
ResultGlBoundary resultGlBoundary =
@ -91,6 +101,9 @@ public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
resultGlRenderer.renderResult(solutionResult, resultGlBoundary);
}
flush(frame);
if (solutionResult != null) {
solutionResult.releaseCachedTextureFrames();
}
}
@Override

View File

@ -64,9 +64,10 @@ node {
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite"
delegate { xnnpack {} }
delegate {
xnnpack {}
}
}
#
}
}

View File

@ -50,9 +50,10 @@ node {
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
delegate { xnnpack {} }
delegate {
xnnpack {}
}
}
#
}
}

View File

@ -72,9 +72,10 @@ node {
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/pose_detection/pose_detection.tflite"
delegate { xnnpack {} }
delegate {
xnnpack {}
}
}
#
}
}

View File

@ -245,7 +245,7 @@ node {
output_stream: "enabled_segmentation_tensor"
options: {
[mediapipe.GateCalculatorOptions.ext] {
allow: true
allow: false
}
}
}

View File

@ -96,9 +96,10 @@ node {
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
delegate { xnnpack {} }
delegate {
xnnpack {}
}
}
#
}
}

View File

@ -288,7 +288,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
calculator_graph.def(
"wait_until_done",
[](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->WaitUntilDone()); },
[](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
},
R"doc(Wait for the current run to finish.
A blocking call to wait for the current run to finish (block the current
@ -313,7 +316,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
calculator_graph.def(
"wait_until_idle",
[](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->WaitUntilIdle()); },
[](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilIdle(), /**acquire_gil=*/true);
},
R"doc(Wait until the running graph is in the idle mode.
Wait until the running graph is in the idle mode, which is when nothing can
@ -399,12 +405,9 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
stream_name,
[callback_fn, stream_name](const Packet& packet) {
absl::MutexLock lock(&callback_mutex);
py::gil_scoped_release gil_release;
{
// Acquires GIL before calling Python callback.
py::gil_scoped_acquire gil_acquire;
callback_fn(stream_name, packet);
}
return absl::OkStatus();
},
observe_timestamp_bounds));
@ -439,7 +442,8 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
"close",
[](CalculatorGraph* self) {
RaisePyErrorIfNotOk(self->CloseAllPacketSources());
RaisePyErrorIfNotOk(self->WaitUntilDone());
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
},
R"doc(Close all the input sources and shutdown the graph.)doc");

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h"
#include "pybind11/gil.h"
#include "pybind11/pybind11.h"
namespace mediapipe {
@ -45,10 +46,17 @@ inline PyObject* StatusCodeToPyError(const ::absl::StatusCode& code) {
}
}
inline void RaisePyErrorIfNotOk(const absl::Status& status) {
inline void RaisePyErrorIfNotOk(const absl::Status& status,
bool acquire_gil = false) {
if (!status.ok()) {
if (acquire_gil) {
py::gil_scoped_acquire acquire;
throw RaisePyError(StatusCodeToPyError(status.code()),
status.message().data());
} else {
throw RaisePyError(StatusCodeToPyError(status.code()),
status.message().data());
}
}
}

View File

@ -441,7 +441,7 @@ class SolutionBase:
else:
field_label = calculator_options.DESCRIPTOR.fields_by_name[
field_name].label
if field_label is descriptor.FieldDescriptor.LABEL_REPEATED:
if field_label == descriptor.FieldDescriptor.LABEL_REPEATED:
if not isinstance(field_value, Iterable):
raise ValueError(
f'{field_name} is a repeated proto field but the value '

View File

@ -111,7 +111,7 @@ def draw_detection(
image_rows)
rect_end_point = _normalized_to_pixel_coordinates(
relative_bounding_box.xmin + relative_bounding_box.width,
relative_bounding_box.ymin + +relative_bounding_box.height, image_cols,
relative_bounding_box.ymin + relative_bounding_box.height, image_cols,
image_rows)
cv2.rectangle(image, rect_start_point, rect_end_point,
bbox_drawing_spec.color, bbox_drawing_spec.thickness)

View File

@ -175,8 +175,9 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":resource_util_custom",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:singleton",
"//mediapipe/framework/port:status",

View File

@ -38,10 +38,17 @@ workspace_file="$( cd "$(dirname "$0")" ; pwd -P )"/WORKSPACE
if [ -z "$1" ]
then
echo "Installing OpenCV from source"
if [[ -x "$(command -v apt)" ]]; then
sudo apt update && sudo apt install build-essential git
sudo apt install cmake ffmpeg libavformat-dev libdc1394-22-dev libgtk2.0-dev \
libjpeg-dev libpng-dev libswscale-dev libtbb2 libtbb-dev \
libtiff-dev
elif [[ -x "$(command -v dnf)" ]]; then
sudo dnf update && sudo dnf install cmake gcc gcc-c git
sudo dnf install ffmpeg-devel libdc1394-devel gtk2-devel \
libjpeg-turbo-devel libpng-devel tbb-devel \
libtiff-devel
fi
rm -rf /tmp/build_opencv
mkdir /tmp/build_opencv
cd /tmp/build_opencv