Project import generated by Copybara.

GitOrigin-RevId: bbbbcb4f5174dea33525729ede47c770069157cd
This commit is contained in:
MediaPipe Team 2021-10-18 12:39:29 -07:00 committed by chuoling
parent 33d683c671
commit 1faeaae7e5
75 changed files with 1944 additions and 560 deletions

View File

@ -120,7 +120,7 @@ just 86.22%.
### Hand Landmark Model ### Hand Landmark Model
After the palm detection over the whole image our subsequent hand landmark After the palm detection over the whole image our subsequent hand landmark
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite) [model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite)
performs precise keypoint localization of 21 3D hand-knuckle coordinates inside performs precise keypoint localization of 21 3D hand-knuckle coordinates inside
the detected hand regions via regression, that is direct coordinate prediction. the detected hand regions via regression, that is direct coordinate prediction.
The model learns a consistent internal hand pose representation and is robust The model learns a consistent internal hand pose representation and is robust
@ -163,6 +163,11 @@ unrelated, images. Default to `false`.
Maximum number of hands to detect. Default to `2`. Maximum number of hands to detect. Default to `2`.
#### model_complexity
Complexity of the hand landmark model: `0` or `1`. Landmark accuracy as well as
inference latency generally go up with the model complexity. Default to `1`.
#### min_detection_confidence #### min_detection_confidence
Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the
@ -212,6 +217,7 @@ Supported configuration options:
* [static_image_mode](#static_image_mode) * [static_image_mode](#static_image_mode)
* [max_num_hands](#max_num_hands) * [max_num_hands](#max_num_hands)
* [model_complexity](#model_complexity)
* [min_detection_confidence](#min_detection_confidence) * [min_detection_confidence](#min_detection_confidence)
* [min_tracking_confidence](#min_tracking_confidence) * [min_tracking_confidence](#min_tracking_confidence)
@ -260,6 +266,7 @@ with mp_hands.Hands(
# For webcam input: # For webcam input:
cap = cv2.VideoCapture(0) cap = cv2.VideoCapture(0)
with mp_hands.Hands( with mp_hands.Hands(
model_complexity=0,
min_detection_confidence=0.5, min_detection_confidence=0.5,
min_tracking_confidence=0.5) as hands: min_tracking_confidence=0.5) as hands:
while cap.isOpened(): while cap.isOpened():
@ -302,6 +309,7 @@ and a [fun application], and the following usage example.
Supported configuration options: Supported configuration options:
* [maxNumHands](#max_num_hands) * [maxNumHands](#max_num_hands)
* [modelComplexity](#model_complexity)
* [minDetectionConfidence](#min_detection_confidence) * [minDetectionConfidence](#min_detection_confidence)
* [minTrackingConfidence](#min_tracking_confidence) * [minTrackingConfidence](#min_tracking_confidence)
@ -351,6 +359,7 @@ const hands = new Hands({locateFile: (file) => {
}}); }});
hands.setOptions({ hands.setOptions({
maxNumHands: 2, maxNumHands: 2,
modelComplexity: 1,
minDetectionConfidence: 0.5, minDetectionConfidence: 0.5,
minTrackingConfidence: 0.5 minTrackingConfidence: 0.5
}); });

View File

@ -58,10 +58,12 @@ one over the other.
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite), [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite),
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1) [TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
* Hand landmark model: * Hand landmark model:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite), [TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_lite.tflite),
[TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite),
[TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite), [TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite),
[TF.js model](https://tfhub.dev/mediapipe/handskeleton/1) [TF.js model](https://tfhub.dev/mediapipe/handskeleton/1)
* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse) * [Model card](https://mediapipe.page.link/handmc),
[Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
### [Pose](https://google.github.io/mediapipe/solutions/pose) ### [Pose](https://google.github.io/mediapipe/solutions/pose)

View File

@ -125,7 +125,7 @@ hip midpoints.
:----------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: |
*Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* | *Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* |
### Pose Landmark Model (BlazePose GHUM 3D) ### Pose Landmark Model (BlazePose [GHUM](https://github.com/google-research/google-research/tree/master/ghum) 3D)
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
(see figure below). (see figure below).
@ -486,6 +486,7 @@ on how to build MediaPipe examples.
[BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204) [BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204)
([presentation](https://youtu.be/YPpUOTRn5tA)) ([presentation](https://youtu.be/YPpUOTRn5tA))
* [Models and model cards](./models.md#pose) * [Models and model cards](./models.md#pose)
* [GHUM & GHUML: Generative 3D Human Shape and Articulated Pose Models](https://github.com/google-research/google-research/tree/master/ghum)
* [Web demo](https://code.mediapipe.dev/codepen/pose) * [Web demo](https://code.mediapipe.dev/codepen/pose)
* [Python Colab](https://mediapipe.page.link/pose_py_colab) * [Python Colab](https://mediapipe.page.link/pose_py_colab)

View File

@ -531,9 +531,13 @@ cc_test(
":split_vector_calculator", ":split_vector_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -47,4 +47,8 @@ typedef BeginLoopCalculator<std::vector<std::vector<Matrix>>>
BeginLoopMatrixVectorCalculator; BeginLoopMatrixVectorCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator); REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
// A calculator to process std::vector<uint64_t>.
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -14,7 +14,11 @@
#include <memory> #include <memory>
#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "mediapipe/calculators/core/split_vector_calculator.h" #include "mediapipe/calculators/core/split_vector_calculator.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -301,4 +305,99 @@ TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) {
} }
} // namespace } // namespace
class PassThroughAndTsBoundUpdateNode : public mediapipe::api2::Node {
public:
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
static constexpr mediapipe::api2::Output<int> kOutValue{"VALUE"};
static constexpr mediapipe::api2::Output<int> kOutTsBoundUpdate{
"TS_BOUND_UPDATE"};
MEDIAPIPE_NODE_CONTRACT(kInValue, kOutValue, kOutTsBoundUpdate);
absl::Status Process(CalculatorContext* cc) override {
kOutValue(cc).Send(kInValue(cc));
kOutTsBoundUpdate(cc).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(PassThroughAndTsBoundUpdateNode);
class ToOptionalNode : public mediapipe::api2::Node {
public:
static constexpr mediapipe::api2::Input<int> kTick{"TICK"};
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
static constexpr mediapipe::api2::Output<absl::optional<int>> kOutValue{
"OUTPUT"};
MEDIAPIPE_NODE_CONTRACT(kTick, kInValue, kOutValue);
absl::Status Process(CalculatorContext* cc) override {
if (kInValue(cc).IsEmpty()) {
kOutValue(cc).Send(absl::nullopt);
} else {
kOutValue(cc).Send({kInValue(cc).Get()});
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(ToOptionalNode);
namespace {
TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "select"
node {
calculator: "PassThroughAndTsBoundUpdateNode"
input_stream: "VALUE:select"
output_stream: "VALUE:select_ps"
output_stream: "TS_BOUND_UPDATE:ts_bound_update"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:select_ps"
input_stream: "INPUT:1:ts_bound_update"
input_stream: "SELECT:select"
output_stream: "OUTPUT:select_or_ts_bound_update"
}
node {
calculator: "ToOptionalNode"
input_stream: "TICK:select"
input_stream: "VALUE:select_or_ts_bound_update"
output_stream: "OUTPUT:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
auto send_value_fn = [&](int value, Timestamp ts) -> absl::Status {
MP_RETURN_IF_ERROR(
graph.AddPacketToInputStream("select", MakePacket<int>(value).At(ts)));
return graph.WaitUntilIdle();
};
MP_ASSERT_OK(send_value_fn(0, Timestamp(1)));
ASSERT_EQ(output_packets.size(), 1);
EXPECT_EQ(output_packets[0].Get<absl::optional<int>>(), 0);
MP_ASSERT_OK(send_value_fn(1, Timestamp(2)));
ASSERT_EQ(output_packets.size(), 2);
EXPECT_EQ(output_packets[1].Get<absl::optional<int>>(), absl::nullopt);
MP_ASSERT_OK(send_value_fn(0, Timestamp(3)));
ASSERT_EQ(output_packets.size(), 3);
EXPECT_EQ(output_packets[2].Get<absl::optional<int>>(), 0);
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -34,7 +34,6 @@ option java_outer_classname = "InferenceCalculatorProto";
// } // }
// } // }
// } // }
//
message InferenceCalculatorOptions { message InferenceCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional InferenceCalculatorOptions ext = 336783863; optional InferenceCalculatorOptions ext = 336783863;
@ -69,8 +68,30 @@ message InferenceCalculatorOptions {
// Load pre-compiled serialized binary cache to accelerate init process. // Load pre-compiled serialized binary cache to accelerate init process.
// Only available for OpenCL delegate on Android. // Only available for OpenCL delegate on Android.
// Kernel caching will only be enabled if this path is set. // Kernel caching will only be enabled if this path is set.
//
// NOTE: binary cache usage may be skipped if valid serialized model,
// specified by "serialized_model_dir", exists.
//
// TODO: update to cached_kernel_dir
optional string cached_kernel_path = 2; optional string cached_kernel_path = 2;
// A dir to load from and save to a pre-compiled serialized model used to
// accelerate init process.
//
// NOTE: available for OpenCL delegate on Android only when
// "use_advanced_gpu_api" is set to true and "model_token" is set
// properly.
//
// NOTE: serialized model takes precedence over binary cache
// specified by "cached_kernel_path", which still can be used if
// serialized model is invalid or missing.
optional string serialized_model_dir = 7;
// Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens.
optional string model_token = 8;
// Encapsulated compilation/runtime tradeoffs. // Encapsulated compilation/runtime tradeoffs.
enum InferenceUsage { enum InferenceUsage {
UNSPECIFIED = 0; UNSPECIFIED = 0;

View File

@ -20,6 +20,7 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h" #include "mediapipe/util/tflite/config.h"
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
@ -49,8 +50,8 @@ class InferenceCalculatorGlImpl
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
absl::Status ReadKernelsFromFile(); absl::Status ReadGpuCaches();
absl::Status WriteKernelsToFile(); absl::Status SaveGpuCaches();
absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
@ -82,6 +83,8 @@ class InferenceCalculatorGlImpl
bool use_kernel_caching_ = false; bool use_kernel_caching_ = false;
std::string cached_kernel_filename_; std::string cached_kernel_filename_;
bool use_serialized_model_ = false;
std::string serialized_model_path_;
}; };
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
@ -114,6 +117,9 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
tflite_gpu_runner_usage_ = delegate.gpu().usage(); tflite_gpu_runner_usage_ = delegate.gpu().usage();
use_kernel_caching_ = use_kernel_caching_ =
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path(); use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
use_serialized_model_ = use_advanced_gpu_api_ &&
delegate.gpu().has_serialized_model_dir() &&
delegate.gpu().has_model_token();
use_gpu_delegate_ = !use_advanced_gpu_api_; use_gpu_delegate_ = !use_advanced_gpu_api_;
if (use_kernel_caching_) { if (use_kernel_caching_) {
@ -123,6 +129,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
".ker"; ".ker";
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
} }
if (use_serialized_model_) {
#ifdef MEDIAPIPE_ANDROID
serialized_model_path_ = mediapipe::file::JoinPath(
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
#endif // MEDIAPIPE_ANDROID
}
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything. // for everything.
@ -210,7 +222,7 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() { absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() {
#ifdef MEDIAPIPE_ANDROID #ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_) { if (use_kernel_caching_) {
// Save kernel file. // Save kernel file.
@ -220,12 +232,22 @@ absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
} }
if (use_serialized_model_) {
// Save serialized model file.
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
tflite_gpu_runner_->GetSerializedModel());
absl::string_view serialized_model(
reinterpret_cast<char*>(serialized_model_vec.data()),
serialized_model_vec.size());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
}
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(WriteKernelsToFile()); MP_RETURN_IF_ERROR(SaveGpuCaches());
if (use_gpu_delegate_) { if (use_gpu_delegate_) {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
gpu_buffers_in_.clear(); gpu_buffers_in_.clear();
@ -239,17 +261,24 @@ absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() { absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
#ifdef MEDIAPIPE_ANDROID #ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_) { if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
// Load pre-compiled kernel file. // Load pre-compiled kernel file.
if (mediapipe::File::Exists(cached_kernel_filename_)) { std::string cache_str;
std::string cache_str; MP_RETURN_IF_ERROR(
MP_RETURN_IF_ERROR( mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end()); tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); }
} if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
// Load serialized model file.
std::string serialized_model_str;
MP_RETURN_IF_ERROR(
file::GetContents(serialized_model_path_, &serialized_model_str));
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
serialized_model_str.end());
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
} }
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
return absl::OkStatus(); return absl::OkStatus();
@ -313,7 +342,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
tflite_gpu_runner_->GetOutputShapes()[i].c}; tflite_gpu_runner_->GetOutputShapes()[i].c};
} }
MP_RETURN_IF_ERROR(ReadKernelsFromFile()); MP_RETURN_IF_ERROR(ReadGpuCaches());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());

View File

@ -24,20 +24,20 @@ message SsdAnchorsCalculatorOptions {
optional SsdAnchorsCalculatorOptions ext = 247258239; optional SsdAnchorsCalculatorOptions ext = 247258239;
} }
// Size of input images. // Size of input images.
required int32 input_size_width = 1; optional int32 input_size_width = 1; // required
required int32 input_size_height = 2; optional int32 input_size_height = 2; // required
// Min and max scales for generating anchor boxes on feature maps. // Min and max scales for generating anchor boxes on feature maps.
required float min_scale = 3; optional float min_scale = 3; // required
required float max_scale = 4; optional float max_scale = 4; // required
// The offset for the center of anchors. The value is in the scale of stride. // The offset for the center of anchors. The value is in the scale of stride.
// E.g. 0.5 meaning 0.5 * |current_stride| in pixels. // E.g. 0.5 meaning 0.5 * |current_stride| in pixels.
required float anchor_offset_x = 5 [default = 0.5]; optional float anchor_offset_x = 5 [default = 0.5]; // required
required float anchor_offset_y = 6 [default = 0.5]; optional float anchor_offset_y = 6 [default = 0.5]; // required
// Number of output feature maps to generate the anchors on. // Number of output feature maps to generate the anchors on.
required int32 num_layers = 7; optional int32 num_layers = 7; // required
// Sizes of output feature maps to create anchors. Either feature_map size or // Sizes of output feature maps to create anchors. Either feature_map size or
// stride should be provided. // stride should be provided.
repeated int32 feature_map_width = 8; repeated int32 feature_map_width = 8;

View File

@ -26,12 +26,12 @@ message TfLiteTensorsToDetectionsCalculatorOptions {
} }
// The number of output classes predicted by the detection model. // The number of output classes predicted by the detection model.
required int32 num_classes = 1; optional int32 num_classes = 1; // required
// The number of output boxes predicted by the detection model. // The number of output boxes predicted by the detection model.
required int32 num_boxes = 2; optional int32 num_boxes = 2; // required
// The number of output values per boxes predicted by the detection model. The // The number of output values per boxes predicted by the detection model. The
// values contain bounding boxes, keypoints, etc. // values contain bounding boxes, keypoints, etc.
required int32 num_coords = 3; optional int32 num_coords = 3; // required
// The offset of keypoint coordinates in the location tensor. // The offset of keypoint coordinates in the location tensor.
optional int32 keypoint_coord_offset = 9; optional int32 keypoint_coord_offset = 9;

View File

@ -31,7 +31,7 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
} }
// Number of landmarks from the output of the model. // Number of landmarks from the output of the model.
required int32 num_landmarks = 1; optional int32 num_landmarks = 1; // required
// Size of the input image for the model. These options are used only when // Size of the input image for the model. These options are used only when
// normalized landmarks are needed. Z coordinate is scaled as X assuming // normalized landmarks are needed. Z coordinate is scaled as X assuming

View File

@ -24,9 +24,9 @@ message TfLiteTensorsToSegmentationCalculatorOptions {
} }
// Dimensions of input segmentation tensor to process. // Dimensions of input segmentation tensor to process.
required int32 tensor_width = 1; optional int32 tensor_width = 1; // required
required int32 tensor_height = 2; optional int32 tensor_height = 2; // required
required int32 tensor_channels = 3; optional int32 tensor_channels = 3; // required
// How much to use previous mask when computing current one; range [0-1]. // How much to use previous mask when computing current one; range [0-1].
// This is a tradeoff between responsiveness (0.0) and accuracy (1.0). // This is a tradeoff between responsiveness (0.0) and accuracy (1.0).

View File

@ -98,6 +98,43 @@ public class MainActivity extends AppCompatActivity {
} }
} }
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
/** Sets up the UI components for the static image demo. */ /** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() { private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap. // The Intent to access gallery and read images as bitmap.
@ -111,37 +148,16 @@ public class MainActivity extends AppCompatActivity {
Bitmap bitmap = null; Bitmap bitmap = null;
try { try {
bitmap = bitmap =
MediaStore.Images.Media.getBitmap( downscaleBitmap(
this.getContentResolver(), resultIntent.getData()); MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) { } catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e); Log.e(TAG, "Bitmap reading error:" + e);
} }
try { try {
InputStream imageData = InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData()); this.getContentResolver().openInputStream(resultIntent.getData());
int orientation = bitmap = rotateBitmap(bitmap, imageData);
new ExifInterface(imageData)
.getAttributeInt(
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
bitmap =
Bitmap.createBitmap(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
}
} catch (IOException e) { } catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e); Log.e(TAG, "Bitmap rotation error:" + e);
} }

View File

@ -7,7 +7,7 @@ android {
buildToolsVersion "30.0.3" buildToolsVersion "30.0.3"
defaultConfig { defaultConfig {
applicationId "com.google.mediapipe.apps.hands" applicationId "com.google.mediapipe.apps.facemesh"
minSdkVersion 21 minSdkVersion 21
targetSdkVersion 30 targetSdkVersion 30
versionCode 1 versionCode 1

View File

@ -15,7 +15,7 @@
<application <application
android:allowBackup="true" android:allowBackup="true"
android:icon="@mipmap/ic_launcher" android:icon="@mipmap/ic_launcher"
android:label="MediaPipe FaceMesh" android:label="MediaPipe Face Mesh"
android:roundIcon="@mipmap/ic_launcher_round" android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true" android:supportsRtl="true"
android:theme="@style/AppTheme"> android:theme="@style/AppTheme">

View File

@ -99,6 +99,43 @@ public class MainActivity extends AppCompatActivity {
} }
} }
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
/** Sets up the UI components for the static image demo. */ /** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() { private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap. // The Intent to access gallery and read images as bitmap.
@ -112,37 +149,16 @@ public class MainActivity extends AppCompatActivity {
Bitmap bitmap = null; Bitmap bitmap = null;
try { try {
bitmap = bitmap =
MediaStore.Images.Media.getBitmap( downscaleBitmap(
this.getContentResolver(), resultIntent.getData()); MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) { } catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e); Log.e(TAG, "Bitmap reading error:" + e);
} }
try { try {
InputStream imageData = InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData()); this.getContentResolver().openInputStream(resultIntent.getData());
int orientation = bitmap = rotateBitmap(bitmap, imageData);
new ExifInterface(imageData)
.getAttributeInt(
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
bitmap =
Bitmap.createBitmap(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
}
} catch (IOException e) { } catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e); Log.e(TAG, "Bitmap rotation error:" + e);
} }

View File

@ -100,6 +100,43 @@ public class MainActivity extends AppCompatActivity {
} }
} }
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
/** Sets up the UI components for the static image demo. */ /** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() { private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap. // The Intent to access gallery and read images as bitmap.
@ -113,37 +150,16 @@ public class MainActivity extends AppCompatActivity {
Bitmap bitmap = null; Bitmap bitmap = null;
try { try {
bitmap = bitmap =
MediaStore.Images.Media.getBitmap( downscaleBitmap(
this.getContentResolver(), resultIntent.getData()); MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) { } catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e); Log.e(TAG, "Bitmap reading error:" + e);
} }
try { try {
InputStream imageData = InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData()); this.getContentResolver().openInputStream(resultIntent.getData());
int orientation = bitmap = rotateBitmap(bitmap, imageData);
new ExifInterface(imageData)
.getAttributeInt(
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.postRotate(90);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.postRotate(180);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.postRotate(270);
break;
default:
matrix.postRotate(0);
}
bitmap =
Bitmap.createBitmap(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
}
} catch (IOException e) { } catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e); Log.e(TAG, "Bitmap rotation error:" + e);
} }

View File

@ -38,7 +38,7 @@ android_binary(
assets = [ assets = [
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb", "//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite", "//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/palm_detection:palm_detection.tflite", "//mediapipe/modules/palm_detection:palm_detection.tflite",
], ],
assets_dir = "", assets_dir = "",

View File

@ -39,7 +39,7 @@ android_binary(
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb", "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
"//mediapipe/modules/face_detection:face_detection_short_range.tflite", "//mediapipe/modules/face_detection:face_detection_short_range.tflite",
"//mediapipe/modules/face_landmark:face_landmark.tflite", "//mediapipe/modules/face_landmark:face_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite", "//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite",

View File

@ -62,7 +62,7 @@ objc_library(
copts = ["-std=c++17"], copts = ["-std=c++17"],
data = [ data = [
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb", "//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite", "//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/palm_detection:palm_detection.tflite", "//mediapipe/modules/palm_detection:palm_detection.tflite",
], ],

View File

@ -57,7 +57,7 @@ objc_library(
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb", "//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
"//mediapipe/modules/face_detection:face_detection_short_range.tflite", "//mediapipe/modules/face_detection:face_detection_short_range.tflite",
"//mediapipe/modules/face_landmark:face_landmark.tflite", "//mediapipe/modules/face_landmark:face_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite", "//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite", "//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite", "//mediapipe/modules/pose_detection:pose_detection.tflite",

View File

@ -427,7 +427,8 @@ absl::Status CalculatorGraph::Initialize(
const std::map<std::string, Packet>& side_packets) { const std::map<std::string, Packet>& side_packets) {
auto validated_graph = absl::make_unique<ValidatedGraphConfig>(); auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize( MP_RETURN_IF_ERROR(validated_graph->Initialize(
input_config, /*graph_registry=*/nullptr, &service_manager_)); input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr,
&service_manager_));
return Initialize(std::move(validated_graph), side_packets); return Initialize(std::move(validated_graph), side_packets);
} }

View File

@ -122,7 +122,7 @@ class CalculatorRunner {
const StreamContentsSet& Outputs() const { return *outputs_; } const StreamContentsSet& Outputs() const { return *outputs_; }
// Returns the access to the output side packets. // Returns the access to the output side packets.
const PacketSet& OutputSidePackets() { return *output_side_packets_.get(); } const PacketSet& OutputSidePackets() { return *output_side_packets_; }
// Returns a graph counter. // Returns a graph counter.
mediapipe::Counter* GetCounter(const std::string& name); mediapipe::Counter* GetCounter(const std::string& name);

View File

@ -77,13 +77,6 @@ bool Image::ConvertToGpu() const {
#else #else
// GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture) // GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture)
auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_); auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_);
glBindTexture(GL_TEXTURE_2D, buffer->name());
// See GlCalculatorHelperImpl::SetStandardTextureParams
glTexParameteri(buffer->target(), GL_TEXTURE_MIN_FILTER, GL_LINEAR);
glTexParameteri(buffer->target(), GL_TEXTURE_MAG_FILTER, GL_LINEAR);
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
glBindTexture(GL_TEXTURE_2D, 0);
glFlush(); glFlush();
gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer)); gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer));
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER

View File

@ -244,6 +244,8 @@ cc_test(
srcs = ["mux_input_stream_handler_test.cc"], srcs = ["mux_input_stream_handler_test.cc"],
deps = [ deps = [
":mux_input_stream_handler", ":mux_input_stream_handler",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:make_pair_calculator",
"//mediapipe/calculators/core:mux_calculator", "//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/core:round_robin_demux_calculator", "//mediapipe/calculators/core:round_robin_demux_calculator",

View File

@ -75,13 +75,30 @@ class MuxInputStreamHandler : public InputStreamHandler {
int control_value = control_packet.Get<int>(); int control_value = control_packet.Get<int>();
CHECK_LE(0, control_value); CHECK_LE(0, control_value);
CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1); CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1);
const auto& data_stream = input_stream_managers_.Get( const auto& data_stream = input_stream_managers_.Get(
input_stream_managers_.BeginId() + control_value); input_stream_managers_.BeginId() + control_value);
// Data stream may contain some outdated packets which failed to be popped
// out during "FillInputSet". (This handler doesn't sync input streams,
// hence "FillInputSet" can be triggerred before every input stream is
// filled with packets corresponding to the same timestamp.)
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty); Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
if (empty) { if (empty) {
CHECK_LE(stream_timestamp, *min_stream_timestamp); if (stream_timestamp <= *min_stream_timestamp) {
return NodeReadiness::kNotReady; // "data_stream" didn't receive a packet corresponding to the current
// "control_stream" packet yet.
return NodeReadiness::kNotReady;
}
// "data_stream" timestamp bound update detected.
return NodeReadiness::kReadyForProcess;
}
if (stream_timestamp > *min_stream_timestamp) {
// The earliest packet "data_stream" holds corresponds to a control packet
// yet to arrive, which means there won't be a "data_stream" packet
// corresponding to the current "control_stream" packet, which should be
// indicated as timestamp boun update.
return NodeReadiness::kReadyForProcess;
} }
CHECK_EQ(stream_timestamp, *min_stream_timestamp); CHECK_EQ(stream_timestamp, *min_stream_timestamp);
return NodeReadiness::kReadyForProcess; return NodeReadiness::kReadyForProcess;

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
@ -19,9 +20,10 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using ::testing::ElementsAre;
// A regression test for b/31620439. MuxInputStreamHandler's accesses to the // A regression test for b/31620439. MuxInputStreamHandler's accesses to the
// control and data streams should be atomic so that it has a consistent view // control and data streams should be atomic so that it has a consistent view
// of the two streams. None of the CHECKs in the GetNodeReadiness() method of // of the two streams. None of the CHECKs in the GetNodeReadiness() method of
@ -87,5 +89,561 @@ TEST(MuxInputStreamHandlerTest, AtomicAccessToControlAndDataStreams) {
MP_ASSERT_OK(graph.WaitUntilDone()); MP_ASSERT_OK(graph.WaitUntilDone());
} }
MATCHER_P2(IntPacket, value, ts, "") {
return arg.template Get<int>() == value && arg.Timestamp() == ts;
}
struct GateAndMuxGraphInput {
int input0;
int input1;
int input2;
int select;
bool allow0;
bool allow1;
bool allow2;
Timestamp at;
};
constexpr char kGateAndMuxGraph[] = R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "input2"
input_stream: "select"
input_stream: "allow0"
input_stream: "allow1"
input_stream: "allow2"
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow0"
input_stream: "input0"
output_stream: "output0"
}
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow1"
input_stream: "input1"
output_stream: "output1"
}
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow2"
input_stream: "input2"
output_stream: "output2"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:output0"
input_stream: "INPUT:1:output1"
input_stream: "INPUT:2:output2"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
})pb";
absl::Status SendInput(GateAndMuxGraphInput in, CalculatorGraph& graph) {
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"input0", MakePacket<int>(in.input0).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"input1", MakePacket<int>(in.input1).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"input2", MakePacket<int>(in.input2).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"select", MakePacket<int>(in.select).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"allow0", MakePacket<bool>(in.allow0).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"allow1", MakePacket<bool>(in.allow1).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"allow2", MakePacket<bool>(in.allow2).At(in.at)));
return graph.WaitUntilIdle();
}
TEST(MuxInputStreamHandlerTest, BasicMuxing) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = true,
.allow1 = false,
.allow2 = false,
.at = Timestamp(1)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = false,
.allow1 = true,
.allow2 = false,
.at = Timestamp(2)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = false,
.allow1 = false,
.allow2 = true,
.at = Timestamp(3)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(3))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingNonEmptyInputs) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(1)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(2)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(3)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(3))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingAllTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(1)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(2)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(3)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingSlectedTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = false,
.allow1 = true,
.allow2 = true,
.at = Timestamp(1)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = true,
.allow1 = false,
.allow2 = true,
.at = Timestamp(2)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = false,
.at = Timestamp(3)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingSometimesTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(1)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = false,
.allow1 = true,
.allow2 = false,
.at = Timestamp(2)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = false,
.at = Timestamp(3)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(4)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(4))));
MP_ASSERT_OK(SendInput({.input0 = 700,
.input1 = 600,
.input2 = 500,
.select = 0,
.allow0 = true,
.allow1 = false,
.allow2 = false,
.at = Timestamp(5)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(4)),
IntPacket(700, Timestamp(5))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
MATCHER_P(EmptyPacket, ts, "") {
return arg.IsEmpty() && arg.Timestamp() == ts;
}
MATCHER_P2(Pair, m1, m2, "") {
const auto& p = arg.template Get<std::pair<Packet, Packet>>();
return testing::Matches(m1)(p.first) && testing::Matches(m2)(p.second);
}
TEST(MuxInputStreamHandlerTest,
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacket) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input0"
input_stream: "INPUT:1:input1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
node {
calculator: "MakePairCalculator"
input_stream: "select"
input_stream: "output"
output_stream: "pair"
}
)pb");
std::vector<Packet> pair_packets;
tool::AddVectorSink("pair", &config, &pair_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_TRUE(pair_packets.empty());
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
EmptyPacket(Timestamp(1)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(800).At(Timestamp(4))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
EmptyPacket(Timestamp(1)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(3))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(4))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest,
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacketPacketsAtOnce) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input0"
input_stream: "INPUT:1:input1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
node {
calculator: "MakePairCalculator"
input_stream: "select"
input_stream: "output"
output_stream: "pair"
}
)pb");
std::vector<Packet> pair_packets;
tool::AddVectorSink("pair", &config, &pair_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(800).At(Timestamp(4))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(3))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(4))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest,
TimestampBoundUpdateTriggersTimestampBoundUpdate) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
input_stream: "allow0"
input_stream: "allow1"
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow0"
input_stream: "input0"
output_stream: "output0"
}
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow1"
input_stream: "input1"
output_stream: "output1"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:output0"
input_stream: "INPUT:1:output1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
node {
calculator: "MakePairCalculator"
input_stream: "select"
input_stream: "output"
output_stream: "pair"
}
)pb");
std::vector<Packet> pair_packets;
tool::AddVectorSink("pair", &config, &pair_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"allow0", MakePacket<bool>(false).At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
EmptyPacket(Timestamp(1)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(900).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"allow0", MakePacket<bool>(true).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(900, Timestamp(2)))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -56,7 +56,9 @@ class SubgraphContext {
return options_map_.Get<T>(); return options_map_.Get<T>();
} }
const CalculatorGraphConfig::Node& OriginalNode() { return original_node_; } const CalculatorGraphConfig::Node& OriginalNode() const {
return original_node_;
}
template <typename T> template <typename T>
ServiceBinding<T> Service(const GraphService<T>& service) const { ServiceBinding<T> Service(const GraphService<T>& service) const {

View File

@ -724,6 +724,7 @@ cc_test(
srcs = ["subgraph_expansion_test.cc"], srcs = ["subgraph_expansion_test.cc"],
deps = [ deps = [
":node_chain_subgraph_cc_proto", ":node_chain_subgraph_cc_proto",
":node_chain_subgraph_options_lib",
":subgraph_expansion", ":subgraph_expansion",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",

View File

@ -23,6 +23,8 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/type_map.h" #include "mediapipe/framework/type_map.h"
#define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging()
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
@ -47,13 +49,13 @@ absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in,
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
if (IsLengthDelimited(wire_type)) { if (IsLengthDelimited(wire_type)) {
uint32 length; uint32 length;
RET_CHECK(in->ReadVarint32(&length)); RET_CHECK_NO_LOG(in->ReadVarint32(&length));
RET_CHECK(in->ReadString(result, length)); RET_CHECK_NO_LOG(in->ReadString(result, length));
} else { } else {
std::string field_data; std::string field_data;
StringOutputStream sos(&field_data); StringOutputStream sos(&field_data);
CodedOutputStream cos(&sos); CodedOutputStream cos(&sos);
RET_CHECK(WireFormatLite::SkipField(in, tag, &cos)); RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, &cos));
// Skip the tag written by SkipField. // Skip the tag written by SkipField.
int tag_size = CodedOutputStream::VarintSize32(tag); int tag_size = CodedOutputStream::VarintSize32(tag);
cos.Trim(); cos.Trim();
@ -67,13 +69,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
CodedInputStream* in, CodedInputStream* in,
std::vector<std::string>* field_values) { std::vector<std::string>* field_values) {
uint32 data_size; uint32 data_size;
RET_CHECK(in->ReadVarint32(&data_size)); RET_CHECK_NO_LOG(in->ReadVarint32(&data_size));
// fake_tag encodes the wire-type for calls to WireFormatLite::SkipField. // fake_tag encodes the wire-type for calls to WireFormatLite::SkipField.
uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type); uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type);
while (data_size > 0) { while (data_size > 0) {
std::string number; std::string number;
MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number)); MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number));
RET_CHECK_LE(number.size(), data_size); RET_CHECK_NO_LOG(number.size() <= data_size);
field_values->push_back(number); field_values->push_back(number);
data_size -= number.size(); data_size -= number.size();
} }
@ -98,7 +100,7 @@ absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
field_values->push_back(value); field_values->push_back(value);
} }
} else { } else {
RET_CHECK(WireFormatLite::SkipField(in, tag, out)); RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, out));
} }
} }
return absl::OkStatus(); return absl::OkStatus();
@ -157,12 +159,12 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
MP_RETURN_IF_ERROR(access.SetMessage(*message)); MP_RETURN_IF_ERROR(access.SetMessage(*message));
std::vector<std::string>& v = *access.mutable_field_values(); std::vector<std::string>& v = *access.mutable_field_values();
if (!proto_path.empty()) { if (!proto_path.empty()) {
RET_CHECK(index >= 0 && index < v.size()); RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length, MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
field_type, field_values)); field_type, field_values));
} else { } else {
RET_CHECK(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK(index + length >= 0 && index + length <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
v.erase(v.begin() + index, v.begin() + index + length); v.erase(v.begin() + index, v.begin() + index + length);
v.insert(v.begin() + index, field_values.begin(), field_values.end()); v.insert(v.begin() + index, field_values.begin(), field_values.end());
} }
@ -184,12 +186,12 @@ absl::Status ProtoUtilLite::GetFieldRange(
MP_RETURN_IF_ERROR(access.SetMessage(message)); MP_RETURN_IF_ERROR(access.SetMessage(message));
std::vector<std::string>& v = *access.mutable_field_values(); std::vector<std::string>& v = *access.mutable_field_values();
if (!proto_path.empty()) { if (!proto_path.empty()) {
RET_CHECK(index >= 0 && index < v.size()); RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
GetFieldRange(v[index], proto_path, length, field_type, field_values)); GetFieldRange(v[index], proto_path, length, field_type, field_values));
} else { } else {
RET_CHECK(index >= 0 && index <= v.size()); RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK(index + length >= 0 && index + length <= v.size()); RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
field_values->insert(field_values->begin(), v.begin() + index, field_values->insert(field_values->begin(), v.begin() + index,
v.begin() + index + length); v.begin() + index + length);
} }

View File

@ -274,12 +274,14 @@ absl::Status ConnectSubgraphStreams(
absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
const GraphRegistry* graph_registry, const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) { const GraphServiceManager* service_manager) {
graph_registry = graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
RET_CHECK(config); RET_CHECK(config);
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions( MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
CalculatorGraphConfig::Node(), config)); graph_options ? *graph_options : CalculatorGraphConfig::Node(), config));
auto* nodes = config->mutable_node(); auto* nodes = config->mutable_node();
while (1) { while (1) {
auto subgraph_nodes_start = std::stable_partition( auto subgraph_nodes_start = std::stable_partition(

View File

@ -72,6 +72,7 @@ absl::Status ConnectSubgraphStreams(
absl::Status ExpandSubgraphs( absl::Status ExpandSubgraphs(
CalculatorGraphConfig* config, CalculatorGraphConfig* config,
const GraphRegistry* graph_registry = nullptr, const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr); const GraphServiceManager* service_manager = nullptr);
// Creates a graph wrapping the provided node and exposing all of its // Creates a graph wrapping the provided node and exposing all of its

View File

@ -560,9 +560,111 @@ TEST(SubgraphExpansionTest, GraphServicesUsage) {
MP_ASSERT_OK(service_manager.SetServiceObject( MP_ASSERT_OK(service_manager.SetServiceObject(
kStringTestService, std::make_shared<std::string>("ExpectedNode"))); kStringTestService, std::make_shared<std::string>("ExpectedNode")));
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr, MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr,
/*graph_options=*/nullptr,
&service_manager)); &service_manager));
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph)); EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
} }
// Shows SubgraphOptions consumed by GraphRegistry::CreateByName.
TEST(SubgraphExpansionTest, SubgraphOptionsUsage) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
GraphRegistry graph_registry;
// CalculatorGraph::Initialize passes the SubgraphOptions into:
// (1) GraphRegistry::CreateByName("NodeChainSubgraph", options)
// (2) tool::ExpandSubgraphs(&config, options)
auto graph_options =
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
options {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
})pb");
SubgraphContext context(&graph_options, /*service_manager=*/nullptr);
// "NodeChainSubgraph" consumes graph_options only in CreateByName.
auto subgraph_status =
graph_registry.CreateByName("", "NodeChainSubgraph", &context);
MP_ASSERT_OK(subgraph_status);
auto subgraph = std::move(subgraph_status).value();
MP_ASSERT_OK(
tool::ExpandSubgraphs(&subgraph, &graph_registry, &graph_options));
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_0"
output_stream: "stream_1"
}
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_1"
output_stream: "stream_2"
}
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_2"
output_stream: "stream_3"
}
input_stream: "INPUT:stream_0"
output_stream: "OUTPUT:stream_3"
)pb");
EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_graph));
}
// Shows SubgraphOptions consumed by tool::ExpandSubgraphs.
TEST(SubgraphExpansionTest, SimpleSubgraphOptionsUsage) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
GraphRegistry graph_registry;
auto moon_options =
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
options {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
})pb");
auto moon_subgraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
type: "MoonSubgraph"
graph_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
node: {
calculator: "MoonCalculator"
node_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
option_value: "chain_length:options/chain_length"
}
)pb");
// The moon_options are copied into the graph_options of moon_subgraph.
MP_ASSERT_OK(
tool::ExpandSubgraphs(&moon_subgraph, &graph_registry, &moon_options));
// The field chain_length is copied from moon_options into MoonCalculator.
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "MoonCalculator"
node_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {
chain_length: 3
}
}
option_value: "chain_length:options/chain_length"
}
type: "MoonSubgraph"
graph_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
)pb");
EXPECT_THAT(moon_subgraph, mediapipe::EqualsProto(expected_graph));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -150,12 +150,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
const int packet_count = 10; const int packet_count = 10;
// Send int value packets at {10K, 20K, 30K, ..., 100K}. // Send int value packets at {10K, 20K, 30K, ..., 100K}.
for (uint64 t = 1; t <= packet_count; ++t) { for (uint64 t = 1; t <= packet_count; ++t) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
if (send_bounds) { if (send_bounds) {
MP_EXPECT_OK(graph.AddPacketToInputStream( MP_EXPECT_OK(graph.AddPacketToInputStream(
"enable", MakePacket<bool>(true).At(Timestamp(t * 10000)))); "enable", MakePacket<bool>(true).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle());
} }
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
// The inputs are sent to the input stream "foo", they should pass through. // The inputs are sent to the input stream "foo", they should pass through.
EXPECT_EQ(out_foo.size(), t); EXPECT_EQ(out_foo.size(), t);
@ -175,12 +176,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
// Send int value packets at {110K, 120K, ..., 200K}. // Send int value packets at {110K, 120K, ..., 200K}.
for (uint64 t = 11; t <= packet_count * 2; ++t) { for (uint64 t = 11; t <= packet_count * 2; ++t) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
if (send_bounds) { if (send_bounds) {
MP_EXPECT_OK(graph.AddPacketToInputStream( MP_EXPECT_OK(graph.AddPacketToInputStream(
"enable", MakePacket<bool>(false).At(Timestamp(t * 10000)))); "enable", MakePacket<bool>(false).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle());
} }
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.WaitUntilIdle());
// The inputs are sent to the input stream "foo", they should pass through. // The inputs are sent to the input stream "foo", they should pass through.
EXPECT_EQ(out_foo.size(), t); EXPECT_EQ(out_foo.size(), t);

View File

@ -143,11 +143,12 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) {
absl::Status PerformBasicTransforms( absl::Status PerformBasicTransforms(
const CalculatorGraphConfig& input_graph_config, const CalculatorGraphConfig& input_graph_config,
const GraphRegistry* graph_registry, const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager, const GraphServiceManager* service_manager,
CalculatorGraphConfig* output_graph_config) { CalculatorGraphConfig* output_graph_config) {
*output_graph_config = input_graph_config; *output_graph_config = input_graph_config;
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry, MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
service_manager)); graph_options, service_manager));
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config)); MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
@ -347,6 +348,7 @@ absl::Status NodeTypeInfo::Initialize(
absl::Status ValidatedGraphConfig::Initialize( absl::Status ValidatedGraphConfig::Initialize(
const CalculatorGraphConfig& input_config, const CalculatorGraphConfig& input_config,
const GraphRegistry* graph_registry, const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) { const GraphServiceManager* service_manager) {
RET_CHECK(!initialized_) RET_CHECK(!initialized_)
<< "ValidatedGraphConfig can be initialized only once."; << "ValidatedGraphConfig can be initialized only once.";
@ -356,8 +358,8 @@ absl::Status ValidatedGraphConfig::Initialize(
<< input_config.DebugString(); << input_config.DebugString();
#endif #endif
MP_RETURN_IF_ERROR(PerformBasicTransforms(input_config, graph_registry, MP_RETURN_IF_ERROR(PerformBasicTransforms(
service_manager, &config_)); input_config, graph_registry, graph_options, service_manager, &config_));
// Initialize the basic node information. // Initialize the basic node information.
MP_RETURN_IF_ERROR(InitializeGeneratorInfo()); MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
@ -431,22 +433,24 @@ absl::Status ValidatedGraphConfig::Initialize(
} }
absl::Status ValidatedGraphConfig::Initialize( absl::Status ValidatedGraphConfig::Initialize(
const std::string& graph_type, const Subgraph::SubgraphOptions* options, const std::string& graph_type, const GraphRegistry* graph_registry,
const GraphRegistry* graph_registry, const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) { const GraphServiceManager* service_manager) {
graph_registry = graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
SubgraphContext subgraph_context(options, service_manager); SubgraphContext subgraph_context(graph_options, service_manager);
auto status_or_config = auto status_or_config =
graph_registry->CreateByName("", graph_type, &subgraph_context); graph_registry->CreateByName("", graph_type, &subgraph_context);
MP_RETURN_IF_ERROR(status_or_config.status()); MP_RETURN_IF_ERROR(status_or_config.status());
return Initialize(status_or_config.value(), graph_registry, service_manager); return Initialize(status_or_config.value(), graph_registry, graph_options,
service_manager);
} }
absl::Status ValidatedGraphConfig::Initialize( absl::Status ValidatedGraphConfig::Initialize(
const std::vector<CalculatorGraphConfig>& input_configs, const std::vector<CalculatorGraphConfig>& input_configs,
const std::vector<CalculatorGraphTemplate>& input_templates, const std::vector<CalculatorGraphTemplate>& input_templates,
const std::string& graph_type, const Subgraph::SubgraphOptions* arguments, const std::string& graph_type,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) { const GraphServiceManager* service_manager) {
GraphRegistry graph_registry; GraphRegistry graph_registry;
for (auto& config : input_configs) { for (auto& config : input_configs) {
@ -455,7 +459,8 @@ absl::Status ValidatedGraphConfig::Initialize(
for (auto& templ : input_templates) { for (auto& templ : input_templates) {
graph_registry.Register(templ.config().type(), templ); graph_registry.Register(templ.config().type(), templ);
} }
return Initialize(graph_type, arguments, &graph_registry, service_manager); return Initialize(graph_type, &graph_registry, graph_options,
service_manager);
} }
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() { absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {

View File

@ -195,18 +195,21 @@ class ValidatedGraphConfig {
// Initializes the ValidatedGraphConfig. This function must be called // Initializes the ValidatedGraphConfig. This function must be called
// before any other functions. Subgraphs are specified through the // before any other functions. Subgraphs are specified through the
// global graph registry or an optional local graph registry. // global graph registry or an optional local graph registry.
absl::Status Initialize(const CalculatorGraphConfig& input_config, absl::Status Initialize(
const GraphRegistry* graph_registry = nullptr, const CalculatorGraphConfig& input_config,
const GraphServiceManager* service_manager = nullptr); const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Initializes the ValidatedGraphConfig from registered graph and subgraph // Initializes the ValidatedGraphConfig from registered graph and subgraph
// configs. Subgraphs are retrieved from the specified graph registry or from // configs. Subgraphs are retrieved from the specified graph registry or from
// the global graph registry. A subgraph can be instantiated directly by // the global graph registry. A subgraph can be instantiated directly by
// specifying its type in |graph_type|. // specifying its type in |graph_type|.
absl::Status Initialize(const std::string& graph_type, absl::Status Initialize(
const Subgraph::SubgraphOptions* options = nullptr, const std::string& graph_type,
const GraphRegistry* graph_registry = nullptr, const GraphRegistry* graph_registry = nullptr,
const GraphServiceManager* service_manager = nullptr); const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Initializes the ValidatedGraphConfig from the specified graph and subgraph // Initializes the ValidatedGraphConfig from the specified graph and subgraph
// configs. Template graph and subgraph configs can be specified through // configs. Template graph and subgraph configs can be specified through
@ -218,7 +221,7 @@ class ValidatedGraphConfig {
const std::vector<CalculatorGraphConfig>& input_configs, const std::vector<CalculatorGraphConfig>& input_configs,
const std::vector<CalculatorGraphTemplate>& input_templates, const std::vector<CalculatorGraphTemplate>& input_templates,
const std::string& graph_type = "", const std::string& graph_type = "",
const Subgraph::SubgraphOptions* arguments = nullptr, const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr); const GraphServiceManager* service_manager = nullptr);
// Returns true if the ValidatedGraphConfig has been initialized. // Returns true if the ValidatedGraphConfig has been initialized.

View File

@ -155,6 +155,7 @@ TEST(ValidatedGraphConfigTest, InitializeSubgraphWithServiceCalculatorB) {
kStringTestService, std::make_shared<std::string>(calculator_name))); kStringTestService, std::make_shared<std::string>(calculator_name)));
MP_EXPECT_OK(config.Initialize(graph, MP_EXPECT_OK(config.Initialize(graph,
/*graph_registry=*/nullptr, /*graph_registry=*/nullptr,
/*subgraph_options=*/nullptr,
/*service_manager=*/&service_manager)); /*service_manager=*/&service_manager));
ASSERT_TRUE(config.Initialized()); ASSERT_TRUE(config.Initialized());
EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph( EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph(

View File

@ -196,7 +196,9 @@ cc_library(
deps = [ deps = [
":gl_base", ":gl_base",
":gl_context", ":gl_context",
":gl_texture_view",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage",
# TODO: remove this dependency. Some other teams' tests # TODO: remove this dependency. Some other teams' tests
# depend on having an indirect image_frame dependency, need to be # depend on having an indirect image_frame dependency, need to be
# fixed first. # fixed first.
@ -205,6 +207,17 @@ cc_library(
], ],
) )
cc_library(
name = "gl_texture_view",
srcs = ["gl_texture_view.cc"],
hdrs = ["gl_texture_view.h"],
visibility = ["//visibility:public"],
deps = [
":gl_base",
":gl_context",
],
)
cc_library( cc_library(
name = "gpu_buffer", name = "gpu_buffer",
srcs = ["gpu_buffer.cc"], srcs = ["gpu_buffer.cc"],
@ -214,12 +227,15 @@ cc_library(
":gl_base", ":gl_base",
":gl_context", ":gl_context",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage",
":gl_texture_view",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
":gl_texture_buffer", ":gl_texture_buffer",
], ],
"//mediapipe:ios": [ "//mediapipe:ios": [
":gpu_buffer_storage_cv_pixel_buffer",
"//mediapipe/objc:util", "//mediapipe/objc:util",
"//mediapipe/objc:CFHolder", "//mediapipe/objc:CFHolder",
], ],
@ -244,6 +260,35 @@ cc_library(
], ],
) )
cc_library(
name = "gpu_buffer_storage",
hdrs = ["gpu_buffer_storage.h"],
visibility = ["//visibility:public"],
deps = [
":gl_base",
":gpu_buffer_format",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "gpu_buffer_storage_cv_pixel_buffer",
srcs = ["gpu_buffer_storage_cv_pixel_buffer.cc"],
hdrs = ["gpu_buffer_storage_cv_pixel_buffer.h"],
visibility = ["//visibility:public"],
deps = [
":gl_base",
":gl_context",
":gl_texture_view",
":gpu_buffer_storage",
"//mediapipe/objc:CFHolder",
"//mediapipe/objc:util",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "gpu_origin_proto", name = "gpu_origin_proto",
srcs = ["gpu_origin.proto"], srcs = ["gpu_origin.proto"],

View File

@ -109,12 +109,10 @@ GlTexture GlCalculatorHelper::CreateSourceTexture(
return impl_->CreateSourceTexture(image_frame); return impl_->CreateSourceTexture(image_frame);
} }
#ifdef __APPLE__
GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer, GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer,
int plane) { int plane) {
return impl_->CreateSourceTexture(pixel_buffer, plane); return impl_->CreateSourceTexture(pixel_buffer, plane);
} }
#endif
void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer,
int* width, int* height) { int* width, int* height) {

View File

@ -29,10 +29,6 @@
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/graph_support.h" #include "mediapipe/gpu/graph_support.h"
#ifdef __APPLE__
#include <CoreVideo/CoreVideo.h>
#endif // __APPLE__
namespace mediapipe { namespace mediapipe {
class GlCalculatorHelperImpl; class GlCalculatorHelperImpl;
@ -111,24 +107,35 @@ class GlCalculatorHelper {
// where it is supported (iOS, for now) they take advantage of memory sharing // where it is supported (iOS, for now) they take advantage of memory sharing
// between the CPU and GPU, avoiding memory copies. // between the CPU and GPU, avoiding memory copies.
// Creates a texture representing an input frame, and manages sync token. // Gives access to an input frame as an OpenGL texture for reading (sampling).
//
// IMPORTANT: the returned GlTexture should be treated as a short-term view
// into the frame (typically for the duration of a Process call). Do not store
// it as a member in your calculator. If you need to keep a frame around,
// store the GpuBuffer instead, and call CreateSourceTexture again on each
// Process call.
//
// TODO: rename this; the use of "Create" makes this sound more expensive than
// it is.
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer); GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer);
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
GlTexture CreateSourceTexture(const mediapipe::Image& image); GlTexture CreateSourceTexture(const mediapipe::Image& image);
#ifdef __APPLE__ // Gives read access to a plane of a planar buffer.
// Creates a texture from a plane of a planar buffer.
// The plane index is zero-based. The number of planes depends on the // The plane index is zero-based. The number of planes depends on the
// internal format of the buffer. // internal format of the buffer.
// Note: multi-plane support is not available on all platforms.
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer, int plane); GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer, int plane);
#endif
// Convenience function for converting an ImageFrame to GpuBuffer and then
// accessing it as a texture.
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
// Extracts GpuBuffer dimensions without creating a texture. // Extracts GpuBuffer dimensions without creating a texture.
ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead") ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead")
void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width, void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width,
int* height); int* height);
// Creates a texture representing an output frame, and manages sync token. // Gives access to an OpenGL texture for writing (rendering) a new frame.
// TODO: This should either return errors or a status. // TODO: This should either return errors or a status.
GlTexture CreateDestinationTexture( GlTexture CreateDestinationTexture(
int output_width, int output_height, int output_width, int output_height,

View File

@ -62,10 +62,7 @@ class GlCalculatorHelperImpl {
private: private:
// Makes a GpuBuffer accessible as a texture in the GL context. // Makes a GpuBuffer accessible as a texture in the GL context.
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, int plane, GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view);
bool for_reading);
void AttachGlTexture(GlTexture& texture, const GpuBuffer& gpu_buffer,
int plane, bool for_reading);
// Create the framebuffer for rendering. // Create the framebuffer for rendering.
void CreateFramebuffer(); void CreateFramebuffer();

View File

@ -91,9 +91,7 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
} }
GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
int plane, bool for_reading) { GlTextureView view) {
GlTextureView view = gpu_buffer.GetGlTextureView(plane, for_reading);
if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { if (gpu_buffer.format() != GpuBufferFormat::kUnknown) {
// TODO: do the params need to be reset here?? // TODO: do the params need to be reset here??
glBindTexture(view.target(), view.name()); glBindTexture(view.target(), view.name());
@ -109,19 +107,18 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
GlTexture GlCalculatorHelperImpl::CreateSourceTexture( GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
const GpuBuffer& gpu_buffer) { const GpuBuffer& gpu_buffer) {
return MapGpuBuffer(gpu_buffer, 0, true); return CreateSourceTexture(gpu_buffer, 0);
} }
GlTexture GlCalculatorHelperImpl::CreateSourceTexture( GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
const GpuBuffer& gpu_buffer, int plane) { const GpuBuffer& gpu_buffer, int plane) {
return MapGpuBuffer(gpu_buffer, plane, true); return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(plane));
} }
GlTexture GlCalculatorHelperImpl::CreateSourceTexture( GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
const ImageFrame& image_frame) { const ImageFrame& image_frame) {
GlTexture texture = auto gpu_buffer = GpuBuffer::CopyingImageFrame(image_frame);
MapGpuBuffer(GpuBuffer::CopyingImageFrame(image_frame), 0, true); return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(0));
return texture;
} }
template <> template <>
@ -150,11 +147,9 @@ GlTexture GlCalculatorHelperImpl::CreateDestinationTexture(
CreateFramebuffer(); CreateFramebuffer();
} }
GpuBuffer buffer = GpuBuffer gpu_buffer =
gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format); gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format);
GlTexture texture = MapGpuBuffer(buffer, 0, false); return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureWriteView(0));
return texture;
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -14,6 +14,9 @@
#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_buffer.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gl_texture_view.h"
namespace mediapipe { namespace mediapipe {
std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Wrap( std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Wrap(
@ -122,6 +125,15 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
if (alignment != 4 && data) glPixelStorei(GL_UNPACK_ALIGNMENT, 4); if (alignment != 4 && data) glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
// TODO: does this need to set the texture params? We set them again when the
// texture is actually acccessed via GlTexture[View]. Or should they always be
// set on creation?
if (format_ != GpuBufferFormat::kUnknown) {
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format_, /*plane=*/0, context->GetGlVersion());
context->SetStandardTextureParams(target_, info.gl_internal_format);
}
glBindTexture(target_, 0); glBindTexture(target_, 0);
// Use the deletion callback to delete the texture on the context // Use the deletion callback to delete the texture on the context
@ -167,7 +179,7 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
producer_context_ = producer_sync_->GetContext(); producer_context_ = producer_sync_->GetContext();
} }
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) { void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
absl::MutexLock lock(&consumer_sync_mutex_); absl::MutexLock lock(&consumer_sync_mutex_);
consumer_multi_sync_->Add(std::move(cons_token)); consumer_multi_sync_->Add(std::move(cons_token));
} }
@ -181,7 +193,7 @@ GlTextureBuffer::~GlTextureBuffer() {
} }
} }
void GlTextureBuffer::WaitUntilComplete() { void GlTextureBuffer::WaitUntilComplete() const {
// Buffers created by the application (using the constructor that wraps an // Buffers created by the application (using the constructor that wraps an
// existing texture) have no sync token and are assumed to be already // existing texture) have no sync token and are assumed to be already
// complete. // complete.
@ -190,7 +202,7 @@ void GlTextureBuffer::WaitUntilComplete() {
} }
} }
void GlTextureBuffer::WaitOnGpu() { void GlTextureBuffer::WaitOnGpu() const {
// Buffers created by the application (using the constructor that wraps an // Buffers created by the application (using the constructor that wraps an
// existing texture) have no sync token and are assumed to be already // existing texture) have no sync token and are assumed to be already
// complete. // complete.
@ -212,4 +224,127 @@ void GlTextureBuffer::WaitForConsumersOnGpu() {
// precisely, on only one GL context. // precisely, on only one GL context.
} }
GlTextureView GlTextureBuffer::GetGlTextureReadView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
CHECK_EQ(plane, 0);
// Insert wait call to sync with the producer.
WaitOnGpu();
GlTextureView::DetachFn detach = [this](mediapipe::GlTextureView& texture) {
// Inform the GlTextureBuffer that we have finished accessing its
// contents, and create a consumer sync point.
DidRead(texture.gl_context()->CreateSyncToken());
};
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
std::move(gpu_buffer), plane, std::move(detach),
nullptr);
}
GlTextureView GlTextureBuffer::GetGlTextureWriteView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
CHECK_EQ(plane, 0);
// Insert wait call to sync with the producer.
WaitOnGpu();
Reuse(); // TODO: the producer wait should probably be part of Reuse in the
// case when there are no consumers.
GlTextureView::DoneWritingFn done_writing =
[this](const mediapipe::GlTextureView& texture) {
ViewDoneWriting(texture);
};
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
std::move(gpu_buffer), plane, nullptr,
std::move(done_writing));
}
void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) {
// Inform the GlTextureBuffer that we have produced new content, and create
// a producer sync point.
Updated(view.gl_context()->CreateSyncToken());
#ifdef __ANDROID__
// On (some?) Android devices, the texture may need to be explicitly
// detached from the current framebuffer.
// TODO: is this necessary even with the unbind in BindFramebuffer?
// It is not clear if this affected other contexts too, but let's keep it
// while in doubt.
GLint type = GL_NONE;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
&type);
if (type == GL_TEXTURE) {
GLint color_attachment = 0;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment);
if (color_attachment == name()) {
glBindFramebuffer(GL_FRAMEBUFFER, 0);
}
}
// Some Android drivers log a GL_INVALID_ENUM error after the first
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
// even though it should be ok to ask for the type and get back GL_NONE.
// Let's just ignore any pending errors here.
GLenum error;
while ((error = glGetError()) != GL_NO_ERROR) {
}
#endif // __ANDROID__
}
static void ReadTexture(const GlTextureView& view, GpuBufferFormat format,
void* output, size_t size) {
// TODO: check buffer size? We could use glReadnPixels where available
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
// won't overflow the buffer with glReadPixels, we'd also need to check or
// reset several glPixelStore parameters (e.g. what if someone had the
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
CHECK(view.gl_context());
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format, view.plane(), view.gl_context()->GetGlVersion());
GLint current_fbo;
glGetIntegerv(GL_FRAMEBUFFER_BINDING, &current_fbo);
CHECK_NE(current_fbo, 0);
GLint color_attachment_name;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment_name);
if (color_attachment_name != view.name()) {
// Save the viewport. Note that we assume that the color attachment is a
// GL_TEXTURE_2D texture.
GLint viewport[4];
glGetIntegerv(GL_VIEWPORT, viewport);
// Set the data from GLTextureView object.
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
// Restore from the saved viewport and color attachment name.
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
color_attachment_name, 0);
} else {
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
}
}
std::unique_ptr<ImageFrame> GlTextureBuffer::AsImageFrame() const {
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
auto output = absl::make_unique<ImageFrame>(
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
auto view = GetGlTextureReadView(nullptr, 0);
ReadTexture(view, format(), output->MutablePixelData(),
output->PixelDataSize());
return output;
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -25,13 +25,14 @@
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_context.h"
#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
namespace mediapipe { namespace mediapipe {
class GlCalculatorHelperImpl; class GlCalculatorHelperImpl;
// Implements a GPU memory buffer as an OpenGL texture. For internal use. // Implements a GPU memory buffer as an OpenGL texture. For internal use.
class GlTextureBuffer { class GlTextureBuffer : public mediapipe::internal::GpuBufferStorage {
public: public:
// This is called when the texture buffer is deleted. It is passed a sync // This is called when the texture buffer is deleted. It is passed a sync
// token created at that time on the GlContext. If the GlTextureBuffer has // token created at that time on the GlContext. If the GlTextureBuffer has
@ -85,6 +86,13 @@ class GlTextureBuffer {
int height() const { return height_; } int height() const { return height_; }
GpuBufferFormat format() const { return format_; } GpuBufferFormat format() const { return format_; }
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) const override;
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) override;
void ViewDoneWriting(const GlTextureView& view) override;
std::unique_ptr<ImageFrame> AsImageFrame() const override;
// If this texture is going to be used outside of the context that produced // If this texture is going to be used outside of the context that produced
// it, this method should be called to ensure that its updated contents are // it, this method should be called to ensure that its updated contents are
// available. When this method returns, all changed made before the call to // available. When this method returns, all changed made before the call to
@ -94,13 +102,13 @@ class GlTextureBuffer {
// NOTE: This blocks the current CPU thread and makes the changes visible // NOTE: This blocks the current CPU thread and makes the changes visible
// to the CPU. If you want to access the data via OpenGL, use WaitOnGpu // to the CPU. If you want to access the data via OpenGL, use WaitOnGpu
// instead. // instead.
void WaitUntilComplete(); void WaitUntilComplete() const;
// Call this method to synchronize the current GL context with the texture's // Call this method to synchronize the current GL context with the texture's
// producer. This will not block the current CPU thread, but will ensure that // producer. This will not block the current CPU thread, but will ensure that
// subsequent GL commands see the texture in its complete status, with all // subsequent GL commands see the texture in its complete status, with all
// rendering done on the GPU by the generating context. // rendering done on the GPU by the generating context.
void WaitOnGpu(); void WaitOnGpu() const;
// Informs the buffer that its contents are going to be overwritten. // Informs the buffer that its contents are going to be overwritten.
// This invalidates the current sync token. // This invalidates the current sync token.
@ -114,7 +122,7 @@ class GlTextureBuffer {
void Updated(std::shared_ptr<GlSyncPoint> prod_token); void Updated(std::shared_ptr<GlSyncPoint> prod_token);
// Informs the buffer that a consumer has finished reading from it. // Informs the buffer that a consumer has finished reading from it.
void DidRead(std::shared_ptr<GlSyncPoint> cons_token); void DidRead(std::shared_ptr<GlSyncPoint> cons_token) const;
// Waits for all pending consumers to finish accessing the current content // Waits for all pending consumers to finish accessing the current content
// of the texture. This (preferably the OnGpu version) should be called // of the texture. This (preferably the OnGpu version) should be called
@ -143,10 +151,11 @@ class GlTextureBuffer {
const GLenum target_ = GL_TEXTURE_2D; const GLenum target_ = GL_TEXTURE_2D;
// Token tracking changes to this texture. Used by WaitUntilComplete. // Token tracking changes to this texture. Used by WaitUntilComplete.
std::shared_ptr<GlSyncPoint> producer_sync_; std::shared_ptr<GlSyncPoint> producer_sync_;
absl::Mutex consumer_sync_mutex_; mutable absl::Mutex consumer_sync_mutex_;
// Tokens tracking the point when consumers finished using this texture. // Tokens tracking the point when consumers finished using this texture.
std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_ ABSL_GUARDED_BY( mutable std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_
consumer_sync_mutex_) = absl::make_unique<GlMultiSyncPoint>(); ABSL_GUARDED_BY(consumer_sync_mutex_) =
absl::make_unique<GlMultiSyncPoint>();
DeletionCallback deletion_callback_; DeletionCallback deletion_callback_;
std::shared_ptr<GlContext> producer_context_; std::shared_ptr<GlContext> producer_context_;
}; };

View File

@ -0,0 +1,16 @@
#include "mediapipe/gpu/gl_texture_view.h"
namespace mediapipe {
void GlTextureView::Release() {
if (detach_) detach_(*this);
detach_ = nullptr;
gl_context_ = nullptr;
gpu_buffer_ = nullptr;
plane_ = 0;
name_ = 0;
width_ = 0;
height_ = 0;
}
} // namespace mediapipe

View File

@ -0,0 +1,86 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
#define MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
#include <functional>
#include <memory>
#include <utility>
#include "mediapipe/gpu/gl_base.h"
namespace mediapipe {
class GlContext;
class GlTextureViewManager;
class GpuBuffer;
class GlTextureView {
public:
GlTextureView() {}
~GlTextureView() { Release(); }
// TODO: make this class move-only.
GlContext* gl_context() const { return gl_context_; }
int width() const { return width_; }
int height() const { return height_; }
GLenum target() const { return target_; }
GLuint name() const { return name_; }
const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; }
int plane() const { return plane_; }
using DetachFn = std::function<void(GlTextureView&)>;
using DoneWritingFn = std::function<void(const GlTextureView&)>;
private:
friend class GpuBuffer;
friend class GlTextureBuffer;
friend class GpuBufferStorageCvPixelBuffer;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
DetachFn detach, DoneWritingFn done_writing)
: gl_context_(context),
target_(target),
name_(name),
width_(width),
height_(height),
gpu_buffer_(std::move(gpu_buffer)),
plane_(plane),
detach_(std::move(detach)),
done_writing_(std::move(done_writing)) {}
// TODO: remove this friend declaration.
friend class GlTexture;
void Release();
// TODO: make this non-const.
void DoneWriting() const {
if (done_writing_) done_writing_(*this);
}
GlContext* gl_context_ = nullptr;
GLenum target_ = GL_TEXTURE_2D;
GLuint name_ = 0;
// Note: when scale is not 1, we still give the nominal size of the image.
int width_ = 0;
int height_ = 0;
std::shared_ptr<GpuBuffer> gpu_buffer_; // using shared_ptr temporarily
int plane_ = 0;
DetachFn detach_;
DoneWritingFn done_writing_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_

View File

@ -8,62 +8,7 @@
namespace mediapipe { namespace mediapipe {
void GlTextureView::Release() {
if (detach_) detach_(*this);
detach_ = nullptr;
gl_context_ = nullptr;
gpu_buffer_ = nullptr;
plane_ = 0;
name_ = 0;
width_ = 0;
height_ = 0;
}
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#if TARGET_OS_OSX
typedef CVOpenGLTextureRef CVTextureType;
#else
typedef CVOpenGLESTextureRef CVTextureType;
#endif // TARGET_OS_OSX
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
CVReturn err;
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
#if TARGET_OS_OSX
CVTextureType cv_texture_temp;
err = CVOpenGLTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(),
GetCVPixelBufferRef(), NULL, &cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
[cv_texture](
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
#else
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format(), plane, gl_context->GetGlVersion());
CVTextureType cv_texture_temp;
err = CVOpenGLESTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(),
GetCVPixelBufferRef(), NULL, GL_TEXTURE_2D, info.gl_internal_format,
width() / info.downscale, height() / info.downscale, info.gl_format,
info.gl_type, plane, &cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
CVOpenGLESTextureGetName(*cv_texture), width(), height(), *this, plane,
[cv_texture](
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
#endif // TARGET_OS_OSX
}
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) { GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame);
@ -72,187 +17,11 @@ GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
CHECK_OK(maybe_buffer.status()); CHECK_OK(maybe_buffer.status());
return GpuBuffer(std::move(maybe_buffer).value()); return GpuBuffer(std::move(maybe_buffer).value());
} }
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
CHECK(GetCVPixelBufferRef());
return CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef());
}
void GlTextureView::DoneWriting() const {
CHECK(gpu_buffer_);
#if TARGET_IPHONE_SIMULATOR
CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef();
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferLockBaseAddress failed: " << err;
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
uint8_t* pixel_ptr =
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
if (pixel_format == kCVPixelFormatType_32BGRA) {
// TODO: restore previous framebuffer? Move this to helper so we
// can use BindFramebuffer?
glViewport(0, 0, width(), height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target(),
name(), 0);
size_t contiguous_bytes_per_row = width() * 4;
if (bytes_per_row == contiguous_bytes_per_row) {
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
pixel_ptr);
} else {
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
height());
uint8_t* temp_ptr = contiguous_buffer.data();
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
temp_ptr);
for (int i = 0; i < height(); ++i) {
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
temp_ptr += contiguous_bytes_per_row;
pixel_ptr += bytes_per_row;
}
}
} else {
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
}
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
#endif
}
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
const GlTextureBufferSharedPtr& texture_buffer =
GetGlTextureBufferSharedPtr();
// Insert wait call to sync with the producer.
texture_buffer->WaitOnGpu();
CHECK_EQ(plane, 0);
GlTextureView::DetachFn detach;
if (for_reading) {
detach = [](mediapipe::GlTextureView& texture) {
// Inform the GlTextureBuffer that we have finished accessing its
// contents, and create a consumer sync point.
texture.gpu_buffer().GetGlTextureBufferSharedPtr()->DidRead(
texture.gl_context()->CreateSyncToken());
};
}
return GlTextureView(gl_context.get(), texture_buffer->target(),
texture_buffer->name(), width(), height(), *this, plane,
std::move(detach));
}
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) { GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
auto gl_context = GlContext::GetCurrent(); return GpuBuffer(GlTextureBuffer::Create(image_frame));
CHECK(gl_context);
auto buffer = GlTextureBuffer::Create(image_frame);
// TODO: does this need to set the texture params? We set them again when the
// texture is actually acccessed via GlTexture[View]. Or should they always be
// set on creation?
if (buffer->format() != GpuBufferFormat::kUnknown) {
glBindTexture(GL_TEXTURE_2D, buffer->name());
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
buffer->format(), /*plane=*/0, gl_context->GetGlVersion());
gl_context->SetStandardTextureParams(buffer->target(),
info.gl_internal_format);
glBindTexture(GL_TEXTURE_2D, 0);
}
return GpuBuffer(std::move(buffer));
}
static void ReadTexture(const GlTextureView& view, void* output, size_t size) {
// TODO: check buffer size? We could use glReadnPixels where available
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
// won't overflow the buffer with glReadPixels, we'd also need to check or
// reset several glPixelStore parameters (e.g. what if someone had the
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
CHECK(view.gl_context());
GlTextureInfo info =
GlTextureInfoForGpuBufferFormat(view.gpu_buffer().format(), view.plane(),
view.gl_context()->GetGlVersion());
GLint current_fbo;
glGetIntegerv(GL_FRAMEBUFFER_BINDING, &current_fbo);
CHECK_NE(current_fbo, 0);
GLint color_attachment_name;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment_name);
if (color_attachment_name != view.name()) {
// Save the viewport. Note that we assume that the color attachment is a
// GL_TEXTURE_2D texture.
GLint viewport[4];
glGetIntegerv(GL_VIEWPORT, viewport);
// Set the data from GLTextureView object.
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
// Restore from the saved viewport and color attachment name.
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
color_attachment_name, 0);
} else {
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
}
}
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
auto output = absl::make_unique<ImageFrame>(
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
auto view = GetGlTextureView(0, true);
ReadTexture(view, output->MutablePixelData(), output->PixelDataSize());
return output;
}
void GlTextureView::DoneWriting() const {
CHECK(gpu_buffer_);
// Inform the GlTextureBuffer that we have produced new content, and create
// a producer sync point.
gpu_buffer_.GetGlTextureBufferSharedPtr()->Updated(
gl_context()->CreateSyncToken());
#ifdef __ANDROID__
// On (some?) Android devices, the texture may need to be explicitly
// detached from the current framebuffer.
// TODO: is this necessary even with the unbind in BindFramebuffer?
// It is not clear if this affected other contexts too, but let's keep it
// while in doubt.
GLint type = GL_NONE;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
&type);
if (type == GL_TEXTURE) {
GLint color_attachment = 0;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment);
if (color_attachment == name()) {
glBindFramebuffer(GL_FRAMEBUFFER, 0);
}
}
// Some Android drivers log a GL_INVALID_ENUM error after the first
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
// even though it should be ok to ask for the type and get back GL_NONE.
// Let's just ignore any pending errors here.
GLenum error;
while ((error = glGetError()) != GL_NO_ERROR) {
}
#endif // __ANDROID__
} }
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER

View File

@ -19,7 +19,9 @@
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_texture_view.h"
#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
#if defined(__APPLE__) #if defined(__APPLE__)
#include <CoreVideo/CoreVideo.h> #include <CoreVideo/CoreVideo.h>
@ -27,6 +29,10 @@
#include "mediapipe/objc/CFHolder.h" #include "mediapipe/objc/CFHolder.h"
#endif // defined(__APPLE__) #endif // defined(__APPLE__)
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/gl_texture_buffer.h"
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
@ -34,7 +40,6 @@
namespace mediapipe { namespace mediapipe {
class GlContext; class GlContext;
class GlTextureView;
// This class wraps a platform-specific buffer of GPU data. // This class wraps a platform-specific buffer of GPU data.
// An instance of GpuBuffer acts as an opaque reference to the underlying // An instance of GpuBuffer acts as an opaque reference to the underlying
@ -71,9 +76,9 @@ class GpuBuffer {
} }
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
int width() const; int width() const { return current_storage().width(); }
int height() const; int height() const { return current_storage().height(); }
GpuBufferFormat format() const; GpuBufferFormat format() const { return current_storage().format(); }
// Converts to true iff valid. // Converts to true iff valid.
explicit operator bool() const { return operator!=(nullptr); } explicit operator bool() const { return operator!=(nullptr); }
@ -88,8 +93,15 @@ class GpuBuffer {
// Allow assignment from nullptr. // Allow assignment from nullptr.
GpuBuffer& operator=(std::nullptr_t other); GpuBuffer& operator=(std::nullptr_t other);
// TODO: split into read and write, remove const from write. GlTextureView GetGlTextureReadView(int plane) const {
GlTextureView GetGlTextureView(int plane, bool for_reading) const; return current_storage().GetGlTextureReadView(
std::make_shared<GpuBuffer>(*this), plane);
}
GlTextureView GetGlTextureWriteView(int plane) {
return current_storage().GetGlTextureWriteView(
std::make_shared<GpuBuffer>(*this), plane);
}
// Make a GpuBuffer copying the data from an ImageFrame. // Make a GpuBuffer copying the data from an ImageFrame.
static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame); static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame);
@ -99,114 +111,84 @@ class GpuBuffer {
// In order to work correctly across platforms, callers should always treat // In order to work correctly across platforms, callers should always treat
// the returned ImageFrame as if it shares memory with the GpuBuffer, i.e. // the returned ImageFrame as if it shares memory with the GpuBuffer, i.e.
// treat it as immutable if the GpuBuffer must not be modified. // treat it as immutable if the GpuBuffer must not be modified.
std::unique_ptr<ImageFrame> AsImageFrame() const; std::unique_ptr<ImageFrame> AsImageFrame() const {
return current_storage().AsImageFrame();
}
private: private:
class PlaceholderGpuBufferStorage
: public mediapipe::internal::GpuBufferStorage {
public:
int width() const override { return 0; }
int height() const override { return 0; }
virtual GpuBufferFormat format() const override {
return GpuBufferFormat::kUnknown;
}
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) const override {
return {};
}
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) override {
return {};
}
void ViewDoneWriting(const GlTextureView& view) override{};
std::unique_ptr<ImageFrame> AsImageFrame() const override {
return nullptr;
}
};
mediapipe::internal::GpuBufferStorage& no_storage() const {
static PlaceholderGpuBufferStorage placeholder;
return placeholder;
}
const mediapipe::internal::GpuBufferStorage& current_storage() const {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
CFHolder<CVPixelBufferRef> pixel_buffer_; if (pixel_buffer_ != nullptr) return pixel_buffer_;
#else
if (texture_buffer_) return *texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return no_storage();
}
mediapipe::internal::GpuBufferStorage& current_storage() {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
if (pixel_buffer_ != nullptr) return pixel_buffer_;
#else
if (texture_buffer_) return *texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return no_storage();
}
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
GpuBufferStorageCvPixelBuffer pixel_buffer_;
#else #else
GlTextureBufferSharedPtr texture_buffer_; GlTextureBufferSharedPtr texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
}; };
class GlTextureView { inline bool GpuBuffer::operator==(std::nullptr_t other) const {
public: return &current_storage() == &no_storage();
GlTextureView() {} }
~GlTextureView() { Release(); }
// TODO: make this class move-only.
GlContext* gl_context() const { return gl_context_; }
int width() const { return width_; }
int height() const { return height_; }
GLenum target() const { return target_; }
GLuint name() const { return name_; }
const GpuBuffer& gpu_buffer() const { return gpu_buffer_; }
int plane() const { return plane_; }
private:
friend class GpuBuffer;
using DetachFn = std::function<void(GlTextureView&)>;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, GpuBuffer gpu_buffer, int plane, DetachFn detach)
: gl_context_(context),
target_(target),
name_(name),
width_(width),
height_(height),
gpu_buffer_(std::move(gpu_buffer)),
plane_(plane),
detach_(std::move(detach)) {}
// TODO: remove this friend declaration.
friend class GlTexture;
void Release();
// TODO: make this non-const.
void DoneWriting() const;
GlContext* gl_context_ = nullptr;
GLenum target_ = GL_TEXTURE_2D;
GLuint name_ = 0;
// Note: when scale is not 1, we still give the nominal size of the image.
int width_ = 0;
int height_ = 0;
GpuBuffer gpu_buffer_;
int plane_ = 0;
DetachFn detach_;
};
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
inline int GpuBuffer::width() const {
return static_cast<int>(CVPixelBufferGetWidth(*pixel_buffer_));
}
inline int GpuBuffer::height() const {
return static_cast<int>(CVPixelBufferGetHeight(*pixel_buffer_));
}
inline GpuBufferFormat GpuBuffer::format() const {
return GpuBufferFormatForCVPixelFormat(
CVPixelBufferGetPixelFormatType(*pixel_buffer_));
}
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
return pixel_buffer_ == other;
}
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
return pixel_buffer_ == other.pixel_buffer_; return pixel_buffer_ == other.pixel_buffer_;
}
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
pixel_buffer_.reset(other);
return *this;
}
#else #else
inline int GpuBuffer::width() const { return texture_buffer_->width(); }
inline int GpuBuffer::height() const { return texture_buffer_->height(); }
inline GpuBufferFormat GpuBuffer::format() const {
return texture_buffer_->format();
}
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
return texture_buffer_ == other;
}
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
return texture_buffer_ == other.texture_buffer_; return texture_buffer_ == other.texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
} }
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) { inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
pixel_buffer_.reset(other);
#else
texture_buffer_ = other; texture_buffer_ = other;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return *this; return *this;
} }
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_BUFFER_H_ #endif // MEDIAPIPE_GPU_GPU_BUFFER_H_

View File

@ -0,0 +1,41 @@
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
namespace mediapipe {
class GlTextureView;
class GpuBuffer;
} // namespace mediapipe
namespace mediapipe {
namespace internal {
using mediapipe::GlTextureView;
using mediapipe::GpuBuffer;
using mediapipe::GpuBufferFormat;
class GlTextureViewManager {
public:
virtual ~GlTextureViewManager() = default;
virtual GlTextureView GetGlTextureReadView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const = 0;
virtual GlTextureView GetGlTextureWriteView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) = 0;
virtual void ViewDoneWriting(const GlTextureView& view) = 0;
};
class GpuBufferStorage : public GlTextureViewManager {
public:
virtual ~GpuBufferStorage() = default;
virtual int width() const = 0;
virtual int height() const = 0;
virtual GpuBufferFormat format() const = 0;
virtual std::unique_ptr<ImageFrame> AsImageFrame() const = 0;
};
} // namespace internal
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_

View File

@ -0,0 +1,116 @@
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
#include "mediapipe/gpu/gl_context.h"
#include "mediapipe/objc/util.h"
namespace mediapipe {
#if TARGET_OS_OSX
typedef CVOpenGLTextureRef CVTextureType;
#else
typedef CVOpenGLESTextureRef CVTextureType;
#endif // TARGET_OS_OSX
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureReadView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
CVReturn err;
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
#if TARGET_OS_OSX
CVTextureType cv_texture_temp;
err = CVOpenGLTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
&cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
[cv_texture](
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
#else
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format(), plane, gl_context->GetGlVersion());
CVTextureType cv_texture_temp;
err = CVOpenGLESTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
GL_TEXTURE_2D, info.gl_internal_format, width() / info.downscale,
height() / info.downscale, info.gl_format, info.gl_type, plane,
&cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
CVOpenGLESTextureGetName(*cv_texture), width(), height(),
std::move(gpu_buffer), plane,
[cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ },
// TODO: make GetGlTextureView for write view non-const, remove cast
// Note: we have to copy *this here because this storage is currently
// stored in GpuBuffer by value, and so the this pointer becomes invalid
// if the GpuBuffer is moved/copied. TODO: fix this.
[me = *this](const mediapipe::GlTextureView& view) {
const_cast<GpuBufferStorageCvPixelBuffer*>(&me)->ViewDoneWriting(view);
});
#endif // TARGET_OS_OSX
}
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureWriteView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
// For this storage there is currently no difference between read and write
// views, so we delegate to the read method.
return GetGlTextureReadView(std::move(gpu_buffer), plane);
}
void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) {
#if TARGET_IPHONE_SIMULATOR
CVPixelBufferRef pixel_buffer = **this;
CHECK(pixel_buffer);
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferLockBaseAddress failed: " << err;
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
uint8_t* pixel_ptr =
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
if (pixel_format == kCVPixelFormatType_32BGRA) {
// TODO: restore previous framebuffer? Move this to helper so we
// can use BindFramebuffer?
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
size_t contiguous_bytes_per_row = view.width() * 4;
if (bytes_per_row == contiguous_bytes_per_row) {
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
pixel_ptr);
} else {
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
view.height());
uint8_t* temp_ptr = contiguous_buffer.data();
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
temp_ptr);
for (int i = 0; i < view.height(); ++i) {
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
temp_ptr += contiguous_bytes_per_row;
pixel_ptr += bytes_per_row;
}
}
} else {
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
}
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
#endif
}
std::unique_ptr<ImageFrame> GpuBufferStorageCvPixelBuffer::AsImageFrame()
const {
return CreateImageFrameForCVPixelBuffer(**this);
}
} // namespace mediapipe

View File

@ -0,0 +1,41 @@
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
#include <CoreVideo/CoreVideo.h>
#include "mediapipe/gpu/gl_texture_view.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
#include "mediapipe/objc/CFHolder.h"
namespace mediapipe {
class GlContext;
class GpuBufferStorageCvPixelBuffer
: public mediapipe::internal::GpuBufferStorage,
public CFHolder<CVPixelBufferRef> {
public:
using CFHolder<CVPixelBufferRef>::CFHolder;
GpuBufferStorageCvPixelBuffer(const CFHolder<CVPixelBufferRef>& other)
: CFHolder(other) {}
GpuBufferStorageCvPixelBuffer(CFHolder<CVPixelBufferRef>&& other)
: CFHolder(std::move(other)) {}
int width() const { return static_cast<int>(CVPixelBufferGetWidth(**this)); }
int height() const {
return static_cast<int>(CVPixelBufferGetHeight(**this));
}
virtual GpuBufferFormat format() const {
return GpuBufferFormatForCVPixelFormat(
CVPixelBufferGetPixelFormatType(**this));
}
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) const override;
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) override;
std::unique_ptr<ImageFrame> AsImageFrame() const override;
void ViewDoneWriting(const GlTextureView& view) override;
};
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_

View File

@ -64,6 +64,13 @@ public class AppTextureFrame implements TextureFrame {
return timestamp; return timestamp;
} }
/** Returns true if a call to waitUntilReleased() would block waiting for release. */
public boolean isNotYetReleased() {
synchronized (this) {
return inUse && releaseSyncToken == null;
}
}
/** /**
* Waits until the consumer is done with the texture. * Waits until the consumer is done with the texture.
* *

View File

@ -26,7 +26,8 @@ android_library(
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_gpu_image.binarypb", "//mediapipe/modules/hand_landmark:hand_landmark_tracking_gpu_image.binarypb",
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image.binarypb", "//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image.binarypb",
"//mediapipe/modules/hand_landmark:handedness.txt", "//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite", "//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/palm_detection:palm_detection.tflite", "//mediapipe/modules/palm_detection:palm_detection.tflite",
], ],
assets_dir = "", assets_dir = "",

View File

@ -78,6 +78,7 @@ public class Hands extends ImageSolutionBase {
Connection.create(HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP), Connection.create(HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP),
Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP)); Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP));
private static final String MODEL_COMPLEXITY = "model_complexity";
private static final String NUM_HANDS = "num_hands"; private static final String NUM_HANDS = "num_hands";
private static final String USE_PREV_LANDMARKS = "use_prev_landmarks"; private static final String USE_PREV_LANDMARKS = "use_prev_landmarks";
private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb"; private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb";
@ -131,6 +132,7 @@ public class Hands extends ImageSolutionBase {
initialize(context, solutionInfo, outputHandler); initialize(context, solutionInfo, outputHandler);
Map<String, Packet> inputSidePackets = new HashMap<>(); Map<String, Packet> inputSidePackets = new HashMap<>();
inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands())); inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands()));
inputSidePackets.put(MODEL_COMPLEXITY, packetCreator.createInt32(options.modelComplexity()));
inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode())); inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode()));
start(inputSidePackets); start(inputSidePackets);
} }

View File

@ -26,6 +26,10 @@ import com.google.auto.value.AutoValue;
* <p>maxNumHands: Maximum number of hands to detect. See details in * <p>maxNumHands: Maximum number of hands to detect. See details in
* https://solutions.mediapipe.dev/hands#max_num_hands. * https://solutions.mediapipe.dev/hands#max_num_hands.
* *
* <p>modelComplexity: Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
* inference latency generally go up with the model complexity. See details in
* https://solutions.mediapipe.dev/hands#model_complexity.
*
* <p>minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for hand detection to be * <p>minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for hand detection to be
* considered successful. See details in * considered successful. See details in
* https://solutions.mediapipe.dev/hands#min_detection_confidence. * https://solutions.mediapipe.dev/hands#min_detection_confidence.
@ -43,6 +47,8 @@ public abstract class HandsOptions {
public abstract int maxNumHands(); public abstract int maxNumHands();
public abstract int modelComplexity();
public abstract float minDetectionConfidence(); public abstract float minDetectionConfidence();
public abstract float minTrackingConfidence(); public abstract float minTrackingConfidence();
@ -59,6 +65,7 @@ public abstract class HandsOptions {
public Builder withDefaultValues() { public Builder withDefaultValues() {
return setStaticImageMode(false) return setStaticImageMode(false)
.setMaxNumHands(2) .setMaxNumHands(2)
.setModelComplexity(1)
.setMinDetectionConfidence(0.5f) .setMinDetectionConfidence(0.5f)
.setMinTrackingConfidence(0.5f) .setMinTrackingConfidence(0.5f)
.setRunOnGpu(true); .setRunOnGpu(true);
@ -68,6 +75,8 @@ public abstract class HandsOptions {
public abstract Builder setMaxNumHands(int value); public abstract Builder setMaxNumHands(int value);
public abstract Builder setModelComplexity(int value);
public abstract Builder setMinDetectionConfidence(float value); public abstract Builder setMinDetectionConfidence(float value);
public abstract Builder setMinTrackingConfidence(float value); public abstract Builder setMinTrackingConfidence(float value);

View File

@ -22,16 +22,30 @@ licenses(["notice"])
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
exports_files([ exports_files([
"hand_landmark.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"hand_landmark_sparse.tflite", "hand_landmark_sparse.tflite",
"handedness.txt", "handedness.txt",
]) ])
mediapipe_simple_subgraph(
name = "hand_landmark_model_loader",
graph = "hand_landmark_model_loader.pbtxt",
register_as = "HandLandmarkModelLoader",
deps = [
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/tflite:tflite_model_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/framework/tool:switch_container",
],
)
mediapipe_simple_subgraph( mediapipe_simple_subgraph(
name = "hand_landmark_cpu", name = "hand_landmark_cpu",
graph = "hand_landmark_cpu.pbtxt", graph = "hand_landmark_cpu.pbtxt",
register_as = "HandLandmarkCpu", register_as = "HandLandmarkCpu",
deps = [ deps = [
":hand_landmark_model_loader",
"//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator",
@ -50,6 +64,7 @@ mediapipe_simple_subgraph(
graph = "hand_landmark_gpu.pbtxt", graph = "hand_landmark_gpu.pbtxt",
register_as = "HandLandmarkGpu", register_as = "HandLandmarkGpu",
deps = [ deps = [
":hand_landmark_model_loader",
"//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator",

View File

@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
# (NormalizedRect) # (NormalizedRect)
input_stream: "ROI:hand_rect" input_stream: "ROI:hand_rect"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList) # 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
# NOTE: if a hand is not present within the given ROI, for this particular # NOTE: if a hand is not present within the given ROI, for this particular
# timestamp there will not be an output packet in the LANDMARKS stream. However, # timestamp there will not be an output packet in the LANDMARKS stream. However,
@ -40,16 +45,23 @@ node {
} }
} }
# Loads the hand landmark TF Lite model.
node {
calculator: "HandLandmarkModelLoader"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
output_side_packet: "MODEL:model"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and # vector of tensors representing, for instance, detection boxes/keypoints and
# scores. # scores.
node { node {
calculator: "InferenceCalculator" calculator: "InferenceCalculator"
input_side_packet: "MODEL:model"
input_stream: "TENSORS:input_tensor" input_stream: "TENSORS:input_tensor"
output_stream: "TENSORS:output_tensors" output_stream: "TENSORS:output_tensors"
options: { options: {
[mediapipe.InferenceCalculatorOptions.ext] { [mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
delegate { delegate {
xnnpack {} xnnpack {}
} }

Binary file not shown.

View File

@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
# (NormalizedRect) # (NormalizedRect)
input_stream: "ROI:hand_rect" input_stream: "ROI:hand_rect"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList) # 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
# NOTE: if a hand is not present within the given ROI, for this particular # NOTE: if a hand is not present within the given ROI, for this particular
# timestamp there will not be an output packet in the LANDMARKS stream. However, # timestamp there will not be an output packet in the LANDMARKS stream. However,
@ -41,18 +46,21 @@ node {
} }
} }
# Loads the hand landmark TF Lite model.
node {
calculator: "HandLandmarkModelLoader"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
output_side_packet: "MODEL:model"
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and # vector of tensors representing, for instance, detection boxes/keypoints and
# scores. # scores.
node { node {
calculator: "InferenceCalculator" calculator: "InferenceCalculator"
input_side_packet: "MODEL:model"
input_stream: "TENSORS:input_tensor" input_stream: "TENSORS:input_tensor"
output_stream: "TENSORS:output_tensors" output_stream: "TENSORS:output_tensors"
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
}
}
} }
# Splits a vector of tensors to multiple vectors according to the ranges # Splits a vector of tensors to multiple vectors according to the ranges

Binary file not shown.

View File

@ -0,0 +1,63 @@
# MediaPipe graph to load a selected hand landmark TF Lite model.
type: "HandLandmarkModelLoader"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# TF Lite model represented as a FlatBuffer.
# (std::unique_ptr<tflite::FlatBufferModel, std::function<void(tflite::FlatBufferModel*)>>)
output_side_packet: "MODEL:model"
# Determines path to the desired pose landmark model file.
node {
calculator: "SwitchContainer"
input_side_packet: "SELECT:model_complexity"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.SwitchContainerOptions.ext] {
select: 1
contained_node: {
calculator: "ConstantSidePacketCalculator"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet {
string_value: "mediapipe/modules/hand_landmark/hand_landmark_lite.tflite"
}
}
}
}
contained_node: {
calculator: "ConstantSidePacketCalculator"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet {
string_value: "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
}
}
}
}
}
}
}
# Loads the file in the specified path into a blob.
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
options: {
[mediapipe.LocalFileContentsCalculatorOptions.ext]: {
text_mode: false
}
}
}
# Converts the input blob into a TF Lite model.
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
output_side_packet: "MODEL:model"
}

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int) # Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands" input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize # Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool) # landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -177,6 +182,7 @@ node {
# Detect hand landmarks for the specific hand rect. # Detect hand landmarks for the specific hand rect.
node { node {
calculator: "HandLandmarkCpu" calculator: "HandLandmarkCpu"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_stream: "IMAGE:image_for_landmarks" input_stream: "IMAGE:image_for_landmarks"
input_stream: "ROI:single_hand_rect" input_stream: "ROI:single_hand_rect"
output_stream: "LANDMARKS:single_hand_landmarks" output_stream: "LANDMARKS:single_hand_landmarks"

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int) # Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands" input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize # Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool) # landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -85,6 +90,7 @@ node {
calculator: "HandLandmarkTrackingCpu" calculator: "HandLandmarkTrackingCpu"
input_stream: "IMAGE:image_frame" input_stream: "IMAGE:image_frame"
input_side_packet: "NUM_HANDS:num_hands" input_side_packet: "NUM_HANDS:num_hands"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
output_stream: "LANDMARKS:multi_hand_landmarks" output_stream: "LANDMARKS:multi_hand_landmarks"
output_stream: "HANDEDNESS:multi_handedness" output_stream: "HANDEDNESS:multi_handedness"

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int) # Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands" input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize # Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool) # landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -178,6 +183,7 @@ node {
# Detect hand landmarks for the specific hand rect. # Detect hand landmarks for the specific hand rect.
node { node {
calculator: "HandLandmarkGpu" calculator: "HandLandmarkGpu"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_stream: "IMAGE:image_for_landmarks" input_stream: "IMAGE:image_for_landmarks"
input_stream: "ROI:single_hand_rect" input_stream: "ROI:single_hand_rect"
output_stream: "LANDMARKS:single_hand_landmarks" output_stream: "LANDMARKS:single_hand_landmarks"

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int) # Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands" input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize # Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool) # landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -85,6 +90,7 @@ node {
calculator: "HandLandmarkTrackingGpu" calculator: "HandLandmarkTrackingGpu"
input_stream: "IMAGE:gpu_buffer" input_stream: "IMAGE:gpu_buffer"
input_side_packet: "NUM_HANDS:num_hands" input_side_packet: "NUM_HANDS:num_hands"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks" input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
output_stream: "LANDMARKS:multi_hand_landmarks" output_stream: "LANDMARKS:multi_hand_landmarks"
output_stream: "HANDEDNESS:multi_handedness" output_stream: "HANDEDNESS:multi_handedness"

View File

@ -7,8 +7,8 @@
# - "face_landmark.tflite" is available at # - "face_landmark.tflite" is available at
# "mediapipe/modules/face_landmark/face_landmark.tflite" # "mediapipe/modules/face_landmark/face_landmark.tflite"
# #
# - "hand_landmark.tflite" is available at # - "hand_landmark_full.tflite" is available at
# "mediapipe/modules/hand_landmark/hand_landmark.tflite" # "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
# #
# - "hand_recrop.tflite" is available at # - "hand_recrop.tflite" is available at
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite" # "mediapipe/modules/holistic_landmark/hand_recrop.tflite"

View File

@ -7,8 +7,8 @@
# - "face_landmark.tflite" is available at # - "face_landmark.tflite" is available at
# "mediapipe/modules/face_landmark/face_landmark.tflite" # "mediapipe/modules/face_landmark/face_landmark.tflite"
# #
# - "hand_landmark.tflite" is available at # - "hand_landmark_full.tflite" is available at
# "mediapipe/modules/hand_landmark/hand_landmark.tflite" # "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
# #
# - "hand_recrop.tflite" is available at # - "hand_recrop.tflite" is available at
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite" # "mediapipe/modules/holistic_landmark/hand_recrop.tflite"

View File

@ -237,7 +237,7 @@ absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) {
} }
bool FilterDetectionCalculator::IsValidLabel(const std::string& label) { bool FilterDetectionCalculator::IsValidLabel(const std::string& label) {
bool match = !limit_labels_ || ContainsKey(allowed_labels_, label); bool match = !limit_labels_ || allowed_labels_.contains(label);
if (!match) { if (!match) {
// If no exact match is found, check for regular expression // If no exact match is found, check for regular expression
// comparions in the allowed_labels. // comparions in the allowed_labels.

View File

@ -21,6 +21,9 @@ cc_library(
features = ["-parse_headers"], features = ["-parse_headers"],
linkopts = [ linkopts = [
"-framework Accelerate", "-framework Accelerate",
"-framework CoreFoundation",
"-framework CoreGraphics",
"-framework CoreVideo",
], ],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [ deps = [

View File

@ -17,6 +17,7 @@
#import <Accelerate/Accelerate.h> #import <Accelerate/Accelerate.h>
#import <CoreFoundation/CoreFoundation.h> #import <CoreFoundation/CoreFoundation.h>
#import <CoreGraphics/CoreGraphics.h>
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"

View File

@ -89,6 +89,7 @@ class Hands(SolutionBase):
def __init__(self, def __init__(self,
static_image_mode=False, static_image_mode=False,
max_num_hands=2, max_num_hands=2,
model_complexity=1,
min_detection_confidence=0.5, min_detection_confidence=0.5,
min_tracking_confidence=0.5): min_tracking_confidence=0.5):
"""Initializes a MediaPipe Hand object. """Initializes a MediaPipe Hand object.
@ -99,6 +100,10 @@ class Hands(SolutionBase):
https://solutions.mediapipe.dev/hands#static_image_mode. https://solutions.mediapipe.dev/hands#static_image_mode.
max_num_hands: Maximum number of hands to detect. See details in max_num_hands: Maximum number of hands to detect. See details in
https://solutions.mediapipe.dev/hands#max_num_hands. https://solutions.mediapipe.dev/hands#max_num_hands.
model_complexity: Complexity of the hand landmark model: 0 or 1.
Landmark accuracy as well as inference latency generally go up with the
model complexity. See details in
https://solutions.mediapipe.dev/hands#model_complexity.
min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for hand min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for hand
detection to be considered successful. See details in detection to be considered successful. See details in
https://solutions.mediapipe.dev/hands#min_detection_confidence. https://solutions.mediapipe.dev/hands#min_detection_confidence.
@ -109,6 +114,7 @@ class Hands(SolutionBase):
super().__init__( super().__init__(
binary_graph_path=_BINARYPB_FILE_PATH, binary_graph_path=_BINARYPB_FILE_PATH,
side_inputs={ side_inputs={
'model_complexity': model_complexity,
'num_hands': max_num_hands, 'num_hands': max_num_hands,
'use_prev_landmarks': not static_image_mode, 'use_prev_landmarks': not static_image_mode,
}, },

View File

@ -32,7 +32,8 @@ from mediapipe.python.solutions import hands as mp_hands
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
DIFF_THRESHOLD = 20 # pixels LITE_MODEL_DIFF_THRESHOLD = 25 # pixels
FULL_MODEL_DIFF_THRESHOLD = 20 # pixels
EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286], EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
[289, 237], [322, 203], [219, 216], [289, 237], [322, 203], [219, 216],
[238, 138], [249, 90], [253, 51], [238, 138], [249, 90], [253, 51],
@ -40,7 +41,7 @@ EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
[185, 19], [138, 208], [131, 127], [185, 19], [138, 208], [131, 127],
[124, 77], [117, 36], [106, 222], [124, 77], [117, 36], [106, 222],
[92, 159], [79, 124], [68, 93]], [92, 159], [79, 124], [68, 93]],
[[580, 36], [504, 50], [459, 94], [[580, 34], [504, 50], [459, 94],
[429, 146], [397, 182], [507, 167], [429, 146], [397, 182], [507, 167],
[479, 245], [469, 292], [464, 330], [479, 245], [469, 292], [464, 330],
[545, 180], [534, 265], [533, 319], [545, 180], [534, 265], [533, 319],
@ -75,14 +76,18 @@ class HandsTest(parameterized.TestCase):
self.assertIsNone(results.multi_hand_landmarks) self.assertIsNone(results.multi_hand_landmarks)
self.assertIsNone(results.multi_handedness) self.assertIsNone(results.multi_handedness)
@parameterized.named_parameters(('static_image_mode', True, 1), @parameterized.named_parameters(
('video_mode', False, 5)) ('static_image_mode_with_lite_model', True, 0, 5),
def test_multi_hands(self, static_image_mode, num_frames): ('video_mode_with_lite_model', False, 0, 10),
('static_image_mode_with_full_model', True, 1, 5),
('video_mode_with_full_model', False, 1, 10))
def test_multi_hands(self, static_image_mode, model_complexity, num_frames):
image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg') image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg')
image = cv2.imread(image_path) image = cv2.imread(image_path)
with mp_hands.Hands( with mp_hands.Hands(
static_image_mode=static_image_mode, static_image_mode=static_image_mode,
max_num_hands=2, max_num_hands=2,
model_complexity=model_complexity,
min_detection_confidence=0.5) as hands: min_detection_confidence=0.5) as hands:
for idx in range(num_frames): for idx in range(num_frames):
results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
@ -104,7 +109,8 @@ class HandsTest(parameterized.TestCase):
prediction_error = np.abs( prediction_error = np.abs(
np.asarray(multi_hand_coordinates) - np.asarray(multi_hand_coordinates) -
np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION)) np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION))
npt.assert_array_less(prediction_error, DIFF_THRESHOLD) diff_threshold = LITE_MODEL_DIFF_THRESHOLD if model_complexity == 0 else FULL_MODEL_DIFF_THRESHOLD
npt.assert_array_less(prediction_error, diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "absl/status/status.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -81,6 +82,32 @@ ObjectDef GetSSBOObjectDef(int channels) {
return gpu_object_def; return gpu_object_def;
} }
#ifdef __ANDROID__
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
cl::InferenceOptions result{};
result.priority1 = options.priority1;
result.priority2 = options.priority2;
result.priority3 = options.priority3;
result.usage = options.usage;
return result;
}
absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
const std::vector<BHWC>& expected) {
RET_CHECK_EQ(actual.size(), expected.size());
const int size = actual.size();
for (int i = 0; i < size; ++i) {
const auto& dims = actual[i].dimensions;
const BHWC& bhwc = expected[i];
RET_CHECK(dims.b == bhwc.b && dims.h == bhwc.h && dims.w == bhwc.w &&
dims.c == bhwc.c);
}
return absl::OkStatus();
}
#endif // __ANDROID__
} // namespace } // namespace
absl::Status TFLiteGPURunner::InitializeWithModel( absl::Status TFLiteGPURunner::InitializeWithModel(
@ -139,16 +166,16 @@ absl::Status TFLiteGPURunner::Build() {
// try to build OpenCL first. If something goes wrong, fall back to OpenGL. // try to build OpenCL first. If something goes wrong, fall back to OpenGL.
absl::Status status = InitializeOpenCL(&builder); absl::Status status = InitializeOpenCL(&builder);
if (status.ok()) { if (status.ok()) {
LOG(INFO) << "OpenCL backend is used."; VLOG(2) << "OpenCL backend is used.";
} else { } else {
LOG(ERROR) << "Falling back to OpenGL: " << status.message(); VLOG(2) << "Falling back to OpenGL: " << status.message();
MP_RETURN_IF_ERROR(InitializeOpenGL(&builder)); MP_RETURN_IF_ERROR(InitializeOpenGL(&builder));
} }
} }
// Both graphs are not needed anymore. Make sure they are deleted. // GL graph not needed anymore, CL graph maybe needed for serialized model
// calculation.
graph_gl_.reset(nullptr); graph_gl_.reset(nullptr);
graph_cl_.reset(nullptr);
// 2. Describe output/input objects for created builder. // 2. Describe output/input objects for created builder.
for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) { for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) {
@ -204,18 +231,57 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
env_options.serialized_binary_cache = serialized_binary_cache_; env_options.serialized_binary_cache = serialized_binary_cache_;
} }
cl::InferenceEnvironmentProperties properties; cl::InferenceEnvironmentProperties properties;
cl::InferenceOptions cl_options;
cl_options.priority1 = options_.priority1;
cl_options.priority2 = options_.priority2;
cl_options.priority3 = options_.priority3;
cl_options.usage = options_.usage;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties)); cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
// Try to initialize from serialized model first.
if (!serialized_model_.empty()) {
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
if (init_status.ok()) {
serialized_model_used_ = true;
return absl::OkStatus();
}
VLOG(2) << "Failed to init from serialized model: [" << init_status
<< "]. Trying to init from scratch.";
}
// Initialize from scratch.
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
GraphFloat32 graph_cl;
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder( MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
cl_options, std::move(*graph_cl_), builder)); cl_options, std::move(graph_cl), builder));
#endif
#endif // __ANDROID__
return absl::OkStatus(); return absl::OkStatus();
} }
#ifdef __ANDROID__
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
std::unique_ptr<InferenceBuilder>* builder) {
MP_RETURN_IF_ERROR(
cl_environment_->NewInferenceBuilder(serialized_model_, builder));
MP_RETURN_IF_ERROR(VerifyShapes(builder->get()->inputs(), input_shapes_));
return VerifyShapes(builder->get()->outputs(), output_shapes_);
}
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
RET_CHECK(runner_) << "Runner is in invalid state.";
if (serialized_model_used_) {
return serialized_model_;
}
RET_CHECK(graph_cl_) << "CL graph is not initialized.";
GraphFloat32 graph_cl;
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
std::vector<uint8_t> serialized_model;
MP_RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
cl_options, std::move(graph_cl), &serialized_model));
return serialized_model;
}
#endif // __ANDROID__
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -19,6 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -29,7 +30,7 @@
#ifdef __ANDROID__ #ifdef __ANDROID__
#include "tensorflow/lite/delegates/gpu/cl/api.h" #include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif #endif // __ANDROID__
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
@ -90,11 +91,22 @@ class TFLiteGPURunner {
std::vector<uint8_t> GetSerializedBinaryCache() { std::vector<uint8_t> GetSerializedBinaryCache() {
return cl_environment_->GetSerializedBinaryCache(); return cl_environment_->GetSerializedBinaryCache();
} }
#endif
void SetSerializedModel(std::vector<uint8_t>&& serialized_model) {
serialized_model_ = std::move(serialized_model);
serialized_model_used_ = false;
}
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
#endif // __ANDROID__
private: private:
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder); absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder); absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
#ifdef __ANDROID__
absl::Status InitializeOpenCLFromSerializedModel(
std::unique_ptr<InferenceBuilder>* builder);
#endif // __ANDROID__
InferenceOptions options_; InferenceOptions options_;
std::unique_ptr<gl::InferenceEnvironment> gl_environment_; std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
@ -103,9 +115,12 @@ class TFLiteGPURunner {
std::unique_ptr<cl::InferenceEnvironment> cl_environment_; std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
std::vector<uint8_t> serialized_binary_cache_; std::vector<uint8_t> serialized_binary_cache_;
#endif std::vector<uint8_t> serialized_model_;
bool serialized_model_used_ = false;
#endif // __ANDROID__
// graph_ is maintained temporarily and becomes invalid after runner_ is ready // graph_gl_ is maintained temporarily and becomes invalid after runner_ is
// ready
std::unique_ptr<GraphFloat32> graph_gl_; std::unique_ptr<GraphFloat32> graph_gl_;
std::unique_ptr<GraphFloat32> graph_cl_; std::unique_ptr<GraphFloat32> graph_cl_;
std::unique_ptr<InferenceRunner> runner_; std::unique_ptr<InferenceRunner> runner_;