Project import generated by Copybara.

GitOrigin-RevId: f4b1fe3f15810450fb6539e733f6a260d3ee082c
This commit is contained in:
MediaPipe Team 2021-09-01 13:49:12 -07:00 committed by jqtang
parent 710fb3de58
commit 6abec128ed
64 changed files with 2384 additions and 161 deletions

View File

@ -157,11 +157,11 @@ http_archive(
http_archive( http_archive(
name = "pybind11", name = "pybind11",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.7.1.tar.gz",
"https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", "https://github.com/pybind/pybind11/archive/v2.7.1.tar.gz",
], ],
sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", sha256 = "616d1c42e4cf14fa27b2a4ff759d7d7b33006fdc5ad8fd603bb2c22622f27020",
strip_prefix = "pybind11-2.4.3", strip_prefix = "pybind11-2.7.1",
build_file = "@pybind11_bazel//:pybind11.BUILD", build_file = "@pybind11_bazel//:pybind11.BUILD",
) )

View File

@ -113,6 +113,10 @@ bazel to build the iOS application. The content of the
5. `Main.storyboard` and `Launch.storyboard` 5. `Main.storyboard` and `Launch.storyboard`
6. `Assets.xcassets` directory. 6. `Assets.xcassets` directory.
Note: In newer versions of Xcode, you may see additional files `SceneDelegate.h`
and `SceneDelegate.m`. Make sure to copy them too and add them to the `BUILD`
file mentioned below.
Copy these files to a directory named `HelloWorld` to a location that can access Copy these files to a directory named `HelloWorld` to a location that can access
the MediaPipe source code. For example, the source code of the application that the MediaPipe source code. For example, the source code of the application that
we will build in this tutorial is located in we will build in this tutorial is located in
@ -247,6 +251,12 @@ We need to get frames from the `_cameraSource` into our application
`MPPInputSourceDelegate`. So our application `ViewController` can be a delegate `MPPInputSourceDelegate`. So our application `ViewController` can be a delegate
of `_cameraSource`. of `_cameraSource`.
Update the interface definition of `ViewController` accordingly:
```
@interface ViewController () <MPPInputSourceDelegate>
```
To handle camera setup and process incoming frames, we should use a queue To handle camera setup and process incoming frames, we should use a queue
different from the main queue. Add the following to the implementation block of different from the main queue. Add the following to the implementation block of
the `ViewController`: the `ViewController`:
@ -288,6 +298,12 @@ utility called `MPPLayerRenderer` to display images on the screen. This utility
can be used to display `CVPixelBufferRef` objects, which is the type of the can be used to display `CVPixelBufferRef` objects, which is the type of the
images provided by `MPPCameraInputSource` to its delegates. images provided by `MPPCameraInputSource` to its delegates.
In `ViewController.m`, add the following import line:
```
#import "mediapipe/objc/MPPLayerRenderer.h"
```
To display images of the screen, we need to add a new `UIView` object called To display images of the screen, we need to add a new `UIView` object called
`_liveView` to the `ViewController`. `_liveView` to the `ViewController`.
@ -411,6 +427,12 @@ Objective-C++.
### Use the graph in `ViewController` ### Use the graph in `ViewController`
In `ViewController.m`, add the following import line:
```
#import "mediapipe/objc/MPPGraph.h"
```
Declare a static constant with the name of the graph, the input stream and the Declare a static constant with the name of the graph, the input stream and the
output stream: output stream:
@ -549,6 +571,12 @@ method to receive packets on this output stream and display them on the screen:
} }
``` ```
Update the interface definition of `ViewController` with `MPPGraphDelegate`:
```
@interface ViewController () <MPPGraphDelegate, MPPInputSourceDelegate>
```
And that is all! Build and run the app on your iOS device. You should see the And that is all! Build and run the app on your iOS device. You should see the
results of running the edge detection graph on a live video feed. Congrats! results of running the edge detection graph on a live video feed. Congrats!
@ -560,5 +588,5 @@ appropriate `BUILD` file dependencies for the edge detection graph.
[Bazel]:https://bazel.build/ [Bazel]:https://bazel.build/
[`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt [`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt
[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common) [common]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common
[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld) [helloworld]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld

View File

@ -796,7 +796,7 @@ This will use a Docker image that will isolate mediapipe's installation from the
```bash ```bash
$ docker run -it --name mediapipe mediapipe:latest $ docker run -it --name mediapipe mediapipe:latest
root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazelisk run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world
# Should print: # Should print:
# Hello World! # Hello World!

View File

@ -529,7 +529,7 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http
> ``` > ```
> and then run > and then run
> >
> ```build > ```bash
> bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] > bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR]
> ``` > ```
> INPUT_DIR should be the folder with initial asset .obj files to be processed, > INPUT_DIR should be the folder with initial asset .obj files to be processed,

View File

@ -141,7 +141,7 @@ Optionally, MediaPipe Pose can predicts a full-body
Please find more detail in the Please find more detail in the
[BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), [BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html),
this [paper](https://arxiv.org/abs/2006.10204), this [paper](https://arxiv.org/abs/2006.10204),
[the model card](./models.md#pose) and the [Output](#Output) section below. [the model card](./models.md#pose) and the [Output](#output) section below.
## Solution APIs ## Solution APIs
@ -281,8 +281,8 @@ with mp_pose.Pose(
continue continue
print( print(
f'Nose coordinates: (' f'Nose coordinates: ('
f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].x * image_width}, '
f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].y * image_height})'
) )
annotated_image = image.copy() annotated_image = image.copy()
@ -369,6 +369,7 @@ Supported configuration options:
<div class="container"> <div class="container">
<video class="input_video"></video> <video class="input_video"></video>
<canvas class="output_canvas" width="1280px" height="720px"></canvas> <canvas class="output_canvas" width="1280px" height="720px"></canvas>
<div class="landmark-grid-container"></div>
</div> </div>
</body> </body>
</html> </html>

View File

@ -262,7 +262,7 @@ to visualize its associated subgraphs, please see
[(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1DoeyGzMmWUsjfVgZfGGecrn7GKzYcEAo/view?usp=sharing) [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1DoeyGzMmWUsjfVgZfGGecrn7GKzYcEAo/view?usp=sharing)
[`mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu:selfiesegmentationgpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu/BUILD) [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu:selfiesegmentationgpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu/BUILD)
* iOS target: * iOS target:
[`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](http:/mediapipe/examples/ios/selfiesegmentationgpu/BUILD) [`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/selfiesegmentationgpu/BUILD)
### Desktop ### Desktop

View File

@ -13,6 +13,9 @@ has_toc: false
{:toc} {:toc}
--- ---
MediaPipe offers open source cross-platform, customizable ML solutions for live
and streaming media.
<!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. --> <!-- []() in the first cell is needed to preserve table formatting in GitHub Pages. -->
<!-- Whenever this table is updated, paste a copy to ../external_index.md. --> <!-- Whenever this table is updated, paste a copy to ../external_index.md. -->

View File

@ -42,4 +42,9 @@ REGISTER_CALCULATOR(BeginLoopDetectionCalculator);
typedef BeginLoopCalculator<std::vector<Matrix>> BeginLoopMatrixCalculator; typedef BeginLoopCalculator<std::vector<Matrix>> BeginLoopMatrixCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixCalculator); REGISTER_CALCULATOR(BeginLoopMatrixCalculator);
// A calculator to process std::vector<std::vector<Matrix>>.
typedef BeginLoopCalculator<std::vector<std::vector<Matrix>>>
BeginLoopMatrixVectorCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -73,6 +73,7 @@ class InferenceCalculatorCpuImpl
private: private:
absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_; Packet<TfLiteModelPtr> model_packet_;
@ -91,8 +92,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
MP_RETURN_IF_ERROR(LoadDelegate(cc)); return LoadDelegateAndAllocateTensors(cc);
return absl::OkStatus();
} }
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
@ -156,11 +156,19 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) {
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread()); cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
return absl::OkStatus();
}
absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors. // TODO: Support quantized tensors.
CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != RET_CHECK_NE(
kTfLiteAffineQuantization); interpreter_->tensor(interpreter_->inputs()[0])->quantization.type,
kTfLiteAffineQuantization);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -53,6 +53,7 @@ class InferenceCalculatorGlImpl
absl::Status WriteKernelsToFile(); absl::Status WriteKernelsToFile();
absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitTFLiteGPURunner(CalculatorContext* cc); absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
@ -119,10 +120,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
} }
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, MP_RETURN_IF_ERROR(
&cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
})); : LoadDelegateAndAllocateTensors(cc);
}));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -324,11 +326,19 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread()); cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors. // TODO: Support quantized tensors.
CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != RET_CHECK_NE(
kTfLiteAffineQuantization); interpreter_->tensor(interpreter_->inputs()[0])->quantization.type,
kTfLiteAffineQuantization);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -92,6 +92,7 @@ class InferenceCalculatorMetalImpl
private: private:
absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_; Packet<TfLiteModelPtr> model_packet_;
@ -130,8 +131,7 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
MP_RETURN_IF_ERROR(LoadDelegate(cc)); return LoadDelegateAndAllocateTensors(cc);
return absl::OkStatus();
} }
absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
@ -212,11 +212,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) {
interpreter_->SetNumThreads( interpreter_->SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread()); cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
return absl::OkStatus();
}
absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors. // TODO: Support quantized tensors.
CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != RET_CHECK_NE(
kTfLiteAffineQuantization); interpreter_->tensor(interpreter_->inputs()[0])->quantization.type,
kTfLiteAffineQuantization);
return absl::OkStatus(); return absl::OkStatus();
} }
@ -236,6 +244,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk); kTfLiteOk);
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
// Get input image sizes. // Get input image sizes.

View File

@ -670,7 +670,8 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
detection_scores[i], detection_classes[i], options_.flip_vertically()); detection_scores[i], detection_classes[i], options_.flip_vertically());
const auto& bbox = detection.location_data().relative_bounding_box(); const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0) { if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
std::isnan(bbox.height())) {
// Decoded detection boxes could have negative values for width/height due // Decoded detection boxes could have negative values for width/height due
// to model prediction. Filter out those boxes since some downstream // to model prediction. Filter out those boxes since some downstream
// calculators may assume non-negative values. (b/171391719) // calculators may assume non-negative values. (b/171391719)

View File

@ -138,7 +138,6 @@ using ::tflite::gpu::gl::GlShader;
// } // }
// } // }
// //
// Currently only OpenGLES 3.1 and CPU backends supported.
// TODO Refactor and add support for other backends/platforms. // TODO Refactor and add support for other backends/platforms.
// //
class TensorsToSegmentationCalculator : public CalculatorBase { class TensorsToSegmentationCalculator : public CalculatorBase {

View File

@ -56,6 +56,8 @@ constexpr char kBboxTag[] = "BBOX";
constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER";
constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST";
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
class UnpackMediaSequenceCalculatorTest : public ::testing::Test { class UnpackMediaSequenceCalculatorTest : public ::testing::Test {

View File

@ -175,7 +175,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
auto* text = label_annotation->mutable_text(); auto* text = label_annotation->mutable_text();
std::string display_text = labels[i]; std::string display_text = labels[i];
if (cc->Inputs().HasTag(kScoresTag)) { if (cc->Inputs().HasTag(kScoresTag) ||
options_.display_classification_score()) {
absl::StrAppend(&display_text, ":", scores[i]); absl::StrAppend(&display_text, ":", scores[i]);
} }
text->set_display_text(display_text); text->set_display_text(display_text);

View File

@ -62,4 +62,7 @@ message LabelsToRenderDataCalculatorOptions {
// Uses Classification.display_name field instead of Classification.label. // Uses Classification.display_name field instead of Classification.label.
optional bool use_display_name = 9 [default = false]; optional bool use_display_name = 9 [default = false];
// Displays Classification score if enabled.
optional bool display_classification_score = 10 [default = false];
} }

View File

@ -223,24 +223,23 @@ class SubgraphImpl : public Subgraph, public Intf {
// This macro is used to register a calculator that does not use automatic // This macro is used to register a calculator that does not use automatic
// registration. Deprecated. // registration. Deprecated.
#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ #define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \ static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(calculator_registration, __LINE__)( \ REGISTRY_STATIC_VAR(calculator_registration, \
mediapipe::CalculatorBaseRegistry::Register( \ __LINE__)(mediapipe::CalculatorBaseRegistry::Register( \
Impl::kCalculatorName, \ Impl::kCalculatorName, \
absl::make_unique< \ absl::make_unique<mediapipe::internal::CalculatorBaseFactoryFor<Impl>>))
mediapipe::internal::CalculatorBaseFactoryFor<Impl>>))
// This macro is used to register a non-split-contract calculator. Deprecated. // This macro is used to register a non-split-contract calculator. Deprecated.
#define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name) #define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name)
// This macro is used to define a subgraph that does not use automatic // This macro is used to define a subgraph that does not use automatic
// registration. Deprecated. // registration. Deprecated.
#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ #define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \
static mediapipe::NoDestructor<mediapipe::RegistrationToken> \ static mediapipe::NoDestructor<mediapipe::RegistrationToken> \
REGISTRY_STATIC_VAR(subgraph_registration, \ REGISTRY_STATIC_VAR(subgraph_registration, \
__LINE__)(mediapipe::SubgraphRegistry::Register( \ __LINE__)(mediapipe::SubgraphRegistry::Register( \
Impl::kCalculatorName, absl::make_unique<Impl>)) Impl::kCalculatorName, absl::make_unique<Impl>))
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -454,12 +454,12 @@ class OutputShardAccessBase {
if (output_) output_->SetNextTimestampBound(timestamp); if (output_) output_->SetNextTimestampBound(timestamp);
} }
bool IsClosed() { return output_ ? output_->IsClosed() : true; } bool IsClosed() const { return output_ ? output_->IsClosed() : true; }
void Close() { void Close() {
if (output_) output_->Close(); if (output_) output_->Close();
} }
bool IsConnected() { return output_ != nullptr; } bool IsConnected() const { return output_ != nullptr; }
protected: protected:
const CalculatorContext& context_; const CalculatorContext& context_;
@ -559,7 +559,7 @@ class InputShardAccess : public Packet<T> {
PacketBase packet() const&& { return *this; } PacketBase packet() const&& { return *this; }
bool IsDone() const { return stream_->IsDone(); } bool IsDone() const { return stream_->IsDone(); }
bool IsConnected() { return stream_ != nullptr; } bool IsConnected() const { return stream_ != nullptr; }
PacketBase Header() const { return FromOldPacket(stream_->Header()); } PacketBase Header() const { return FromOldPacket(stream_->Header()); }
@ -619,7 +619,7 @@ class InputSidePacketAccess : public Packet<T> {
const PacketBase& packet() const& { return *this; } const PacketBase& packet() const& { return *this; }
PacketBase packet() const&& { return *this; } PacketBase packet() const&& { return *this; }
bool IsConnected() { return connected_; } bool IsConnected() const { return connected_; }
private: private:
InputSidePacketAccess(const mediapipe::Packet* packet) InputSidePacketAccess(const mediapipe::Packet* packet)
@ -639,8 +639,8 @@ class InputShardOrSideAccess : public Packet<T> {
PacketBase packet() const&& { return *this; } PacketBase packet() const&& { return *this; }
bool IsDone() const { return stream_->IsDone(); } bool IsDone() const { return stream_->IsDone(); }
bool IsConnected() { return connected_; } bool IsConnected() const { return connected_; }
bool IsStream() { return stream_ != nullptr; } bool IsStream() const { return stream_ != nullptr; }
PacketBase Header() const { return FromOldPacket(stream_->Header()); } PacketBase Header() const { return FromOldPacket(stream_->Header()); }
@ -662,7 +662,7 @@ class InputShardOrSideAccess : public Packet<T> {
class PacketTypeAccess { class PacketTypeAccess {
public: public:
bool IsConnected() { return packet_type_ != nullptr; } bool IsConnected() const { return packet_type_ != nullptr; }
protected: protected:
PacketTypeAccess(PacketType* pt) : packet_type_(pt) {} PacketTypeAccess(PacketType* pt) : packet_type_(pt) {}
@ -675,7 +675,7 @@ class PacketTypeAccess {
class PacketTypeAccessFallback : public PacketTypeAccess { class PacketTypeAccessFallback : public PacketTypeAccess {
public: public:
bool IsStream() { return is_stream_; } bool IsStream() const { return is_stream_; }
private: private:
PacketTypeAccessFallback(PacketType* pt, bool is_stream) PacketTypeAccessFallback(PacketType* pt, bool is_stream)

View File

@ -321,6 +321,8 @@ message CalculatorGraphConfig {
// The maximum number of invocations that can be executed in parallel. // The maximum number of invocations that can be executed in parallel.
// If not specified, the limit is one invocation. // If not specified, the limit is one invocation.
int32 max_in_flight = 16; int32 max_in_flight = 16;
// Defines an option value for this Node from graph options or packets.
repeated string option_value = 17;
// DEPRECATED: For backwards compatibility we allow users to // DEPRECATED: For backwards compatibility we allow users to
// specify the old name for "input_side_packet" in proto configs. // specify the old name for "input_side_packet" in proto configs.
// These are automatically converted to input_side_packets during // These are automatically converted to input_side_packets during

View File

@ -465,7 +465,7 @@ absl::Status CalculatorGraph::ObserveOutputStream(
} }
absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller( absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller(
const std::string& stream_name) { const std::string& stream_name, bool observe_timestamp_bounds) {
RET_CHECK(initialized_).SetNoLogging() RET_CHECK(initialized_).SetNoLogging()
<< "CalculatorGraph is not initialized."; << "CalculatorGraph is not initialized.";
int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); int output_stream_index = validated_graph_->OutputStreamIndex(stream_name);
@ -479,7 +479,7 @@ absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller(
stream_name, &any_packet_type_, stream_name, &any_packet_type_,
std::bind(&CalculatorGraph::UpdateThrottledNodes, this, std::bind(&CalculatorGraph::UpdateThrottledNodes, this,
std::placeholders::_1, std::placeholders::_2), std::placeholders::_1, std::placeholders::_2),
&output_stream_managers_[output_stream_index])); &output_stream_managers_[output_stream_index], observe_timestamp_bounds));
OutputStreamPoller poller(internal_poller); OutputStreamPoller poller(internal_poller);
graph_output_streams_.push_back(std::move(internal_poller)); graph_output_streams_.push_back(std::move(internal_poller));
return std::move(poller); return std::move(poller);

View File

@ -164,7 +164,8 @@ class CalculatorGraph {
// polling API for accessing a stream's output. Should only be called before // polling API for accessing a stream's output. Should only be called before
// Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See // Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See
// also the helpers in tool/sink.h. // also the helpers in tool/sink.h.
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name); StatusOrPoller AddOutputStreamPoller(const std::string& stream_name,
bool observe_timestamp_bounds = false);
// Gets output side packet by name after the graph is done. However, base // Gets output side packet by name after the graph is done. However, base
// packets (generated by PacketGenerators) can be retrieved before // packets (generated by PacketGenerators) can be retrieved before

View File

@ -4348,5 +4348,349 @@ TEST(CalculatorGraph, GraphInputStreamWithTag) {
ASSERT_EQ(5, packet_dump.size()); ASSERT_EQ(5, packet_dump.size());
} }
TEST(CalculatorGraph, GraphInputStreamBeforeStartRun) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "VIDEO_METADATA:video_metadata"
input_stream: "max_count"
node {
calculator: "PassThroughCalculator"
input_stream: "FIRST_INPUT:video_metadata"
input_stream: "max_count"
output_stream: "FIRST_INPUT:output_0"
output_stream: "output_1"
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output_0", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
ASSERT_EQ(graph
.AddPacketToInputStream("video_metadata",
MakePacket<int>(0).At(Timestamp(0)))
.code(),
absl::StatusCode::kFailedPrecondition);
}
// Returns the first packet of the input stream.
class FirstPacketFilterCalculator : public CalculatorBase {
public:
FirstPacketFilterCalculator() {}
~FirstPacketFilterCalculator() override {}
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (!seen_first_packet_) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
cc->Outputs().Index(0).Close();
seen_first_packet_ = true;
}
return absl::OkStatus();
}
private:
bool seen_first_packet_ = false;
};
REGISTER_CALCULATOR(FirstPacketFilterCalculator);
constexpr int kDefaultMaxCount = 1000;
TEST(CalculatorGraph, TestPollPacket) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node = config.add_node();
node->set_calculator("CountingSourceCalculator");
node->add_output_stream("output");
node->add_input_side_packet("MAX_COUNT:max_count");
for (int queue_size = 1; queue_size < 10; ++queue_size) {
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller = graph.AddOutputStreamPoller("output");
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
poller.SetMaxQueueSize(queue_size);
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet;
int num_packets = 0;
while (poller.Next(&packet)) {
EXPECT_EQ(num_packets, packet.Get<int>());
++num_packets;
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller.Next(&packet));
EXPECT_EQ(kDefaultMaxCount, num_packets);
}
}
TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) {
CalculatorGraphConfig config;
CalculatorGraphConfig::Node* node1 = config.add_node();
node1->set_calculator("CountingSourceCalculator");
node1->add_output_stream("stream1");
node1->add_input_side_packet("MAX_COUNT:max_count");
CalculatorGraphConfig::Node* node2 = config.add_node();
node2->set_calculator("PassThroughCalculator");
node2->add_input_stream("stream1");
node2->add_output_stream("stream2");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
auto status_or_poller1 = graph.AddOutputStreamPoller("stream1");
ASSERT_TRUE(status_or_poller1.ok());
OutputStreamPoller poller1 = std::move(status_or_poller1.value());
auto status_or_poller2 = graph.AddOutputStreamPoller("stream2");
ASSERT_TRUE(status_or_poller2.ok());
OutputStreamPoller poller2 = std::move(status_or_poller2.value());
MP_ASSERT_OK(
graph.StartRun({{"max_count", MakePacket<int>(kDefaultMaxCount)}}));
Packet packet1;
Packet packet2;
int num_packets1 = 0;
int num_packets2 = 0;
int running_pollers = 2;
while (running_pollers > 0) {
if (poller1.Next(&packet1)) {
EXPECT_EQ(num_packets1++, packet1.Get<int>());
} else {
--running_pollers;
}
if (poller2.Next(&packet2)) {
EXPECT_EQ(num_packets2++, packet2.Get<int>());
} else {
--running_pollers;
}
}
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_FALSE(poller1.Next(&packet1));
EXPECT_FALSE(poller2.Next(&packet2));
EXPECT_EQ(kDefaultMaxCount, num_packets1);
EXPECT_EQ(kDefaultMaxCount, num_packets2);
}
class TimestampBoundTestCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
absl::Status Process(CalculatorContext* cc) final {
if (count_ % 50 == 1) {
// Outputs packets at t10 and t60.
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(count_).At(Timestamp(count_)));
} else if (count_ % 15 == 7) {
cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_));
}
absl::SleepFor(absl::Milliseconds(3));
++count_;
if (count_ == 110) {
return tool::StatusStop();
}
return absl::OkStatus();
}
private:
int count_ = 0;
};
REGISTER_CALCULATOR(TimestampBoundTestCalculator);
TEST(CalculatorGraph, TestPollPacketsWithTimestampNotification) {
std::string config_str = R"(
node {
calculator: "TimestampBoundTestCalculator"
output_stream: "foo"
}
)";
CalculatorGraphConfig graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(config_str);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
auto status_or_poller =
graph.AddOutputStreamPoller("foo", /*observe_timestamp_bounds=*/true);
ASSERT_TRUE(status_or_poller.ok());
OutputStreamPoller poller = std::move(status_or_poller.value());
Packet packet;
std::vector<int> timestamps;
std::vector<int> values;
MP_ASSERT_OK(graph.StartRun({}));
while (poller.Next(&packet)) {
if (packet.IsEmpty()) {
timestamps.push_back(packet.Timestamp().Value());
} else {
values.push_back(packet.Get<int>());
}
}
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_FALSE(poller.Next(&packet));
ASSERT_FALSE(timestamps.empty());
int prev_t = 0;
for (auto t : timestamps) {
EXPECT_TRUE(t > prev_t && t < 110);
prev_t = t;
}
ASSERT_EQ(3, values.size());
EXPECT_EQ(1, values[0]);
EXPECT_EQ(51, values[1]);
EXPECT_EQ(101, values[2]);
}
// Ensure that when a custom input stream handler is used to handle packets from
// input streams, an error message is outputted with the appropriate link to
// resolve the issue when the calculator doesn't handle inputs in monotonically
// increasing order of timestamps.
TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'input0'
input_stream: 'input1'
node {
calculator: 'SimpleMuxCalculator'
input_stream: 'input0'
input_stream: 'input1'
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
output_stream: 'output'
}
)pb");
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Send packets to input stream "input0" at timestamps 0 and 1 consecutively.
Timestamp input0_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(1, packet_dump[0].Get<int>());
++input0_timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(3).At(input0_timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(3, packet_dump[1].Get<int>());
// Send a packet to input stream "input1" at timestamp 0 after sending two
// packets at timestamps 0 and 1 to input stream "input0". This will result
// in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle
// inputs from all streams in monotonically increasing order of timestamps.
Timestamp input1_timestamp = Timestamp(0);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(2).At(input1_timestamp)));
absl::Status run_status = graph.WaitUntilIdle();
EXPECT_THAT(
run_status.ToString(),
testing::AllOf(
// The core problem.
testing::HasSubstr("timestamp mismatch on a calculator"),
testing::HasSubstr(
"timestamps that are not strictly monotonically increasing"),
// Link to the possible solution.
testing::HasSubstr("ImmediateInputStreamHandler class comment")));
}
void DoTestMultipleGraphRuns(absl::string_view input_stream_handler,
bool select_packet) {
std::string graph_proto = absl::StrFormat(R"(
input_stream: 'input'
input_stream: 'select'
node {
calculator: 'PassThroughCalculator'
input_stream: 'input'
input_stream: 'select'
input_stream_handler {
input_stream_handler: "%s"
}
output_stream: 'output'
output_stream: 'select_out'
}
)",
input_stream_handler.data());
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> packet_dump;
tool::AddVectorSink("output", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
struct Run {
Timestamp timestamp;
int value;
};
std::vector<Run> runs = {{.timestamp = Timestamp(2000), .value = 2},
{.timestamp = Timestamp(1000), .value = 1}};
for (const Run& run : runs) {
MP_ASSERT_OK(graph.StartRun({}));
if (select_packet) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(run.timestamp)));
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"input", MakePacket<int>(run.value).At(run.timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(run.value, packet_dump[0].Get<int>());
EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
packet_dump.clear();
}
}
TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) {
DoTestMultipleGraphRuns("BarrierInputStreamHandler", true);
DoTestMultipleGraphRuns("DefaultInputStreamHandler", true);
DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true);
DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true);
DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false);
DoTestMultipleGraphRuns("MuxInputStreamHandler", true);
DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true);
DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true);
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -45,7 +45,7 @@ std::string JoinPathImpl(bool honor_abs,
// This size calculation is worst-case: it assumes one extra "/" for every // This size calculation is worst-case: it assumes one extra "/" for every
// path other than the first. // path other than the first.
size_t total_size = paths.size() - 1; size_t total_size = paths.size() - 1;
for (const absl::string_view path : paths) total_size += path.size(); for (const absl::string_view& path : paths) total_size += path.size();
result.resize(total_size); result.resize(total_size);
auto begin = result.begin(); auto begin = result.begin();

View File

@ -81,6 +81,12 @@ mediapipe_proto_library(
deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"],
) )
mediapipe_proto_library(
name = "affine_transform_data_proto",
srcs = ["affine_transform_data.proto"],
visibility = ["//visibility:public"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "time_series_header_proto", name = "time_series_header_proto",
srcs = ["time_series_header.proto"], srcs = ["time_series_header.proto"],
@ -119,6 +125,31 @@ cc_library(
], ],
) )
cc_library(
name = "affine_transform",
srcs = ["affine_transform.cc"],
hdrs = ["affine_transform.h"],
visibility = [
"//visibility:public",
],
deps = [
"//mediapipe/framework:port",
"//mediapipe/framework:type_map",
"//mediapipe/framework/formats:affine_transform_data_cc_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:point",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/tool:status_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf",
],
)
cc_library( cc_library(
name = "image_frame", name = "image_frame",
srcs = ["image_frame.cc"], srcs = ["image_frame.cc"],

View File

@ -0,0 +1,228 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/formats/affine_transform.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/point2.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/type_map.h"
namespace mediapipe {
using ::mediapipe::AffineTransformData;
AffineTransform::AffineTransform() { SetScale(Point2_f(1, 1)); }
AffineTransform::AffineTransform(
const AffineTransformData& affine_transform_data)
: affine_transform_data_(affine_transform_data), is_dirty_(true) {
// make sure scale is set to default (1, 1) when none provided
if (!affine_transform_data_.has_scale()) {
SetScale(Point2_f(1, 1));
}
}
AffineTransform AffineTransform::Create(const Point2_f& translation,
const Point2_f& scale, float rotation,
const Point2_f& shear) {
AffineTransformData affine_transform_data;
auto* t = affine_transform_data.mutable_translation();
t->set_x(translation.x());
t->set_y(translation.y());
auto* s = affine_transform_data.mutable_scale();
s->set_x(scale.x());
s->set_y(scale.y());
s = affine_transform_data.mutable_shear();
s->set_x(shear.x());
s->set_y(shear.y());
affine_transform_data.set_rotation(rotation);
return AffineTransform(affine_transform_data);
}
// Accessor for the composition matrix
std::vector<float> AffineTransform::GetCompositionMatrix() {
float r = affine_transform_data_.rotation();
const auto t = affine_transform_data_.translation();
const auto sc = affine_transform_data_.scale();
const auto sh = affine_transform_data_.shear();
if (is_dirty_) {
// Composition matrix M = T*R*Sh*Sc
// Column based to match GL matrix store order
float cos_r = std::cos(r);
float sin_r = std::sin(r);
matrix_[0] = (cos_r + sin_r * -sh.y()) * sc.x();
matrix_[1] = (-sin_r + cos_r * -sh.y()) * sc.x();
matrix_[2] = 0;
matrix_[3] = (cos_r * -sh.x() + sin_r) * sc.y();
matrix_[4] = (-sin_r * -sh.x() + cos_r) * sc.y();
matrix_[5] = 0;
matrix_[6] = t.x();
matrix_[7] = -t.y();
matrix_[8] = 1;
is_dirty_ = false;
}
return matrix_;
}
Point2_f AffineTransform::GetScale() const {
return Point2_f(affine_transform_data_.scale().x(),
affine_transform_data_.scale().y());
}
Point2_f AffineTransform::GetTranslation() const {
return Point2_f(affine_transform_data_.translation().x(),
affine_transform_data_.translation().y());
}
Point2_f AffineTransform::GetShear() const {
return Point2_f(affine_transform_data_.shear().x(),
affine_transform_data_.shear().y());
}
float AffineTransform::GetRotation() const {
return affine_transform_data_.rotation();
}
void AffineTransform::SetScale(const Point2_f& scale) {
auto* s = affine_transform_data_.mutable_scale();
s->set_x(scale.x());
s->set_y(scale.y());
is_dirty_ = true;
}
void AffineTransform::SetTranslation(const Point2_f& translation) {
auto* t = affine_transform_data_.mutable_translation();
t->set_x(translation.x());
t->set_y(translation.y());
is_dirty_ = true;
}
void AffineTransform::SetShear(const Point2_f& shear) {
auto* s = affine_transform_data_.mutable_shear();
s->set_x(shear.x());
s->set_y(shear.y());
is_dirty_ = true;
}
void AffineTransform::SetRotation(float rotationInRadians) {
affine_transform_data_.set_rotation(rotationInRadians);
is_dirty_ = true;
}
void AffineTransform::AddScale(const Point2_f& scale) {
auto* s = affine_transform_data_.mutable_scale();
s->set_x(s->x() + scale.x());
s->set_y(s->y() + scale.y());
is_dirty_ = true;
}
void AffineTransform::AddTranslation(const Point2_f& translation) {
auto* t = affine_transform_data_.mutable_translation();
t->set_x(t->x() + translation.x());
t->set_y(t->y() + translation.y());
is_dirty_ = true;
}
void AffineTransform::AddShear(const Point2_f& shear) {
auto* s = affine_transform_data_.mutable_shear();
s->set_x(s->x() + shear.x());
s->set_y(s->y() + shear.y());
is_dirty_ = true;
}
void AffineTransform::AddRotation(float rotationInRadians) {
affine_transform_data_.set_rotation(affine_transform_data_.rotation() +
rotationInRadians);
is_dirty_ = true;
}
void AffineTransform::SetFromProto(const AffineTransformData& proto) {
affine_transform_data_ = proto;
}
void AffineTransform::ConvertToProto(AffineTransformData* proto) const {
*proto = affine_transform_data_;
}
AffineTransformData AffineTransform::ConvertToProto() const {
AffineTransformData affine_transform_data;
ConvertToProto(&affine_transform_data);
return affine_transform_data;
}
bool compare(float lhs, float rhs, float epsilon = 0.001f) {
return std::fabs(lhs - rhs) < epsilon;
}
bool AffineTransform::Equals(const AffineTransform& other,
float epsilon) const {
auto trans1 = GetTranslation();
auto trans2 = other.GetTranslation();
if (!(compare(trans1.x(), trans2.x(), epsilon) &&
compare(trans1.y(), trans2.y(), epsilon)))
return false;
auto scale1 = GetScale();
auto scale2 = other.GetScale();
if (!(compare(scale1.x(), scale2.x(), epsilon) &&
compare(scale1.y(), scale2.y(), epsilon)))
return false;
auto shear1 = GetShear();
auto shear2 = other.GetShear();
if (!(compare(shear1.x(), shear2.x(), epsilon) &&
compare(shear1.y(), shear2.y(), epsilon)))
return false;
auto rot1 = GetRotation();
auto rot2 = other.GetRotation();
if (!compare(rot1, rot2, epsilon)) {
return false;
}
return true;
}
bool AffineTransform::Equal(const AffineTransform& lhs,
const AffineTransform& rhs, float epsilon) {
return lhs.Equals(rhs, epsilon);
}
MEDIAPIPE_REGISTER_TYPE(mediapipe::AffineTransform,
"::mediapipe::AffineTransform", nullptr, nullptr);
} // namespace mediapipe

View File

@ -0,0 +1,86 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// A container for affine transform data
// This wrapper provides two functionalities:
// 1. Factory methods for creation of Transform objects and thus
// AffineTransformData protocol buffers. These methods guarantee a valid
// affine transform data and are the preferred way of creating such.
// 2. Accessors which allow for access of the data and the convertion to proto
// format
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_
#include <memory>
#include <vector>
#include "mediapipe/framework/formats/affine_transform_data.pb.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/point2.h"
namespace mediapipe {
class AffineTransform {
public:
// CREATION METHODS.
AffineTransform();
// Constructs a affine transform wrapping the specified affine transform data.
// Checks the validity of the input and crashes upon failure.
explicit AffineTransform(const AffineTransformData& transform_data);
static AffineTransform Create(const Point2_f& translation = Point2_f(0, 0),
const Point2_f& scale = Point2_f(1, 1),
float rotation = 0,
const Point2_f& shear = Point2_f(0, 0));
// ACCESSORS
// Accessor for the composition matrix
std::vector<float> GetCompositionMatrix();
Point2_f GetScale() const;
Point2_f GetTranslation() const;
Point2_f GetShear() const;
float GetRotation() const;
void SetScale(const Point2_f& scale);
void SetTranslation(const Point2_f& translation);
void SetShear(const Point2_f& shear);
void SetRotation(float rotation);
void AddScale(const Point2_f& scale);
void AddTranslation(const Point2_f& translation);
void AddShear(const Point2_f& shear);
void AddRotation(float rotation);
// Serializes and deserializes the affine transform object.
void ConvertToProto(AffineTransformData* proto) const;
AffineTransformData ConvertToProto() const;
void SetFromProto(const AffineTransformData& proto);
bool Equals(const AffineTransform& other, float epsilon = 0.001f) const;
static bool Equal(const AffineTransform& lhs, const AffineTransform& rhs,
float epsilon = 0.001f);
private:
// The wrapped transform data.
AffineTransformData affine_transform_data_;
std::vector<float> matrix_ = {1, 0, 0, 0, 1, 0, 0, 0, 1};
bool is_dirty_ = false;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_

View File

@ -0,0 +1,33 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
option objc_class_prefix = "MediaPipe";
// Proto for serializing Vector2 data
message Vector2Data {
optional float x = 1;
optional float y = 2;
}
// Proto for serializing Affine Transform data.
message AffineTransformData {
optional Vector2Data translation = 1;
optional Vector2Data scale = 2;
optional Vector2Data shear = 3;
optional float rotation = 4; // in radians
}

View File

@ -0,0 +1,94 @@
#include "mediapipe/framework/formats/affine_transform.h"
#include <string>
#include "base/logging.h"
#include "mediapipe/framework/formats/affine_transform_data.pb.h"
#include "mediapipe/framework/port/point2.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
namespace mediapipe {
TEST(AffineTransformTest, TraslationTest) {
AffineTransform transform;
transform.SetTranslation(Point2_f(10, -3));
auto trans = transform.GetTranslation();
EXPECT_FLOAT_EQ(10, trans.x());
EXPECT_FLOAT_EQ(-3, trans.y());
transform.AddTranslation(Point2_f(-10, 3));
trans = transform.GetTranslation();
EXPECT_FLOAT_EQ(0, trans.x());
EXPECT_FLOAT_EQ(0, trans.y());
}
TEST(AffineTransformTest, ScaleTest) {
AffineTransform transform;
transform.SetScale(Point2_f(10, -3));
auto scale = transform.GetScale();
EXPECT_FLOAT_EQ(10, scale.x());
EXPECT_FLOAT_EQ(-3, scale.y());
transform.AddScale(Point2_f(-10, 3));
scale = transform.GetScale();
EXPECT_FLOAT_EQ(0, scale.x());
EXPECT_FLOAT_EQ(0, scale.y());
}
TEST(AffineTransformTest, RotationTest) {
AffineTransform transform;
transform.SetRotation(0.7);
float rot = transform.GetRotation();
EXPECT_FLOAT_EQ(0.7, rot);
transform.AddRotation(-0.7);
rot = transform.GetRotation();
EXPECT_FLOAT_EQ(0, rot);
}
TEST(AffineTransformTest, ShearTest) {
AffineTransform transform;
transform.SetShear(Point2_f(10, -3));
auto shear = transform.GetShear();
EXPECT_FLOAT_EQ(10, shear.x());
EXPECT_FLOAT_EQ(-3, shear.y());
transform.AddShear(Point2_f(-10, 3));
shear = transform.GetShear();
EXPECT_FLOAT_EQ(0, shear.x());
EXPECT_FLOAT_EQ(0, shear.y());
}
TEST(AffineTransformTest, TransformTest) {
AffineTransform transform1;
transform1 = AffineTransform::Create(Point2_f(0.1, -0.2), Point2_f(0.3, -0.4),
0.5, Point2_f(0.6, -0.7));
AffineTransform transform2;
transform2 = AffineTransform::Create(Point2_f(0.1, -0.2), Point2_f(0.3, -0.4),
0.5, Point2_f(0.6, -0.7));
EXPECT_THAT(true, transform1.Equals(transform2));
EXPECT_THAT(true, AffineTransform::Equal(transform1, transform2));
transform1 = AffineTransform::Create(Point2_f(0.00001, -0.00002),
Point2_f(0.00003, -0.00004), 0.00005,
Point2_f(0.00006, -0.00007));
transform2 = AffineTransform::Create(Point2_f(0.00001, -0.00002),
Point2_f(0.00003, -0.00004), 0.00005,
Point2_f(0.00006, -0.00007));
EXPECT_THAT(true, transform1.Equals(transform2, 0.000001));
EXPECT_THAT(true, AffineTransform::Equal(transform1, transform2, 0.000001));
}
} // namespace mediapipe

View File

@ -125,9 +125,10 @@ absl::Status OutputStreamObserver::Notify() {
absl::Status OutputStreamPollerImpl::Initialize( absl::Status OutputStreamPollerImpl::Initialize(
const std::string& stream_name, const PacketType* packet_type, const std::string& stream_name, const PacketType* packet_type,
std::function<void(InputStreamManager*, bool*)> queue_size_callback, std::function<void(InputStreamManager*, bool*)> queue_size_callback,
OutputStreamManager* output_stream_manager) { OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) {
MP_RETURN_IF_ERROR(GraphOutputStream::Initialize(stream_name, packet_type, MP_RETURN_IF_ERROR(GraphOutputStream::Initialize(stream_name, packet_type,
output_stream_manager)); output_stream_manager,
observe_timestamp_bounds));
input_stream_handler_->SetQueueSizeCallbacks(queue_size_callback, input_stream_handler_->SetQueueSizeCallbacks(queue_size_callback,
queue_size_callback); queue_size_callback);
return absl::OkStatus(); return absl::OkStatus();
@ -176,11 +177,17 @@ void OutputStreamPollerImpl::NotifyError() {
bool OutputStreamPollerImpl::Next(Packet* packet) { bool OutputStreamPollerImpl::Next(Packet* packet) {
CHECK(packet); CHECK(packet);
bool empty_queue = true; bool empty_queue = true;
bool timestamp_bound_changed = false;
Timestamp min_timestamp = Timestamp::Unset(); Timestamp min_timestamp = Timestamp::Unset();
mutex_.Lock(); mutex_.Lock();
while (true) { while (true) {
min_timestamp = input_stream_->MinTimestampOrBound(&empty_queue); min_timestamp = input_stream_->MinTimestampOrBound(&empty_queue);
if (graph_has_error_ || !empty_queue || if (empty_queue) {
timestamp_bound_changed =
input_stream_handler_->ProcessTimestampBounds() &&
output_timestamp_ < min_timestamp.PreviousAllowedInStream();
}
if (graph_has_error_ || !empty_queue || timestamp_bound_changed ||
min_timestamp == Timestamp::Done()) { min_timestamp == Timestamp::Done()) {
break; break;
} else { } else {
@ -191,17 +198,26 @@ bool OutputStreamPollerImpl::Next(Packet* packet) {
mutex_.Unlock(); mutex_.Unlock();
return false; return false;
} }
if (empty_queue) {
output_timestamp_ = min_timestamp.PreviousAllowedInStream();
} else {
output_timestamp_ = min_timestamp;
}
mutex_.Unlock(); mutex_.Unlock();
if (min_timestamp == Timestamp::Done()) { if (min_timestamp == Timestamp::Done()) {
return false; return false;
} }
int num_packets_dropped = 0; if (!empty_queue) {
bool stream_is_done = false; int num_packets_dropped = 0;
*packet = input_stream_->PopPacketAtTimestamp( bool stream_is_done = false;
min_timestamp, &num_packets_dropped, &stream_is_done); *packet = input_stream_->PopPacketAtTimestamp(
CHECK_EQ(num_packets_dropped, 0) min_timestamp, &num_packets_dropped, &stream_is_done);
<< absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", CHECK_EQ(num_packets_dropped, 0)
num_packets_dropped, input_stream_->Name()); << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".",
num_packets_dropped, input_stream_->Name());
} else if (timestamp_bound_changed) {
*packet = Packet().At(min_timestamp.PreviousAllowedInStream());
}
return true; return true;
} }

View File

@ -143,7 +143,8 @@ class OutputStreamPollerImpl : public GraphOutputStream {
absl::Status Initialize( absl::Status Initialize(
const std::string& stream_name, const PacketType* packet_type, const std::string& stream_name, const PacketType* packet_type,
std::function<void(InputStreamManager*, bool*)> queue_size_callback, std::function<void(InputStreamManager*, bool*)> queue_size_callback,
OutputStreamManager* output_stream_manager); OutputStreamManager* output_stream_manager,
bool observe_timestamp_bounds = false);
void PrepareForRun(std::function<void()> notification_callback, void PrepareForRun(std::function<void()> notification_callback,
std::function<void(absl::Status)> error_callback) override; std::function<void(absl::Status)> error_callback) override;
@ -170,6 +171,7 @@ class OutputStreamPollerImpl : public GraphOutputStream {
absl::Mutex mutex_; absl::Mutex mutex_;
absl::CondVar handler_condvar_ ABSL_GUARDED_BY(mutex_); absl::CondVar handler_condvar_ ABSL_GUARDED_BY(mutex_);
bool graph_has_error_ ABSL_GUARDED_BY(mutex_); bool graph_has_error_ ABSL_GUARDED_BY(mutex_);
Timestamp output_timestamp_ ABSL_GUARDED_BY(mutex_) = Timestamp::Min();
}; };
} // namespace internal } // namespace internal

View File

@ -51,4 +51,17 @@ message NightLightCalculatorOptions {
// Format string used by string::Substitute to construct the output. // Format string used by string::Substitute to construct the output.
optional string format_string = 9; optional string format_string = 9;
message LightBundle {
optional string room_id = 1;
repeated NightLightCalculatorOptions room_lights = 2;
}
repeated LightBundle bundle = 10;
// The number of night-lights.
repeated int32 num_lights = 11;
// Options for nested night-lights.
optional NightLightCalculatorOptions sub_options = 12;
} }

View File

@ -180,15 +180,66 @@ cc_library(
], ],
) )
mediapipe_proto_library(
name = "field_data_proto",
srcs = ["field_data.proto"],
visibility = ["//visibility:public"],
deps = ["@com_google_protobuf//:any_proto"],
)
cc_library(
name = "options_field_util",
srcs = ["options_field_util.cc"],
hdrs = ["options_field_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":field_data_cc_proto",
":name_util",
":options_registry",
":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "options_syntax_util",
srcs = ["options_syntax_util.cc"],
hdrs = ["options_syntax_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [
":name_util",
":options_field_util",
":options_registry",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_type",
"//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "options_util", name = "options_util",
srcs = ["options_util.cc"], srcs = ["options_util.cc"],
hdrs = ["options_util.h"], hdrs = ["options_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [ deps = [
":options_field_util",
":options_map", ":options_map",
":options_registry",
":options_syntax_util",
":proto_util_lite", ":proto_util_lite",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:collection", "//mediapipe/framework:collection",
"//mediapipe/framework:input_stream_shard", "//mediapipe/framework:input_stream_shard",
"//mediapipe/framework:output_side_packet", "//mediapipe/framework:output_side_packet",
@ -199,7 +250,7 @@ cc_library(
"//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:any_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/tool:type_util", "//mediapipe/framework/tool:name_util",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -227,6 +278,8 @@ cc_library(
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
], ],
) )
@ -246,11 +299,13 @@ mediapipe_cc_test(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework:validated_graph_config", "//mediapipe/framework:validated_graph_config",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/testdata:night_light_calculator_options_lib",
"//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib",
"//mediapipe/framework/tool:options_syntax_util",
"//mediapipe/util:header_util", "//mediapipe/util:header_util",
], ],
) )

View File

@ -0,0 +1,47 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from mediapipe/framework/tool/source.proto.
// The forked proto must remain identical to the original proto and should be
// ONLY used by mediapipe open source project.
syntax = "proto2";
package mediapipe;
// `MessageData`, like protobuf.Any, contains an arbitrary serialized protbuf
// along with a URL that describes the type of the serialized message.
message MessageData {
// A URL/resource name that identifies the type of serialized protbuf.
optional string type_url = 1;
// Must be a valid serialized protocol buffer of the above specified type.
optional bytes value = 2;
}
// Data for one Protobuf field or one MediaPipe packet.
message FieldData {
oneof value {
sint32 int32_value = 1;
sint64 int64_value = 2;
uint32 uint32_value = 3;
uint64 uint64_value = 4;
double double_value = 5;
float float_value = 6;
bool bool_value = 7;
sint32 enum_value = 8;
string string_value = 9;
MessageData message_value = 10;
}
}

View File

@ -0,0 +1,495 @@
#include "mediapipe/framework/tool/options_field_util.h"
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/proto_util_lite.h"
namespace mediapipe {
namespace tool {
namespace options_field_util {
using ::mediapipe::proto_ns::internal::WireFormatLite;
using FieldType = WireFormatLite::FieldType;
using ::mediapipe::proto_ns::io::ArrayInputStream;
using ::mediapipe::proto_ns::io::CodedInputStream;
using ::mediapipe::proto_ns::io::CodedOutputStream;
using ::mediapipe::proto_ns::io::StringOutputStream;
// Utility functions for OptionsFieldUtil.
namespace {
// Converts a FieldDescriptor::Type to the corresponding FieldType.
FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) {
return static_cast<FieldType>(type);
}
absl::Status WriteValue(const FieldData& value, FieldType field_type,
std::string* field_bytes) {
StringOutputStream sos(field_bytes);
CodedOutputStream out(&sos);
switch (field_type) {
case WireFormatLite::TYPE_INT32:
WireFormatLite::WriteInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_SINT32:
WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_INT64:
WireFormatLite::WriteInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_SINT64:
WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_UINT32:
WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out);
break;
case WireFormatLite::TYPE_UINT64:
WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_DOUBLE:
WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_FLOAT:
WireFormatLite::WriteFloatNoTag(value.float_value(), &out);
break;
case WireFormatLite::TYPE_BOOL:
WireFormatLite::WriteBoolNoTag(value.bool_value(), &out);
break;
case WireFormatLite::TYPE_ENUM:
WireFormatLite::WriteEnumNoTag(value.enum_value(), &out);
break;
case WireFormatLite::TYPE_STRING:
out.WriteString(value.string_value());
break;
case WireFormatLite::TYPE_MESSAGE:
out.WriteString(value.message_value().value());
break;
default:
return absl::UnimplementedError(
absl::StrCat("Cannot write type: ", field_type));
}
return mediapipe::OkStatus();
}
// Serializes a packet value.
absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field,
std::string* result) {
FieldType field_type = AsFieldType(field->type());
return WriteValue(packet, field_type, result);
}
template <typename ValueT, FieldType kFieldType>
static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) {
ArrayInputStream ais(field_bytes.data(), field_bytes.size());
CodedInputStream input(&ais);
ValueT result;
if (!WireFormatLite::ReadPrimitive<ValueT, kFieldType>(&input, &result)) {
status->Update(mediapipe::InvalidArgumentError(absl::StrCat(
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<ValueT>(),
".")));
}
return result;
}
absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type,
absl::string_view message_type, FieldData* result) {
absl::Status status;
result->Clear();
switch (field_type) {
case WireFormatLite::TYPE_INT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_INT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_SINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_INT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_INT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_SINT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT32:
result->set_uint32_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT64:
result->set_uint64_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_DOUBLE:
result->set_double_value(
ReadValue<double, WireFormatLite::TYPE_DOUBLE>(field_bytes, &status));
break;
case WireFormatLite::TYPE_FLOAT:
result->set_float_value(
ReadValue<float, WireFormatLite::TYPE_FLOAT>(field_bytes, &status));
break;
case WireFormatLite::TYPE_BOOL:
result->set_bool_value(
ReadValue<bool, WireFormatLite::TYPE_BOOL>(field_bytes, &status));
break;
case WireFormatLite::TYPE_ENUM:
result->set_enum_value(
ReadValue<int32, WireFormatLite::TYPE_ENUM>(field_bytes, &status));
break;
case WireFormatLite::TYPE_STRING:
result->set_string_value(std::string(field_bytes));
break;
case WireFormatLite::TYPE_MESSAGE:
result->mutable_message_value()->set_value(std::string(field_bytes));
result->mutable_message_value()->set_type_url(TypeUrl(message_type));
break;
default:
status = absl::UnimplementedError(
absl::StrCat("Cannot read type: ", field_type));
break;
}
return status;
}
// Deserializes a packet from a protobuf field.
absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field,
FieldData* result) {
FieldType field_type = AsFieldType(field->type());
std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE)
? field->message_type()->full_name()
: "";
return ReadValue(bytes, field_type, message_type, result);
}
// Converts a chain of fields and indexes into field-numbers and indexes.
ProtoUtilLite::ProtoPath AsProtoPath(const FieldPath& field_path) {
ProtoUtilLite::ProtoPath result;
for (auto field : field_path) {
result.push_back({field.first->number(), field.second});
}
return result;
}
// Returns the options protobuf for a subgraph.
// TODO: Ensure that this works with multiple options protobufs.
absl::Status GetOptionsMessage(
const proto_ns::RepeatedPtrField<mediapipe::protobuf::Any>& options_any,
const proto_ns::MessageLite& options_ext, FieldData* result) {
// Read the "graph_options" or "node_options" field.
for (const auto& options : options_any) {
if (options.type_url().empty()) {
continue;
}
result->mutable_message_value()->set_type_url(options.type_url());
result->mutable_message_value()->set_value(std::string(options.value()));
return mediapipe::OkStatus();
}
// Read the "options" field.
FieldData message_data;
*message_data.mutable_message_value()->mutable_value() =
options_ext.SerializeAsString();
message_data.mutable_message_value()->set_type_url(options_ext.GetTypeName());
std::vector<const FieldDescriptor*> ext_fields;
OptionsRegistry::FindAllExtensions(options_ext.GetTypeName(), &ext_fields);
for (auto ext_field : ext_fields) {
absl::Status status = GetField({{ext_field, 0}}, message_data, result);
if (!status.ok()) {
return status;
}
if (result->has_message_value()) {
return status;
}
}
return mediapipe::OkStatus();
}
// Sets a protobuf in a repeated protobuf::Any field.
void SetOptionsMessage(
const FieldData& node_options,
proto_ns::RepeatedPtrField<mediapipe::protobuf::Any>* result) {
protobuf::Any* options_any = nullptr;
for (auto& any : *result) {
if (any.type_url() == node_options.message_value().type_url()) {
options_any = &any;
}
}
if (!options_any) {
options_any = result->Add();
options_any->set_type_url(node_options.message_value().type_url());
}
*options_any->mutable_value() = node_options.message_value().value();
}
} // anonymous namespace
// Deserializes a packet containing a MessageLite value.
absl::Status ReadMessage(const std::string& value, const std::string& type_name,
Packet* result) {
auto packet = packet_internal::PacketFromDynamicProto(type_name, value);
if (packet.ok()) {
*result = *packet;
}
return packet.status();
}
// Merge two options FieldData values.
absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over,
FieldData* result) {
absl::Status status;
if (over.value_case() == FieldData::VALUE_NOT_SET) {
*result = base;
return status;
}
if (base.value_case() == FieldData::VALUE_NOT_SET) {
*result = over;
return status;
}
if (over.value_case() != base.value_case()) {
return absl::InvalidArgumentError(absl::StrCat(
"Cannot merge field data with data types: ", base.value_case(), ", ",
over.value_case()));
}
if (over.message_value().type_url() != base.message_value().type_url()) {
return absl::InvalidArgumentError(
absl::StrCat("Cannot merge field data with message types: ",
base.message_value().type_url(), ", ",
over.message_value().type_url()));
}
absl::Cord merged_value;
merged_value.Append(base.message_value().value());
merged_value.Append(over.message_value().value());
result->mutable_message_value()->set_type_url(
base.message_value().type_url());
result->mutable_message_value()->set_value(std::string(merged_value));
return status;
}
// Writes a FieldData value into protobuf field.
absl::Status SetField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data) {
if (field_path.empty()) {
*message_data->mutable_message_value() = value.message_value();
return mediapipe::OkStatus();
}
ProtoUtilLite proto_util;
const FieldDescriptor* field = field_path.back().first;
FieldType field_type = AsFieldType(field->type());
std::string field_value;
MP_RETURN_IF_ERROR(WriteField(value, field, &field_value));
ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path);
std::string* message_bytes =
message_data->mutable_message_value()->mutable_value();
int field_count;
MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path,
field_type, &field_count));
MP_RETURN_IF_ERROR(
proto_util.ReplaceFieldRange(message_bytes, AsProtoPath(field_path),
field_count, field_type, {field_value}));
return mediapipe::OkStatus();
}
// Merges a packet value into nested protobuf Message.
absl::Status MergeField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data) {
absl::Status status;
FieldType field_type = field_path.empty()
? FieldType::TYPE_MESSAGE
: AsFieldType(field_path.back().first->type());
std::string message_type =
(value.has_message_value())
? ParseTypeUrl(std::string(value.message_value().type_url()))
: "";
FieldData v = value;
if (field_type == FieldType::TYPE_MESSAGE) {
FieldData b;
status.Update(GetField(field_path, *message_data, &b));
status.Update(MergeOptionsMessages(b, v, &v));
}
status.Update(SetField(field_path, v, message_data));
return status;
}
// Reads a packet value from a protobuf field.
absl::Status GetField(const FieldPath& field_path,
const FieldData& message_data, FieldData* result) {
if (field_path.empty()) {
*result->mutable_message_value() = message_data.message_value();
return mediapipe::OkStatus();
}
ProtoUtilLite proto_util;
const FieldDescriptor* field = field_path.back().first;
FieldType field_type = AsFieldType(field->type());
std::vector<std::string> field_values;
ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path);
const std::string& message_bytes = message_data.message_value().value();
int field_count;
MP_RETURN_IF_ERROR(proto_util.GetFieldCount(message_bytes, proto_path,
field_type, &field_count));
if (field_count == 0) {
return mediapipe::OkStatus();
}
MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1,
field_type, &field_values));
MP_RETURN_IF_ERROR(ReadField(field_values.front(), field, result));
return mediapipe::OkStatus();
}
// Returns the options protobuf for a graph.
absl::Status GetOptionsMessage(const CalculatorGraphConfig& config,
FieldData* result) {
return GetOptionsMessage(config.graph_options(), config.options(), result);
}
// Returns the options protobuf for a node.
absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node,
FieldData* result) {
return GetOptionsMessage(node.node_options(), node.options(), result);
}
// Sets the node_options field in a Node, and clears the options field.
void SetOptionsMessage(const FieldData& node_options,
CalculatorGraphConfig::Node* node) {
SetOptionsMessage(node_options, node->mutable_node_options());
node->clear_options();
}
// Represents a protobuf enum value stored in a Packet.
struct ProtoEnum {
ProtoEnum(int32 v) : value(v) {}
int32 value;
};
absl::Status AsPacket(const FieldData& data, Packet* result) {
switch (data.value_case()) {
case FieldData::ValueCase::kInt32Value:
*result = MakePacket<int32>(data.int32_value());
break;
case FieldData::ValueCase::kInt64Value:
*result = MakePacket<int64>(data.int64_value());
break;
case FieldData::ValueCase::kUint32Value:
*result = MakePacket<uint32>(data.uint32_value());
break;
case FieldData::ValueCase::kUint64Value:
*result = MakePacket<uint64>(data.uint64_value());
break;
case FieldData::ValueCase::kDoubleValue:
*result = MakePacket<double>(data.double_value());
break;
case FieldData::ValueCase::kFloatValue:
*result = MakePacket<float>(data.float_value());
break;
case FieldData::ValueCase::kBoolValue:
*result = MakePacket<bool>(data.bool_value());
break;
case FieldData::ValueCase::kEnumValue:
*result = MakePacket<ProtoEnum>(data.enum_value());
break;
case FieldData::ValueCase::kStringValue:
*result = MakePacket<std::string>(data.string_value());
break;
case FieldData::ValueCase::kMessageValue: {
auto r = packet_internal::PacketFromDynamicProto(
ParseTypeUrl(std::string(data.message_value().type_url())),
std::string(data.message_value().value()));
if (!r.ok()) {
return r.status();
}
*result = r.value();
break;
}
case FieldData::VALUE_NOT_SET:
*result = Packet();
}
return mediapipe::OkStatus();
}
absl::Status AsFieldData(Packet packet, FieldData* result) {
static const auto* kTypeIds = new std::map<size_t, int32>{
{tool::GetTypeHash<int32>(), WireFormatLite::CPPTYPE_INT32},
{tool::GetTypeHash<int64>(), WireFormatLite::CPPTYPE_INT64},
{tool::GetTypeHash<uint32>(), WireFormatLite::CPPTYPE_UINT32},
{tool::GetTypeHash<uint64>(), WireFormatLite::CPPTYPE_UINT64},
{tool::GetTypeHash<double>(), WireFormatLite::CPPTYPE_DOUBLE},
{tool::GetTypeHash<float>(), WireFormatLite::CPPTYPE_FLOAT},
{tool::GetTypeHash<bool>(), WireFormatLite::CPPTYPE_BOOL},
{tool::GetTypeHash<ProtoEnum>(), WireFormatLite::CPPTYPE_ENUM},
{tool::GetTypeHash<std::string>(), WireFormatLite::CPPTYPE_STRING},
};
if (packet.ValidateAsProtoMessageLite().ok()) {
result->mutable_message_value()->set_value(
packet.GetProtoMessageLite().SerializeAsString());
result->mutable_message_value()->set_type_url(
TypeUrl(packet.GetProtoMessageLite().GetTypeName()));
return mediapipe::OkStatus();
}
if (kTypeIds->count(packet.GetTypeId()) == 0) {
return absl::UnimplementedError(absl::StrCat(
"Cannot construct FieldData for: ", packet.DebugTypeName()));
}
switch (kTypeIds->at(packet.GetTypeId())) {
case WireFormatLite::CPPTYPE_INT32:
result->set_int32_value(packet.Get<int32>());
break;
case WireFormatLite::CPPTYPE_INT64:
result->set_int64_value(packet.Get<int64>());
break;
case WireFormatLite::CPPTYPE_UINT32:
result->set_uint32_value(packet.Get<uint32>());
break;
case WireFormatLite::CPPTYPE_UINT64:
result->set_uint64_value(packet.Get<uint64>());
break;
case WireFormatLite::CPPTYPE_DOUBLE:
result->set_double_value(packet.Get<double>());
break;
case WireFormatLite::CPPTYPE_FLOAT:
result->set_float_value(packet.Get<float>());
break;
case WireFormatLite::CPPTYPE_BOOL:
result->set_bool_value(packet.Get<bool>());
break;
case WireFormatLite::CPPTYPE_ENUM:
result->set_enum_value(packet.Get<ProtoEnum>().value);
break;
case WireFormatLite::CPPTYPE_STRING:
result->set_string_value(packet.Get<std::string>());
break;
}
return mediapipe::OkStatus();
}
std::string TypeUrl(absl::string_view type_name) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name));
}
std::string ParseTypeUrl(absl::string_view type_url) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
if (std::string(type_url).rfind(kTypeUrlPrefix, 0) == 0) {
return std::string(
type_url.substr(kTypeUrlPrefix.length(), std::string::npos));
}
return std::string(type_url);
}
} // namespace options_field_util
} // namespace tool
} // namespace mediapipe

View File

@ -0,0 +1,73 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_
#include <string>
#include <vector>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/tool/field_data.pb.h"
#include "mediapipe/framework/tool/options_registry.h"
namespace mediapipe {
namespace tool {
// Utility to read and write Packet data from protobuf fields.
namespace options_field_util {
// A chain of nested fields and indexes.
using FieldPath = std::vector<std::pair<const FieldDescriptor*, int>>;
// Writes a field value into protobuf field.
absl::Status SetField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data);
// Reads a field value from a protobuf field.
absl::Status GetField(const FieldPath& field_path,
const FieldData& message_data, FieldData* result);
// Merges a field value into nested protobuf Message.
absl::Status MergeField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data);
// Deserializes a packet containing a MessageLite value.
absl::Status ReadMessage(const std::string& value, const std::string& type_name,
Packet* result);
// Merge two options protobuf field values.
absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over,
FieldData* result);
// Returns the options protobuf for a graph.
absl::Status GetOptionsMessage(const CalculatorGraphConfig& config,
FieldData* result);
// Returns the options protobuf for a node.
absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node,
FieldData* result);
// Sets the node_options field in a Node, and clears the options field.
void SetOptionsMessage(const FieldData& node_options,
CalculatorGraphConfig::Node* node);
// Constructs a Packet for a FieldData proto.
absl::Status AsPacket(const FieldData& data, Packet* result);
// Constructs a FieldData proto for a Packet.
absl::Status AsFieldData(Packet packet, FieldData* result);
// Returns the protobuf type-url for a protobuf type-name.
std::string TypeUrl(absl::string_view type_name);
// Returns the protobuf type-name for a protobuf type-url.
std::string ParseTypeUrl(absl::string_view type_url);
} // namespace options_field_util
} // namespace tool
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_

View File

@ -1,47 +1,112 @@
#include "mediapipe/framework/tool/options_registry.h" #include "mediapipe/framework/tool/options_registry.h"
#include "absl/synchronization/mutex.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
proto_ns::DescriptorPool* OptionsRegistry::options_descriptor_pool() { namespace {
static proto_ns::DescriptorPool* result = new proto_ns::DescriptorPool();
return result; // Returns a canonical message type name, with any leading "." removed.
std::string CanonicalTypeName(const std::string& type_name) {
return (type_name.rfind('.', 0) == 0) ? type_name.substr(1) : type_name;
} }
} // namespace
RegistrationToken OptionsRegistry::Register( RegistrationToken OptionsRegistry::Register(
const proto_ns::FileDescriptorSet& files) { const proto_ns::FileDescriptorSet& files) {
absl::MutexLock lock(&mutex());
for (auto& file : files.file()) { for (auto& file : files.file()) {
options_descriptor_pool()->BuildFile(file); for (auto& message_type : file.message_type()) {
Register(message_type, file.package());
}
} }
return RegistrationToken([]() {}); return RegistrationToken([]() {});
} }
const proto_ns::Descriptor* OptionsRegistry::GetProtobufDescriptor( void OptionsRegistry::Register(const proto_ns::DescriptorProto& message_type,
const std::string& type_name) { const std::string& parent_name) {
const proto_ns::Descriptor* result = auto full_name = absl::StrCat(parent_name, ".", message_type.name());
proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName( descriptors()[full_name] = Descriptor(message_type, full_name);
type_name); for (auto& nested : message_type.nested_type()) {
if (!result) { Register(nested, full_name);
result = options_descriptor_pool()->FindMessageTypeByName(type_name);
} }
return result; for (auto& extension : message_type.extension()) {
extensions()[CanonicalTypeName(extension.extendee())].push_back(
FieldDescriptor(extension));
}
}
const Descriptor* OptionsRegistry::GetProtobufDescriptor(
const std::string& type_name) {
absl::ReaderMutexLock lock(&mutex());
auto it = descriptors().find(CanonicalTypeName(type_name));
return (it == descriptors().end()) ? nullptr : &it->second;
} }
void OptionsRegistry::FindAllExtensions( void OptionsRegistry::FindAllExtensions(
const proto_ns::Descriptor& extendee, absl::string_view extendee, std::vector<const FieldDescriptor*>* result) {
std::vector<const proto_ns::FieldDescriptor*>* result) { absl::ReaderMutexLock lock(&mutex());
using proto_ns::DescriptorPool; result->clear();
std::vector<const proto_ns::FieldDescriptor*> extensions; if (extensions().count(extendee) > 0) {
DescriptorPool::generated_pool()->FindAllExtensions(&extendee, &extensions); for (const FieldDescriptor& field : extensions().at(extendee)) {
options_descriptor_pool()->FindAllExtensions(&extendee, &extensions); result->push_back(&field);
absl::flat_hash_set<int> numbers;
for (const proto_ns::FieldDescriptor* extension : extensions) {
bool inserted = numbers.insert(extension->number()).second;
if (inserted) {
result->push_back(extension);
} }
} }
} }
absl::flat_hash_map<std::string, Descriptor>& OptionsRegistry::descriptors() {
static auto* descriptors = new absl::flat_hash_map<std::string, Descriptor>();
return *descriptors;
}
absl::flat_hash_map<std::string, std::vector<FieldDescriptor>>&
OptionsRegistry::extensions() {
static auto* extensions =
new absl::flat_hash_map<std::string, std::vector<FieldDescriptor>>();
return *extensions;
}
absl::Mutex& OptionsRegistry::mutex() {
static auto* mutex = new absl::Mutex();
return *mutex;
}
Descriptor::Descriptor(const proto_ns::DescriptorProto& proto,
const std::string& full_name)
: full_name_(full_name) {
for (auto& field : proto.field()) {
fields_[field.name()] = FieldDescriptor(field);
}
}
const std::string& Descriptor::full_name() const { return full_name_; }
const FieldDescriptor* Descriptor::FindFieldByName(
const std::string& name) const {
auto it = fields_.find(name);
return (it != fields_.end()) ? &it->second : nullptr;
}
FieldDescriptor::FieldDescriptor(const proto_ns::FieldDescriptorProto& proto) {
name_ = proto.name();
message_type_ = CanonicalTypeName(proto.type_name());
type_ = proto.type();
number_ = proto.number();
}
const std::string& FieldDescriptor::name() const { return name_; }
int FieldDescriptor::number() const { return number_; }
proto_ns::FieldDescriptorProto::Type FieldDescriptor::type() const {
return type_;
}
const Descriptor* FieldDescriptor::message_type() const {
return OptionsRegistry::GetProtobufDescriptor(message_type_);
}
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe

View File

@ -1,12 +1,16 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
#include "absl/container/flat_hash_map.h"
#include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/advanced_proto_inc.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
class Descriptor;
class FieldDescriptor;
// A static registry that stores descriptors for protobufs used in MediaPipe // A static registry that stores descriptors for protobufs used in MediaPipe
// calculator options. Lite-proto builds do not normally include descriptors. // calculator options. Lite-proto builds do not normally include descriptors.
// These registered descriptors allow individual protobuf fields to be // These registered descriptors allow individual protobuf fields to be
@ -17,23 +21,60 @@ class OptionsRegistry {
static RegistrationToken Register(const proto_ns::FileDescriptorSet& files); static RegistrationToken Register(const proto_ns::FileDescriptorSet& files);
// Finds the descriptor for a protobuf. // Finds the descriptor for a protobuf.
static const proto_ns::Descriptor* GetProtobufDescriptor( static const Descriptor* GetProtobufDescriptor(const std::string& type_name);
const std::string& type_name);
// Returns all known proto2 extensions to a type. // Returns all known proto2 extensions to a type.
static void FindAllExtensions( static void FindAllExtensions(absl::string_view extendee,
const proto_ns::Descriptor& extendee, std::vector<const FieldDescriptor*>* result);
std::vector<const proto_ns::FieldDescriptor*>* result);
private: private:
// Stores the descriptors for each options protobuf type. // Registers protobuf descriptors a MessageLite and nested types.
static proto_ns::DescriptorPool* options_descriptor_pool(); static void Register(const proto_ns::DescriptorProto& message_type,
const std::string& parent_name);
static absl::flat_hash_map<std::string, Descriptor>& descriptors();
static absl::flat_hash_map<std::string, std::vector<FieldDescriptor>>&
extensions();
static absl::Mutex& mutex();
// Registers the descriptors for each options protobuf type. // Registers the descriptors for each options protobuf type.
template <class MessageT> template <class MessageT>
static const RegistrationToken registration_token; static const RegistrationToken registration_token;
}; };
// A custom implementation proto_ns::Descriptor. This implementation
// avoids a code size problem introduced by proto_ns::FieldDescriptor.
class Descriptor {
public:
Descriptor() {}
Descriptor(const proto_ns::DescriptorProto& proto,
const std::string& full_name);
const std::string& full_name() const;
const FieldDescriptor* FindFieldByName(const std::string& name) const;
private:
std::string full_name_;
absl::flat_hash_map<std::string, FieldDescriptor> fields_;
};
// A custom implementation proto_ns::FieldDescriptor. This implementation
// avoids a code size problem introduced by proto_ns::FieldDescriptor.
class FieldDescriptor {
public:
FieldDescriptor() {}
FieldDescriptor(const proto_ns::FieldDescriptorProto& proto);
const std::string& name() const;
int number() const;
proto_ns::FieldDescriptorProto::Type type() const;
const Descriptor* message_type() const;
private:
std::string name_;
std::string message_type_;
proto_ns::FieldDescriptorProto::Type type_;
int number_;
};
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,143 @@
#include "mediapipe/framework/tool/options_syntax_util.h"
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h"
namespace mediapipe {
namespace tool {
// Helper functions for parsing the graph options syntax.
class OptionsSyntaxUtil::OptionsSyntaxHelper {
public:
// The usual graph options syntax tokens.
OptionsSyntaxHelper() : syntax_{"OPTIONS", "options", "/"} {}
// Returns the tag name for an option protobuf field.
std::string OptionFieldTag(const std::string& name) { return name; }
// Returns the packet name for an option protobuf field.
absl::string_view OptionFieldPacket(absl::string_view name) { return name; }
// Returns the option protobuf field name for a tag or packet name.
absl::string_view OptionFieldName(absl::string_view name) { return name; }
// Returns the field-path for an option stream-tag.
FieldPath OptionFieldPath(const std::string& tag,
const Descriptor* descriptor) {
int prefix = syntax_.tag_name.length() + syntax_.separator.length();
std::string suffix = tag.substr(prefix);
std::vector<absl::string_view> name_tags =
absl::StrSplit(suffix, syntax_.separator);
FieldPath result;
for (absl::string_view name_tag : name_tags) {
if (name_tag.empty()) {
continue;
}
absl::string_view option_name = OptionFieldName(name_tag);
int index;
if (absl::SimpleAtoi(option_name, &index)) {
result.back().second = index;
} else {
auto field = descriptor->FindFieldByName(std::string(option_name));
descriptor = field ? field->message_type() : nullptr;
result.push_back({std::move(field), 0});
}
}
return result;
}
// Returns the option field name for a graph options packet name.
std::string GraphOptionFieldName(const std::string& graph_option_name) {
int prefix = syntax_.packet_name.length() + syntax_.separator.length();
std::string result = graph_option_name;
result.erase(0, prefix);
return result;
}
// Returns the graph options packet name for an option field name.
std::string GraphOptionName(const std::string& option_name) {
std::string packet_prefix =
syntax_.packet_name + absl::AsciiStrToLower(syntax_.separator);
return absl::StrCat(packet_prefix, option_name);
}
// Returns the tag name for a graph option.
std::string OptionTagName(const std::string& option_name) {
return absl::StrCat(syntax_.tag_name, syntax_.separator,
OptionFieldTag(option_name));
}
// Converts slash-separated field names into a tag name.
std::string OptionFieldsTag(const std::string& option_names) {
std::string tag_prefix = syntax_.tag_name + syntax_.separator;
std::vector<absl::string_view> names = absl::StrSplit(option_names, '/');
if (!names.empty() && names[0] == syntax_.tag_name) {
names.erase(names.begin());
}
if (!names.empty() && names[0] == syntax_.packet_name) {
names.erase(names.begin());
}
std::string result;
std::string sep = "";
for (absl::string_view v : names) {
absl::StrAppend(&result, sep, OptionFieldTag(std::string(v)));
sep = syntax_.separator;
}
result = tag_prefix + result;
return result;
}
// Token definitions for the graph options syntax.
struct OptionsSyntax {
// The tag name for an options protobuf.
std::string tag_name;
// The packet name for an options protobuf.
std::string packet_name;
// The separator between nested options fields.
std::string separator;
};
OptionsSyntax syntax_;
}; // class OptionsSyntaxHelper
OptionsSyntaxUtil::OptionsSyntaxUtil()
: syntax_helper_(std::make_unique<OptionsSyntaxHelper>()) {}
OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name)
: OptionsSyntaxUtil() {
syntax_helper_->syntax_.tag_name = tag_name;
}
OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name,
const std::string& packet_name,
const std::string& separator)
: OptionsSyntaxUtil() {
syntax_helper_->syntax_.tag_name = tag_name;
syntax_helper_->syntax_.packet_name = packet_name;
syntax_helper_->syntax_.separator = separator;
}
OptionsSyntaxUtil::~OptionsSyntaxUtil() {}
std::string OptionsSyntaxUtil::OptionFieldsTag(
const std::string& option_names) {
return syntax_helper_->OptionFieldsTag(option_names);
}
OptionsSyntaxUtil::FieldPath OptionsSyntaxUtil::OptionFieldPath(
const std::string& tag, const Descriptor* descriptor) {
return syntax_helper_->OptionFieldPath(tag, descriptor);
}
} // namespace tool
} // namespace mediapipe

View File

@ -0,0 +1,45 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_
#include <memory>
#include <string>
#include <vector>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/tool/options_field_util.h"
#include "mediapipe/framework/tool/options_registry.h"
namespace mediapipe {
namespace tool {
// Utility to parse the graph options syntax used in "option_value",
// "side_packet", and "stream".
class OptionsSyntaxUtil {
public:
using FieldPath = options_field_util::FieldPath;
OptionsSyntaxUtil();
OptionsSyntaxUtil(const std::string& tag_name);
OptionsSyntaxUtil(const std::string& tag_name, const std::string& packet_name,
const std::string& separator);
~OptionsSyntaxUtil();
// Converts slash-separated field names into a tag name.
std::string OptionFieldsTag(const std::string& option_names);
// Returns the field-path for an option stream-tag.
FieldPath OptionFieldPath(const std::string& tag,
const Descriptor* descriptor);
private:
class OptionsSyntaxHelper;
std::unique_ptr<OptionsSyntaxHelper> syntax_helper_;
};
} // namespace tool
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_

View File

@ -1,16 +1,82 @@
#include "mediapipe/framework/tool/options_util.h" #include "mediapipe/framework/tool/options_util.h"
#include "mediapipe/framework/port/proto_ns.h" #include <memory>
#include <string>
#include <variant>
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/input_stream_shard.h"
#include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/options_field_util.h"
#include "mediapipe/framework/tool/options_registry.h"
#include "mediapipe/framework/tool/options_syntax_util.h"
#include "mediapipe/framework/tool/proto_util_lite.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
// TODO: Return registered protobuf Descriptors when available. // Copy literal options from graph_options to node_options.
const proto_ns::Descriptor* GetProtobufDescriptor( absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
const std::string& type_name) { CalculatorGraphConfig* config) {
return proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName( Status status;
type_name); FieldData config_options, parent_node_options, graph_options;
status.Update(
options_field_util::GetOptionsMessage(*config, &config_options));
status.Update(
options_field_util::GetOptionsMessage(parent_node, &parent_node_options));
status.Update(options_field_util::MergeOptionsMessages(
config_options, parent_node_options, &graph_options));
const Descriptor* options_descriptor =
OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl(
std::string(graph_options.message_value().type_url())));
if (!options_descriptor) {
return status;
}
OptionsSyntaxUtil syntax_util;
for (auto& node : *config->mutable_node()) {
FieldData node_data;
status.Update(options_field_util::GetOptionsMessage(node, &node_data));
if (!node_data.has_message_value() || node.option_value_size() == 0) {
continue;
}
const Descriptor* node_options_descriptor =
OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl(
std::string(node_data.message_value().type_url())));
if (!node_options_descriptor) {
continue;
}
for (const std::string& option_def : node.option_value()) {
std::vector<std::string> tag_and_name = absl::StrSplit(option_def, ':');
std::string graph_tag = syntax_util.OptionFieldsTag(tag_and_name[1]);
std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]);
FieldData packet_data;
status.Update(options_field_util::GetField(
syntax_util.OptionFieldPath(graph_tag, options_descriptor),
graph_options, &packet_data));
status.Update(options_field_util::MergeField(
syntax_util.OptionFieldPath(node_tag, node_options_descriptor),
packet_data, &node_data));
}
options_field_util::SetOptionsMessage(node_data, &node);
}
return status;
}
// Makes all configuration modifications needed for graph options.
absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node,
CalculatorGraphConfig* config) {
MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config));
return mediapipe::OkStatus();
} }
} // namespace tool } // namespace tool

View File

@ -21,7 +21,6 @@
#include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/port/any_proto.h"
#include "mediapipe/framework/tool/options_map.h" #include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe { namespace mediapipe {
@ -75,8 +74,9 @@ inline T RetrieveOptions(const T& base, const InputStreamShardSet& stream_set,
return base; return base;
} }
// Finds the descriptor for a protobuf. // Copy literal options from enclosing graphs.
const proto_ns::Descriptor* GetProtobufDescriptor(const std::string& type_name); absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node,
CalculatorGraphConfig* config);
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe

View File

@ -16,15 +16,41 @@
#include <vector> #include <vector>
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/message_matchers.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/testdata/night_light_calculator.pb.h" #include "mediapipe/framework/testdata/night_light_calculator.pb.h"
#include "mediapipe/framework/tool/node_chain_subgraph.pb.h"
#include "mediapipe/framework/tool/options_registry.h" #include "mediapipe/framework/tool/options_registry.h"
#include "mediapipe/framework/tool/options_syntax_util.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using ::mediapipe::proto_ns::FieldDescriptorProto;
using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type;
// A test Calculator using DeclareOptions and DefineOptions.
class NightLightCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
return mediapipe::OkStatus();
}
absl::Status Open(CalculatorContext* cc) final {
return mediapipe::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final {
return mediapipe::OkStatus();
}
private:
NightLightCalculatorOptions options_;
};
REGISTER_CALCULATOR(NightLightCalculator);
// Tests for calculator and graph options. // Tests for calculator and graph options.
// //
class OptionsUtilTest : public ::testing::Test { class OptionsUtilTest : public ::testing::Test {
@ -35,21 +61,108 @@ class OptionsUtilTest : public ::testing::Test {
// Retrieves the description of a protobuf. // Retrieves the description of a protobuf.
TEST_F(OptionsUtilTest, GetProtobufDescriptor) { TEST_F(OptionsUtilTest, GetProtobufDescriptor) {
const proto_ns::Descriptor* descriptor = const tool::Descriptor* descriptor =
tool::GetProtobufDescriptor("mediapipe.CalculatorGraphConfig"); tool::OptionsRegistry::GetProtobufDescriptor(
#ifndef MEDIAPIPE_MOBILE "mediapipe.CalculatorGraphConfig");
EXPECT_NE(nullptr, descriptor); EXPECT_NE(nullptr, descriptor);
#else
EXPECT_EQ(nullptr, descriptor);
#endif
} }
// Retrieves the description of a protobuf from the OptionsRegistry. // Shows a calculator node deriving options from graph options.
// The subgraph specifies "graph_options" as "NodeChainSubgraphOptions".
// The calculator specifies "node_options as "NightLightCalculatorOptions".
TEST_F(OptionsUtilTest, CopyLiteralOptions) {
CalculatorGraphConfig subgraph_config;
auto node = subgraph_config.add_node();
*node->mutable_calculator() = "NightLightCalculator";
*node->add_option_value() = "num_lights:options/chain_length";
// The options framework requires at least an empty options protobuf
// as an indication the options protobuf type expected by the node.
NightLightCalculatorOptions node_options;
node->add_node_options()->PackFrom(node_options);
NodeChainSubgraphOptions options;
options.set_chain_length(8);
subgraph_config.add_graph_options()->PackFrom(options);
subgraph_config.set_type("NightSubgraph");
CalculatorGraphConfig graph_config;
node = graph_config.add_node();
*node->mutable_calculator() = "NightSubgraph";
CalculatorGraph graph;
graph_config.set_num_threads(4);
MP_EXPECT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {}));
CalculatorGraphConfig expanded_config = graph.Config();
expanded_config.clear_executor();
CalculatorGraphConfig::Node actual_node;
actual_node = expanded_config.node(0);
CalculatorGraphConfig::Node expected_node;
expected_node.set_name("nightsubgraph__NightLightCalculator");
expected_node.set_calculator("NightLightCalculator");
NightLightCalculatorOptions expected_node_options;
expected_node_options.add_num_lights(8);
expected_node.add_node_options()->PackFrom(expected_node_options);
*expected_node.add_option_value() = "num_lights:options/chain_length";
EXPECT_THAT(actual_node, EqualsProto(expected_node));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
// Ensure static protobuf packet registration.
MakePacket<NodeChainSubgraphOptions>();
MakePacket<NightLightCalculatorOptions>();
}
// Retrieves the description of a protobuf message and a nested protobuf message
// from the OptionsRegistry.
TEST_F(OptionsUtilTest, GetProtobufDescriptorRegistered) { TEST_F(OptionsUtilTest, GetProtobufDescriptorRegistered) {
const proto_ns::Descriptor* descriptor = const tool::Descriptor* options_descriptor =
tool::OptionsRegistry::GetProtobufDescriptor( tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.NightLightCalculatorOptions"); "mediapipe.NightLightCalculatorOptions");
EXPECT_NE(nullptr, descriptor); EXPECT_NE(nullptr, options_descriptor);
const tool::Descriptor* bundle_descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.NightLightCalculatorOptions.LightBundle");
EXPECT_NE(nullptr, bundle_descriptor);
EXPECT_EQ(options_descriptor->full_name(),
"mediapipe.NightLightCalculatorOptions");
const tool::FieldDescriptor* bundle_field =
options_descriptor->FindFieldByName("bundle");
EXPECT_EQ(bundle_field->message_type(), bundle_descriptor);
}
// Constructs the FieldPath for a nested node-option.
TEST_F(OptionsUtilTest, OptionsSyntaxUtil) {
const tool::Descriptor* descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
"mediapipe.NightLightCalculatorOptions");
std::string tag;
tool::OptionsSyntaxUtil::FieldPath field_path;
{
// The default tag syntax.
tool::OptionsSyntaxUtil syntax_util;
tag = syntax_util.OptionFieldsTag("options/sub_options/num_lights");
EXPECT_EQ(tag, "OPTIONS/sub_options/num_lights");
field_path = syntax_util.OptionFieldPath(tag, descriptor);
EXPECT_EQ(field_path.size(), 2);
EXPECT_EQ(field_path[0].first->name(), "sub_options");
EXPECT_EQ(field_path[1].first->name(), "num_lights");
}
{
// A tag syntax with a text-coded separator.
tool::OptionsSyntaxUtil syntax_util("OPTIONS", "options", "_Z0Z_");
tag = syntax_util.OptionFieldsTag("options/sub_options/num_lights");
EXPECT_EQ(tag, "OPTIONS_Z0Z_sub_options_Z0Z_num_lights");
field_path = syntax_util.OptionFieldPath(tag, descriptor);
EXPECT_EQ(field_path.size(), 2);
EXPECT_EQ(field_path[0].first->name(), "sub_options");
EXPECT_EQ(field_path[1].first->name(), "num_lights");
}
} }
} // namespace } // namespace

View File

@ -196,6 +196,27 @@ absl::Status ProtoUtilLite::GetFieldRange(
return absl::OkStatus(); return absl::OkStatus();
} }
// Returns the number of field values in a repeated protobuf field.
absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message,
ProtoPath proto_path,
FieldType field_type,
int* field_count) {
int field_id, index;
std::tie(field_id, index) = proto_path.back();
proto_path.pop_back();
std::vector<std::string> parent;
if (proto_path.empty()) {
parent.push_back(std::string(message));
} else {
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent));
}
FieldAccess access(field_id, field_type);
MP_RETURN_IF_ERROR(access.SetMessage(parent[0]));
*field_count = access.mutable_field_values()->size();
return absl::OkStatus();
}
// If ok, returns OkStatus, otherwise returns InvalidArgumentError. // If ok, returns OkStatus, otherwise returns InvalidArgumentError.
template <typename T> template <typename T>
absl::Status SyntaxStatus(bool ok, const std::string& text, T* result) { absl::Status SyntaxStatus(bool ok, const std::string& text, T* result) {

View File

@ -75,6 +75,11 @@ class ProtoUtilLite {
FieldType field_type, FieldType field_type,
std::vector<FieldValue>* field_values); std::vector<FieldValue>* field_values);
// Returns the number of field values in a repeated protobuf field.
static absl::Status GetFieldCount(const FieldValue& message,
ProtoPath proto_path, FieldType field_type,
int* field_count);
// Serialize one or more protobuf field values from text. // Serialize one or more protobuf field values from text.
static absl::Status Serialize(const std::vector<std::string>& text_values, static absl::Status Serialize(const std::vector<std::string>& text_values,
FieldType field_type, FieldType field_type,

View File

@ -278,6 +278,8 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
graph_registry = graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
RET_CHECK(config); RET_CHECK(config);
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
CalculatorGraphConfig::Node(), config));
auto* nodes = config->mutable_node(); auto* nodes = config->mutable_node();
while (1) { while (1) {
auto subgraph_nodes_start = std::stable_partition( auto subgraph_nodes_start = std::stable_partition(
@ -297,6 +299,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName( ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName(
config->package(), node.calculator(), config->package(), node.calculator(),
&subgraph_context)); &subgraph_context));
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(node, &subgraph));
MP_RETURN_IF_ERROR(PrefixNames(node_name, &subgraph)); MP_RETURN_IF_ERROR(PrefixNames(node_name, &subgraph));
MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph)); MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph));
subgraphs.push_back(subgraph); subgraphs.push_back(subgraph);

View File

@ -99,6 +99,26 @@ const GLchar* const kScaledVertexShader = VERTEX_PREAMBLE _STRINGIFY(
sample_coordinate = texture_coordinate.xy; sample_coordinate = texture_coordinate.xy;
}); });
const GLchar* const kTransformedVertexShader = VERTEX_PREAMBLE _STRINGIFY(
in vec4 position; in mediump vec4 texture_coordinate;
out mediump vec2 sample_coordinate; uniform mat3 transform;
uniform vec2 viewport_size;
void main() {
// switch from clip to viewport aspect ratio in order to properly
// apply transformation
vec2 half_viewport_size = viewport_size * 0.5;
vec3 pos = vec3(position.xy * half_viewport_size, 1);
// apply transform
pos = transform * pos;
// switch back to clip space
gl_Position = vec4(pos.xy / half_viewport_size, 0, 1);
sample_coordinate = texture_coordinate.xy;
});
const GLchar* const kBasicTexturedFragmentShader = FRAGMENT_PREAMBLE _STRINGIFY( const GLchar* const kBasicTexturedFragmentShader = FRAGMENT_PREAMBLE _STRINGIFY(
DEFAULT_PRECISION(mediump, float) DEFAULT_PRECISION(mediump, float)

View File

@ -38,6 +38,17 @@ extern const GLchar* const kBasicVertexShader;
// vec2 sample_coordinate - texture coordinate for shader // vec2 sample_coordinate - texture coordinate for shader
extern const GLchar* const kScaledVertexShader; extern const GLchar* const kScaledVertexShader;
// Applies an affine transformation to the vertex and leaves texture coordinates
// as is. Input attributes:
// vec4 position - vertex position
// vec4 texture_coordinate - texture coordinate
// Input uniform:
// mat3 homogeneous affine transform - transformation matrix for vertices
// vec2 viewport_size - size of the viewport
// Output varying:
// vec2 sample_coordinate - texture coordinate for shader
extern const GLchar* const kTransformedVertexShader;
// Outputs the texture as it is. // Outputs the texture as it is.
// Input varying: // Input varying:
// vec2 sample_coordinate - texture coordinate // vec2 sample_coordinate - texture coordinate

View File

@ -77,16 +77,24 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
}}, }},
{GpuBufferFormat::kOneComponent8, {GpuBufferFormat::kOneComponent8,
{ {
// This should be GL_RED, but it would change the output for existing // This format is like RGBA grayscale: GL_LUMINANCE replicates
// shaders. It would not be a good representation of a grayscale texture, // the single channel texel values to RGB channels, and set alpha
// unless we use texture swizzling. We could add swizzle parameters (e.g. // to 1.0. If it is desired to see only the texel values in the R
// GL_TEXTURE_SWIZZLE_R) in GLES 3 and desktop GL, and use GL_LUMINANCE // channel, use kOneComponent8Red instead.
// in GLES 2. Or we could just punt and make it a red texture.
// {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
#if !TARGET_OS_OSX #if !TARGET_OS_OSX
{GL_LUMINANCE, GL_LUMINANCE, GL_UNSIGNED_BYTE, 1}, {GL_LUMINANCE, GL_LUMINANCE, GL_UNSIGNED_BYTE, 1},
#else
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
#endif // TARGET_OS_OSX #endif // TARGET_OS_OSX
}}, }},
{GpuBufferFormat::kOneComponent8Red,
{
{GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1},
}},
{GpuBufferFormat::kTwoComponent8,
{
{GL_RG8, GL_RG, GL_UNSIGNED_BYTE, 1},
}},
#ifdef __APPLE__ #ifdef __APPLE__
// TODO: figure out GL_RED_EXT etc. on Android. // TODO: figure out GL_RED_EXT etc. on Android.
{GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange, {GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange,
@ -195,6 +203,8 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
case GpuBufferFormat::kTwoComponentFloat32: case GpuBufferFormat::kTwoComponentFloat32:
return ImageFormat::VEC32F2; return ImageFormat::VEC32F2;
case GpuBufferFormat::kGrayHalf16: case GpuBufferFormat::kGrayHalf16:
case GpuBufferFormat::kOneComponent8Red:
case GpuBufferFormat::kTwoComponent8:
case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kTwoComponentHalf16:
case GpuBufferFormat::kRGBAHalf64: case GpuBufferFormat::kRGBAHalf64:
case GpuBufferFormat::kRGBAFloat128: case GpuBufferFormat::kRGBAFloat128:

View File

@ -38,6 +38,8 @@ enum class GpuBufferFormat : uint32_t {
kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'), kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'),
kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'), kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'),
kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'), kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'),
kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'),
kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'),
kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'), kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'),
kTwoComponentFloat32 = MEDIAPIPE_FOURCC('2', 'C', '0', 'f'), kTwoComponentFloat32 = MEDIAPIPE_FOURCC('2', 'C', '0', 'f'),
kBiPlanar420YpCbCr8VideoRange = MEDIAPIPE_FOURCC('4', '2', '0', 'v'), kBiPlanar420YpCbCr8VideoRange = MEDIAPIPE_FOURCC('4', '2', '0', 'v'),
@ -82,6 +84,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
return kCVPixelFormatType_OneComponent32Float; return kCVPixelFormatType_OneComponent32Float;
case GpuBufferFormat::kOneComponent8: case GpuBufferFormat::kOneComponent8:
return kCVPixelFormatType_OneComponent8; return kCVPixelFormatType_OneComponent8;
case GpuBufferFormat::kOneComponent8Red:
return -1;
case GpuBufferFormat::kTwoComponent8:
return kCVPixelFormatType_TwoComponent8;
case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kTwoComponentHalf16:
return kCVPixelFormatType_TwoComponent16Half; return kCVPixelFormatType_TwoComponent16Half;
case GpuBufferFormat::kTwoComponentFloat32: case GpuBufferFormat::kTwoComponentFloat32:
@ -114,6 +120,8 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) {
return GpuBufferFormat::kGrayFloat32; return GpuBufferFormat::kGrayFloat32;
case kCVPixelFormatType_OneComponent8: case kCVPixelFormatType_OneComponent8:
return GpuBufferFormat::kOneComponent8; return GpuBufferFormat::kOneComponent8;
case kCVPixelFormatType_TwoComponent8:
return GpuBufferFormat::kTwoComponent8;
case kCVPixelFormatType_TwoComponent16Half: case kCVPixelFormatType_TwoComponent16Half:
return GpuBufferFormat::kTwoComponentHalf16; return GpuBufferFormat::kTwoComponentHalf16;
case kCVPixelFormatType_TwoComponent32Float: case kCVPixelFormatType_TwoComponent32Float:

View File

@ -19,6 +19,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.TextureFrame; import com.google.mediapipe.framework.TextureFrame;
import java.util.ArrayList;
import java.util.List; import java.util.List;
/** /**
@ -32,6 +33,9 @@ public class ImageSolutionResult implements SolutionResult {
private Bitmap cachedBitmap; private Bitmap cachedBitmap;
// A list of the output image packets produced by the graph. // A list of the output image packets produced by the graph.
protected List<Packet> imageResultPackets; protected List<Packet> imageResultPackets;
// The cached texture frames.
protected List<TextureFrame> imageResultTextureFrames;
private TextureFrame cachedTextureFrame;
// Result timestamp, which is set to the timestamp of the corresponding input image. May return // Result timestamp, which is set to the timestamp of the corresponding input image. May return
// Long.MIN_VALUE if the input image is not associated with a timestamp. // Long.MIN_VALUE if the input image is not associated with a timestamp.
@ -61,8 +65,38 @@ public class ImageSolutionResult implements SolutionResult {
return PacketGetter.getTextureFrame(imagePacket); return PacketGetter.getTextureFrame(imagePacket);
} }
// Returns the cached input image as a {@link TextureFrame}.
public TextureFrame getCachedInputTextureFrame() {
return cachedTextureFrame;
}
// Produces all texture frames from image packets and caches them for further use. The caller must
// release the cached {@link TextureFrame}s after using.
void produceAllTextureFrames() {
cachedTextureFrame = acquireInputTextureFrame();
if (imageResultPackets == null) {
return;
}
imageResultTextureFrames = new ArrayList<>();
for (Packet p : imageResultPackets) {
imageResultTextureFrames.add(PacketGetter.getTextureFrame(p));
}
}
// Releases all cached {@link TextureFrame}s.
void releaseCachedTextureFrames() {
if (cachedTextureFrame != null) {
cachedTextureFrame.release();
}
if (imageResultTextureFrames != null) {
for (TextureFrame textureFrame : imageResultTextureFrames) {
textureFrame.release();
}
}
}
// Releases image packet and the underlying data. // Releases image packet and the underlying data.
void releaseImagePacket() { void releaseImagePackets() {
imagePacket.release(); imagePacket.release();
if (imageResultPackets != null) { if (imageResultPackets != null) {
for (Packet p : imageResultPackets) { for (Packet p : imageResultPackets) {

View File

@ -79,7 +79,7 @@ public class OutputHandler<T extends SolutionResult> {
} }
if (solutionResult instanceof ImageSolutionResult) { if (solutionResult instanceof ImageSolutionResult) {
ImageSolutionResult imageSolutionResult = (ImageSolutionResult) solutionResult; ImageSolutionResult imageSolutionResult = (ImageSolutionResult) solutionResult;
imageSolutionResult.releaseImagePacket(); imageSolutionResult.releaseImagePackets();
} }
} }
} }

View File

@ -50,13 +50,25 @@ public class SolutionGlSurfaceView<T extends ImageSolutionResult> extends GLSurf
} }
/** /**
* Sets the next textureframe and solution result to render. * Sets the next input {@link TextureFrame} and solution result to render.
* *
* @param solutionResult a solution result object that contains the solution outputs and a * @param solutionResult a solution result object that contains the solution outputs and a
* textureframe. * textureframe.
*/ */
public void setRenderData(T solutionResult) { public void setRenderData(T solutionResult) {
renderer.setRenderData(solutionResult); renderer.setRenderData(solutionResult, false);
}
/**
* Sets the next input {@link TextureFrame} and solution result to render.
*
* @param solutionResult a solution result object that contains the solution outputs and a {@link
* TextureFrame}.
* @param produceTextureFrames whether to produce and cache all the {@link TextureFrame}s for
* further use.
*/
public void setRenderData(T solutionResult, boolean produceTextureFrames) {
renderer.setRenderData(solutionResult, produceTextureFrames);
} }
/** Sets if the input image needs to be rendered. Default to true. */ /** Sets if the input image needs to be rendered. Default to true. */

View File

@ -27,9 +27,10 @@ import javax.microedition.khronos.opengles.GL10;
* MediaPipe Solution's GlSurfaceViewRenderer. * MediaPipe Solution's GlSurfaceViewRenderer.
* *
* <p>Users can provide a custom {@link ResultGlRenderer} for rendering MediaPipe solution results. * <p>Users can provide a custom {@link ResultGlRenderer} for rendering MediaPipe solution results.
* For setting the latest solution result, call {@link #setRenderData(ImageSolutionResult)}. By * For setting the latest solution result, call {@link #setRenderData(ImageSolutionResult,
* default, the renderer renders the input images. Call {@link #setRenderInputImage(boolean)} to * boolean)}. By default, the renderer renders the input images. Call {@link
* explicitly set whether the input images should be rendered or not. * #setRenderInputImage(boolean)} to explicitly set whether the input images should be rendered or
* not.
*/ */
public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult> public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
extends GlSurfaceViewRenderer { extends GlSurfaceViewRenderer {
@ -49,16 +50,24 @@ public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
} }
/** /**
* Sets the next textureframe and solution result to render. * Sets the next input {@link TextureFrame} and solution result to render.
* *
* @param solutionResult a solution result object that contains the solution outputs and a * @param solutionResult a solution result object that contains the solution outputs and a
* textureframe. * textureframe.
* @param produceTextureFrames whether to produce and cache all the {@link TextureFrame}s for
* further use.
*/ */
public void setRenderData(T solutionResult) { public void setRenderData(T solutionResult, boolean produceTextureFrames) {
TextureFrame frame = solutionResult.acquireInputTextureFrame(); TextureFrame frame = solutionResult.acquireInputTextureFrame();
setFrameSize(frame.getWidth(), frame.getHeight()); setFrameSize(frame.getWidth(), frame.getHeight());
setNextFrame(frame); setNextFrame(frame);
nextSolutionResult.getAndSet(solutionResult); if (produceTextureFrames) {
solutionResult.produceAllTextureFrames();
}
T oldSolutionResult = nextSolutionResult.getAndSet(solutionResult);
if (oldSolutionResult != null) {
oldSolutionResult.releaseCachedTextureFrames();
}
} }
@Override @Override
@ -78,8 +87,9 @@ public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
GLES20.glActiveTexture(GLES20.GL_TEXTURE0); GLES20.glActiveTexture(GLES20.GL_TEXTURE0);
ShaderUtil.checkGlError("glActiveTexture"); ShaderUtil.checkGlError("glActiveTexture");
} }
T solutionResult = null;
if (nextSolutionResult != null) { if (nextSolutionResult != null) {
T solutionResult = nextSolutionResult.getAndSet(null); solutionResult = nextSolutionResult.getAndSet(null);
float[] textureBoundary = calculateTextureBoundary(); float[] textureBoundary = calculateTextureBoundary();
// Scales the values from [0, 1] to [-1, 1]. // Scales the values from [0, 1] to [-1, 1].
ResultGlBoundary resultGlBoundary = ResultGlBoundary resultGlBoundary =
@ -91,6 +101,9 @@ public class SolutionGlSurfaceViewRenderer<T extends ImageSolutionResult>
resultGlRenderer.renderResult(solutionResult, resultGlBoundary); resultGlRenderer.renderResult(solutionResult, resultGlBoundary);
} }
flush(frame); flush(frame);
if (solutionResult != null) {
solutionResult.releaseCachedTextureFrames();
}
} }
@Override @Override

View File

@ -64,9 +64,10 @@ node {
options: { options: {
[mediapipe.InferenceCalculatorOptions.ext] { [mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite"
delegate { xnnpack {} } delegate {
xnnpack {}
}
} }
#
} }
} }

View File

@ -50,9 +50,10 @@ node {
options: { options: {
[mediapipe.InferenceCalculatorOptions.ext] { [mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite" model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
delegate { xnnpack {} } delegate {
xnnpack {}
}
} }
#
} }
} }

View File

@ -72,9 +72,10 @@ node {
options: { options: {
[mediapipe.InferenceCalculatorOptions.ext] { [mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/pose_detection/pose_detection.tflite" model_path: "mediapipe/modules/pose_detection/pose_detection.tflite"
delegate { xnnpack {} } delegate {
xnnpack {}
}
} }
#
} }
} }

View File

@ -245,7 +245,7 @@ node {
output_stream: "enabled_segmentation_tensor" output_stream: "enabled_segmentation_tensor"
options: { options: {
[mediapipe.GateCalculatorOptions.ext] { [mediapipe.GateCalculatorOptions.ext] {
allow: true allow: false
} }
} }
} }

View File

@ -96,9 +96,10 @@ node {
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
options: { options: {
[mediapipe.InferenceCalculatorOptions.ext] { [mediapipe.InferenceCalculatorOptions.ext] {
delegate { xnnpack {} } delegate {
xnnpack {}
}
} }
#
} }
} }

View File

@ -288,7 +288,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
calculator_graph.def( calculator_graph.def(
"wait_until_done", "wait_until_done",
[](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->WaitUntilDone()); }, [](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
},
R"doc(Wait for the current run to finish. R"doc(Wait for the current run to finish.
A blocking call to wait for the current run to finish (block the current A blocking call to wait for the current run to finish (block the current
@ -313,7 +316,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
calculator_graph.def( calculator_graph.def(
"wait_until_idle", "wait_until_idle",
[](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->WaitUntilIdle()); }, [](CalculatorGraph* self) {
py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilIdle(), /**acquire_gil=*/true);
},
R"doc(Wait until the running graph is in the idle mode. R"doc(Wait until the running graph is in the idle mode.
Wait until the running graph is in the idle mode, which is when nothing can Wait until the running graph is in the idle mode, which is when nothing can
@ -399,12 +405,9 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
stream_name, stream_name,
[callback_fn, stream_name](const Packet& packet) { [callback_fn, stream_name](const Packet& packet) {
absl::MutexLock lock(&callback_mutex); absl::MutexLock lock(&callback_mutex);
py::gil_scoped_release gil_release; // Acquires GIL before calling Python callback.
{ py::gil_scoped_acquire gil_acquire;
// Acquires GIL before calling Python callback. callback_fn(stream_name, packet);
py::gil_scoped_acquire gil_acquire;
callback_fn(stream_name, packet);
}
return absl::OkStatus(); return absl::OkStatus();
}, },
observe_timestamp_bounds)); observe_timestamp_bounds));
@ -439,7 +442,8 @@ void CalculatorGraphSubmodule(pybind11::module* module) {
"close", "close",
[](CalculatorGraph* self) { [](CalculatorGraph* self) {
RaisePyErrorIfNotOk(self->CloseAllPacketSources()); RaisePyErrorIfNotOk(self->CloseAllPacketSources());
RaisePyErrorIfNotOk(self->WaitUntilDone()); py::gil_scoped_release gil_release;
RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true);
}, },
R"doc(Close all the input sources and shutdown the graph.)doc"); R"doc(Close all the input sources and shutdown the graph.)doc");

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "pybind11/gil.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
namespace mediapipe { namespace mediapipe {
@ -45,10 +46,17 @@ inline PyObject* StatusCodeToPyError(const ::absl::StatusCode& code) {
} }
} }
inline void RaisePyErrorIfNotOk(const absl::Status& status) { inline void RaisePyErrorIfNotOk(const absl::Status& status,
bool acquire_gil = false) {
if (!status.ok()) { if (!status.ok()) {
throw RaisePyError(StatusCodeToPyError(status.code()), if (acquire_gil) {
status.message().data()); py::gil_scoped_acquire acquire;
throw RaisePyError(StatusCodeToPyError(status.code()),
status.message().data());
} else {
throw RaisePyError(StatusCodeToPyError(status.code()),
status.message().data());
}
} }
} }

View File

@ -441,7 +441,7 @@ class SolutionBase:
else: else:
field_label = calculator_options.DESCRIPTOR.fields_by_name[ field_label = calculator_options.DESCRIPTOR.fields_by_name[
field_name].label field_name].label
if field_label is descriptor.FieldDescriptor.LABEL_REPEATED: if field_label == descriptor.FieldDescriptor.LABEL_REPEATED:
if not isinstance(field_value, Iterable): if not isinstance(field_value, Iterable):
raise ValueError( raise ValueError(
f'{field_name} is a repeated proto field but the value ' f'{field_name} is a repeated proto field but the value '

View File

@ -111,7 +111,7 @@ def draw_detection(
image_rows) image_rows)
rect_end_point = _normalized_to_pixel_coordinates( rect_end_point = _normalized_to_pixel_coordinates(
relative_bounding_box.xmin + relative_bounding_box.width, relative_bounding_box.xmin + relative_bounding_box.width,
relative_bounding_box.ymin + +relative_bounding_box.height, image_cols, relative_bounding_box.ymin + relative_bounding_box.height, image_cols,
image_rows) image_rows)
cv2.rectangle(image, rect_start_point, rect_end_point, cv2.rectangle(image, rect_start_point, rect_end_point,
bbox_drawing_spec.color, bbox_drawing_spec.thickness) bbox_drawing_spec.color, bbox_drawing_spec.thickness)

View File

@ -175,8 +175,9 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":resource_util_custom", ":resource_util_custom",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:singleton", "//mediapipe/framework/port:singleton",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",

View File

@ -38,10 +38,17 @@ workspace_file="$( cd "$(dirname "$0")" ; pwd -P )"/WORKSPACE
if [ -z "$1" ] if [ -z "$1" ]
then then
echo "Installing OpenCV from source" echo "Installing OpenCV from source"
sudo apt update && sudo apt install build-essential git if [[ -x "$(command -v apt)" ]]; then
sudo apt install cmake ffmpeg libavformat-dev libdc1394-22-dev libgtk2.0-dev \ sudo apt update && sudo apt install build-essential git
libjpeg-dev libpng-dev libswscale-dev libtbb2 libtbb-dev \ sudo apt install cmake ffmpeg libavformat-dev libdc1394-22-dev libgtk2.0-dev \
libtiff-dev libjpeg-dev libpng-dev libswscale-dev libtbb2 libtbb-dev \
libtiff-dev
elif [[ -x "$(command -v dnf)" ]]; then
sudo dnf update && sudo dnf install cmake gcc gcc-c git
sudo dnf install ffmpeg-devel libdc1394-devel gtk2-devel \
libjpeg-turbo-devel libpng-devel tbb-devel \
libtiff-devel
fi
rm -rf /tmp/build_opencv rm -rf /tmp/build_opencv
mkdir /tmp/build_opencv mkdir /tmp/build_opencv
cd /tmp/build_opencv cd /tmp/build_opencv