Project import generated by Copybara.

GitOrigin-RevId: bbbbcb4f5174dea33525729ede47c770069157cd
This commit is contained in:
MediaPipe Team 2021-10-18 12:39:29 -07:00 committed by chuoling
parent 33d683c671
commit 1faeaae7e5
75 changed files with 1944 additions and 560 deletions

View File

@ -120,7 +120,7 @@ just 86.22%.
### Hand Landmark Model
After the palm detection over the whole image our subsequent hand landmark
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite)
[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite)
performs precise keypoint localization of 21 3D hand-knuckle coordinates inside
the detected hand regions via regression, that is direct coordinate prediction.
The model learns a consistent internal hand pose representation and is robust
@ -163,6 +163,11 @@ unrelated, images. Default to `false`.
Maximum number of hands to detect. Default to `2`.
#### model_complexity
Complexity of the hand landmark model: `0` or `1`. Landmark accuracy as well as
inference latency generally go up with the model complexity. Default to `1`.
#### min_detection_confidence
Minimum confidence value (`[0.0, 1.0]`) from the hand detection model for the
@ -212,6 +217,7 @@ Supported configuration options:
* [static_image_mode](#static_image_mode)
* [max_num_hands](#max_num_hands)
* [model_complexity](#model_complexity)
* [min_detection_confidence](#min_detection_confidence)
* [min_tracking_confidence](#min_tracking_confidence)
@ -260,6 +266,7 @@ with mp_hands.Hands(
# For webcam input:
cap = cv2.VideoCapture(0)
with mp_hands.Hands(
model_complexity=0,
min_detection_confidence=0.5,
min_tracking_confidence=0.5) as hands:
while cap.isOpened():
@ -302,6 +309,7 @@ and a [fun application], and the following usage example.
Supported configuration options:
* [maxNumHands](#max_num_hands)
* [modelComplexity](#model_complexity)
* [minDetectionConfidence](#min_detection_confidence)
* [minTrackingConfidence](#min_tracking_confidence)
@ -351,6 +359,7 @@ const hands = new Hands({locateFile: (file) => {
}});
hands.setOptions({
maxNumHands: 2,
modelComplexity: 1,
minDetectionConfidence: 0.5,
minTrackingConfidence: 0.5
});

View File

@ -58,10 +58,12 @@ one over the other.
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite),
[TF.js model](https://tfhub.dev/mediapipe/handdetector/1)
* Hand landmark model:
[TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite),
[TFLite model (lite)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_lite.tflite),
[TFLite model (full)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_full.tflite),
[TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite),
[TF.js model](https://tfhub.dev/mediapipe/handskeleton/1)
* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
* [Model card](https://mediapipe.page.link/handmc),
[Model card (sparse)](https://mediapipe.page.link/handmc-sparse)
### [Pose](https://google.github.io/mediapipe/solutions/pose)

View File

@ -125,7 +125,7 @@ hip midpoints.
:----------------------------------------------------------------------------------------------------: |
*Fig 3. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* |
### Pose Landmark Model (BlazePose GHUM 3D)
### Pose Landmark Model (BlazePose [GHUM](https://github.com/google-research/google-research/tree/master/ghum) 3D)
The landmark model in MediaPipe Pose predicts the location of 33 pose landmarks
(see figure below).
@ -486,6 +486,7 @@ on how to build MediaPipe examples.
[BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204)
([presentation](https://youtu.be/YPpUOTRn5tA))
* [Models and model cards](./models.md#pose)
* [GHUM & GHUML: Generative 3D Human Shape and Articulated Pose Models](https://github.com/google-research/google-research/tree/master/ghum)
* [Web demo](https://code.mediapipe.dev/codepen/pose)
* [Python Colab](https://mediapipe.page.link/pose_py_colab)

View File

@ -531,9 +531,13 @@ cc_test(
":split_vector_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -47,4 +47,8 @@ typedef BeginLoopCalculator<std::vector<std::vector<Matrix>>>
BeginLoopMatrixVectorCalculator;
REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator);
// A calculator to process std::vector<uint64_t>.
typedef BeginLoopCalculator<std::vector<uint64_t>> BeginLoopUint64tCalculator;
REGISTER_CALCULATOR(BeginLoopUint64tCalculator);
} // namespace mediapipe

View File

@ -14,7 +14,11 @@
#include <memory>
#include "absl/status/status.h"
#include "absl/types/optional.h"
#include "mediapipe/calculators/core/split_vector_calculator.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
@ -301,4 +305,99 @@ TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) {
}
} // namespace
class PassThroughAndTsBoundUpdateNode : public mediapipe::api2::Node {
public:
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
static constexpr mediapipe::api2::Output<int> kOutValue{"VALUE"};
static constexpr mediapipe::api2::Output<int> kOutTsBoundUpdate{
"TS_BOUND_UPDATE"};
MEDIAPIPE_NODE_CONTRACT(kInValue, kOutValue, kOutTsBoundUpdate);
absl::Status Process(CalculatorContext* cc) override {
kOutValue(cc).Send(kInValue(cc));
kOutTsBoundUpdate(cc).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(PassThroughAndTsBoundUpdateNode);
class ToOptionalNode : public mediapipe::api2::Node {
public:
static constexpr mediapipe::api2::Input<int> kTick{"TICK"};
static constexpr mediapipe::api2::Input<int> kInValue{"VALUE"};
static constexpr mediapipe::api2::Output<absl::optional<int>> kOutValue{
"OUTPUT"};
MEDIAPIPE_NODE_CONTRACT(kTick, kInValue, kOutValue);
absl::Status Process(CalculatorContext* cc) override {
if (kInValue(cc).IsEmpty()) {
kOutValue(cc).Send(absl::nullopt);
} else {
kOutValue(cc).Send({kInValue(cc).Get()});
}
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(ToOptionalNode);
namespace {
TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "select"
node {
calculator: "PassThroughAndTsBoundUpdateNode"
input_stream: "VALUE:select"
output_stream: "VALUE:select_ps"
output_stream: "TS_BOUND_UPDATE:ts_bound_update"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:select_ps"
input_stream: "INPUT:1:ts_bound_update"
input_stream: "SELECT:select"
output_stream: "OUTPUT:select_or_ts_bound_update"
}
node {
calculator: "ToOptionalNode"
input_stream: "TICK:select"
input_stream: "VALUE:select_or_ts_bound_update"
output_stream: "OUTPUT:output"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
auto send_value_fn = [&](int value, Timestamp ts) -> absl::Status {
MP_RETURN_IF_ERROR(
graph.AddPacketToInputStream("select", MakePacket<int>(value).At(ts)));
return graph.WaitUntilIdle();
};
MP_ASSERT_OK(send_value_fn(0, Timestamp(1)));
ASSERT_EQ(output_packets.size(), 1);
EXPECT_EQ(output_packets[0].Get<absl::optional<int>>(), 0);
MP_ASSERT_OK(send_value_fn(1, Timestamp(2)));
ASSERT_EQ(output_packets.size(), 2);
EXPECT_EQ(output_packets[1].Get<absl::optional<int>>(), absl::nullopt);
MP_ASSERT_OK(send_value_fn(0, Timestamp(3)));
ASSERT_EQ(output_packets.size(), 3);
EXPECT_EQ(output_packets[2].Get<absl::optional<int>>(), 0);
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -34,7 +34,6 @@ option java_outer_classname = "InferenceCalculatorProto";
// }
// }
// }
//
message InferenceCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional InferenceCalculatorOptions ext = 336783863;
@ -69,8 +68,30 @@ message InferenceCalculatorOptions {
// Load pre-compiled serialized binary cache to accelerate init process.
// Only available for OpenCL delegate on Android.
// Kernel caching will only be enabled if this path is set.
//
// NOTE: binary cache usage may be skipped if valid serialized model,
// specified by "serialized_model_dir", exists.
//
// TODO: update to cached_kernel_dir
optional string cached_kernel_path = 2;
// A dir to load from and save to a pre-compiled serialized model used to
// accelerate init process.
//
// NOTE: available for OpenCL delegate on Android only when
// "use_advanced_gpu_api" is set to true and "model_token" is set
// properly.
//
// NOTE: serialized model takes precedence over binary cache
// specified by "cached_kernel_path", which still can be used if
// serialized model is invalid or missing.
optional string serialized_model_dir = 7;
// Unique token identifying the model. Used in conjunction with
// "serialized_model_dir". It is the caller's responsibility to ensure
// there is no clash of the tokens.
optional string model_token = 8;
// Encapsulated compilation/runtime tradeoffs.
enum InferenceUsage {
UNSPECIFIED = 0;

View File

@ -20,6 +20,7 @@
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h"
#if MEDIAPIPE_TFLITE_GL_INFERENCE
@ -49,8 +50,8 @@ class InferenceCalculatorGlImpl
absl::Status Close(CalculatorContext* cc) override;
private:
absl::Status ReadKernelsFromFile();
absl::Status WriteKernelsToFile();
absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches();
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
@ -82,6 +83,8 @@ class InferenceCalculatorGlImpl
bool use_kernel_caching_ = false;
std::string cached_kernel_filename_;
bool use_serialized_model_ = false;
std::string serialized_model_path_;
};
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
@ -114,6 +117,9 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
tflite_gpu_runner_usage_ = delegate.gpu().usage();
use_kernel_caching_ =
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
use_serialized_model_ = use_advanced_gpu_api_ &&
delegate.gpu().has_serialized_model_dir() &&
delegate.gpu().has_model_token();
use_gpu_delegate_ = !use_advanced_gpu_api_;
if (use_kernel_caching_) {
@ -123,6 +129,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
".ker";
#endif // MEDIAPIPE_ANDROID
}
if (use_serialized_model_) {
#ifdef MEDIAPIPE_ANDROID
serialized_model_path_ = mediapipe::file::JoinPath(
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
#endif // MEDIAPIPE_ANDROID
}
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything.
@ -210,7 +222,7 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() {
#ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_) {
// Save kernel file.
@ -220,12 +232,22 @@ absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() {
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
}
if (use_serialized_model_) {
// Save serialized model file.
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
tflite_gpu_runner_->GetSerializedModel());
absl::string_view serialized_model(
reinterpret_cast<char*>(serialized_model_vec.data()),
serialized_model_vec.size());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
}
#endif // MEDIAPIPE_ANDROID
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(WriteKernelsToFile());
MP_RETURN_IF_ERROR(SaveGpuCaches());
if (use_gpu_delegate_) {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
gpu_buffers_in_.clear();
@ -239,17 +261,24 @@ absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() {
absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
#ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_) {
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
// Load pre-compiled kernel file.
if (mediapipe::File::Exists(cached_kernel_filename_)) {
std::string cache_str;
MP_RETURN_IF_ERROR(
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
}
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
// Load serialized model file.
std::string serialized_model_str;
MP_RETURN_IF_ERROR(
file::GetContents(serialized_model_path_, &serialized_model_str));
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
serialized_model_str.end());
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
}
#endif // MEDIAPIPE_ANDROID
return absl::OkStatus();
@ -313,7 +342,7 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
tflite_gpu_runner_->GetOutputShapes()[i].c};
}
MP_RETURN_IF_ERROR(ReadKernelsFromFile());
MP_RETURN_IF_ERROR(ReadGpuCaches());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());

View File

@ -24,20 +24,20 @@ message SsdAnchorsCalculatorOptions {
optional SsdAnchorsCalculatorOptions ext = 247258239;
}
// Size of input images.
required int32 input_size_width = 1;
required int32 input_size_height = 2;
optional int32 input_size_width = 1; // required
optional int32 input_size_height = 2; // required
// Min and max scales for generating anchor boxes on feature maps.
required float min_scale = 3;
required float max_scale = 4;
optional float min_scale = 3; // required
optional float max_scale = 4; // required
// The offset for the center of anchors. The value is in the scale of stride.
// E.g. 0.5 meaning 0.5 * |current_stride| in pixels.
required float anchor_offset_x = 5 [default = 0.5];
required float anchor_offset_y = 6 [default = 0.5];
optional float anchor_offset_x = 5 [default = 0.5]; // required
optional float anchor_offset_y = 6 [default = 0.5]; // required
// Number of output feature maps to generate the anchors on.
required int32 num_layers = 7;
optional int32 num_layers = 7; // required
// Sizes of output feature maps to create anchors. Either feature_map size or
// stride should be provided.
repeated int32 feature_map_width = 8;

View File

@ -26,12 +26,12 @@ message TfLiteTensorsToDetectionsCalculatorOptions {
}
// The number of output classes predicted by the detection model.
required int32 num_classes = 1;
optional int32 num_classes = 1; // required
// The number of output boxes predicted by the detection model.
required int32 num_boxes = 2;
optional int32 num_boxes = 2; // required
// The number of output values per boxes predicted by the detection model. The
// values contain bounding boxes, keypoints, etc.
required int32 num_coords = 3;
optional int32 num_coords = 3; // required
// The offset of keypoint coordinates in the location tensor.
optional int32 keypoint_coord_offset = 9;

View File

@ -31,7 +31,7 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
}
// Number of landmarks from the output of the model.
required int32 num_landmarks = 1;
optional int32 num_landmarks = 1; // required
// Size of the input image for the model. These options are used only when
// normalized landmarks are needed. Z coordinate is scaled as X assuming

View File

@ -24,9 +24,9 @@ message TfLiteTensorsToSegmentationCalculatorOptions {
}
// Dimensions of input segmentation tensor to process.
required int32 tensor_width = 1;
required int32 tensor_height = 2;
required int32 tensor_channels = 3;
optional int32 tensor_width = 1; // required
optional int32 tensor_height = 2; // required
optional int32 tensor_channels = 3; // required
// How much to use previous mask when computing current one; range [0-1].
// This is a tradeoff between responsiveness (0.0) and accuracy (1.0).

View File

@ -98,32 +98,25 @@ public class MainActivity extends AppCompatActivity {
}
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData());
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
@ -138,10 +131,33 @@ public class MainActivity extends AppCompatActivity {
default:
matrix.postRotate(0);
}
bitmap =
Bitmap.createBitmap(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
downscaleBitmap(
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}

View File

@ -7,7 +7,7 @@ android {
buildToolsVersion "30.0.3"
defaultConfig {
applicationId "com.google.mediapipe.apps.hands"
applicationId "com.google.mediapipe.apps.facemesh"
minSdkVersion 21
targetSdkVersion 30
versionCode 1

View File

@ -99,32 +99,25 @@ public class MainActivity extends AppCompatActivity {
}
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData());
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
@ -139,10 +132,33 @@ public class MainActivity extends AppCompatActivity {
default:
matrix.postRotate(0);
}
bitmap =
Bitmap.createBitmap(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
downscaleBitmap(
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}

View File

@ -100,32 +100,25 @@ public class MainActivity extends AppCompatActivity {
}
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData());
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
private Bitmap downscaleBitmap(Bitmap originalBitmap) {
double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight();
int width = imageView.getWidth();
int height = imageView.getHeight();
if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) {
width = (int) (height * aspectRatio);
} else {
height = (int) (width / aspectRatio);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
return Bitmap.createScaledBitmap(originalBitmap, width, height, false);
}
private Bitmap rotateBitmap(Bitmap inputBitmap, InputStream imageData) throws IOException {
int orientation =
new ExifInterface(imageData)
.getAttributeInt(
ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation != ExifInterface.ORIENTATION_NORMAL) {
.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
if (orientation == ExifInterface.ORIENTATION_NORMAL) {
return inputBitmap;
}
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_ROTATE_90:
@ -140,10 +133,33 @@ public class MainActivity extends AppCompatActivity {
default:
matrix.postRotate(0);
}
bitmap =
Bitmap.createBitmap(
bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
return Bitmap.createBitmap(
inputBitmap, 0, 0, inputBitmap.getWidth(), inputBitmap.getHeight(), matrix, true);
}
/** Sets up the UI components for the static image demo. */
private void setupStaticImageDemoUiComponents() {
// The Intent to access gallery and read images as bitmap.
imageGetter =
registerForActivityResult(
new ActivityResultContracts.StartActivityForResult(),
result -> {
Intent resultIntent = result.getData();
if (resultIntent != null) {
if (result.getResultCode() == RESULT_OK) {
Bitmap bitmap = null;
try {
bitmap =
downscaleBitmap(
MediaStore.Images.Media.getBitmap(
this.getContentResolver(), resultIntent.getData()));
} catch (IOException e) {
Log.e(TAG, "Bitmap reading error:" + e);
}
try {
InputStream imageData =
this.getContentResolver().openInputStream(resultIntent.getData());
bitmap = rotateBitmap(bitmap, imageData);
} catch (IOException e) {
Log.e(TAG, "Bitmap rotation error:" + e);
}

View File

@ -38,7 +38,7 @@ android_binary(
assets = [
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/palm_detection:palm_detection.tflite",
],
assets_dir = "",

View File

@ -39,7 +39,7 @@ android_binary(
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
"//mediapipe/modules/face_landmark:face_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite",

View File

@ -62,7 +62,7 @@ objc_library(
copts = ["-std=c++17"],
data = [
"//mediapipe/graphs/hand_tracking:hand_tracking_mobile_gpu.binarypb",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/palm_detection:palm_detection.tflite",
],

View File

@ -57,7 +57,7 @@ objc_library(
"//mediapipe/graphs/holistic_tracking:holistic_tracking_gpu.binarypb",
"//mediapipe/modules/face_detection:face_detection_short_range.tflite",
"//mediapipe/modules/face_landmark:face_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/holistic_landmark:hand_recrop.tflite",
"//mediapipe/modules/pose_detection:pose_detection.tflite",

View File

@ -427,7 +427,8 @@ absl::Status CalculatorGraph::Initialize(
const std::map<std::string, Packet>& side_packets) {
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize(
input_config, /*graph_registry=*/nullptr, &service_manager_));
input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr,
&service_manager_));
return Initialize(std::move(validated_graph), side_packets);
}

View File

@ -122,7 +122,7 @@ class CalculatorRunner {
const StreamContentsSet& Outputs() const { return *outputs_; }
// Returns the access to the output side packets.
const PacketSet& OutputSidePackets() { return *output_side_packets_.get(); }
const PacketSet& OutputSidePackets() { return *output_side_packets_; }
// Returns a graph counter.
mediapipe::Counter* GetCounter(const std::string& name);

View File

@ -77,13 +77,6 @@ bool Image::ConvertToGpu() const {
#else
// GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture)
auto buffer = mediapipe::GlTextureBuffer::Create(*image_frame_);
glBindTexture(GL_TEXTURE_2D, buffer->name());
// See GlCalculatorHelperImpl::SetStandardTextureParams
glTexParameteri(buffer->target(), GL_TEXTURE_MIN_FILTER, GL_LINEAR);
glTexParameteri(buffer->target(), GL_TEXTURE_MAG_FILTER, GL_LINEAR);
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
glBindTexture(GL_TEXTURE_2D, 0);
glFlush();
gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer));
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER

View File

@ -244,6 +244,8 @@ cc_test(
srcs = ["mux_input_stream_handler_test.cc"],
deps = [
":mux_input_stream_handler",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:make_pair_calculator",
"//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/core:round_robin_demux_calculator",

View File

@ -75,14 +75,31 @@ class MuxInputStreamHandler : public InputStreamHandler {
int control_value = control_packet.Get<int>();
CHECK_LE(0, control_value);
CHECK_LT(control_value, input_stream_managers_.NumEntries() - 1);
const auto& data_stream = input_stream_managers_.Get(
input_stream_managers_.BeginId() + control_value);
// Data stream may contain some outdated packets which failed to be popped
// out during "FillInputSet". (This handler doesn't sync input streams,
// hence "FillInputSet" can be triggerred before every input stream is
// filled with packets corresponding to the same timestamp.)
data_stream->ErasePacketsEarlierThan(*min_stream_timestamp);
Timestamp stream_timestamp = data_stream->MinTimestampOrBound(&empty);
if (empty) {
CHECK_LE(stream_timestamp, *min_stream_timestamp);
if (stream_timestamp <= *min_stream_timestamp) {
// "data_stream" didn't receive a packet corresponding to the current
// "control_stream" packet yet.
return NodeReadiness::kNotReady;
}
// "data_stream" timestamp bound update detected.
return NodeReadiness::kReadyForProcess;
}
if (stream_timestamp > *min_stream_timestamp) {
// The earliest packet "data_stream" holds corresponds to a control packet
// yet to arrive, which means there won't be a "data_stream" packet
// corresponding to the current "control_stream" packet, which should be
// indicated as timestamp boun update.
return NodeReadiness::kReadyForProcess;
}
CHECK_EQ(stream_timestamp, *min_stream_timestamp);
return NodeReadiness::kReadyForProcess;
}

View File

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
@ -19,9 +20,10 @@
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
using ::testing::ElementsAre;
// A regression test for b/31620439. MuxInputStreamHandler's accesses to the
// control and data streams should be atomic so that it has a consistent view
// of the two streams. None of the CHECKs in the GetNodeReadiness() method of
@ -87,5 +89,561 @@ TEST(MuxInputStreamHandlerTest, AtomicAccessToControlAndDataStreams) {
MP_ASSERT_OK(graph.WaitUntilDone());
}
MATCHER_P2(IntPacket, value, ts, "") {
return arg.template Get<int>() == value && arg.Timestamp() == ts;
}
struct GateAndMuxGraphInput {
int input0;
int input1;
int input2;
int select;
bool allow0;
bool allow1;
bool allow2;
Timestamp at;
};
constexpr char kGateAndMuxGraph[] = R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "input2"
input_stream: "select"
input_stream: "allow0"
input_stream: "allow1"
input_stream: "allow2"
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow0"
input_stream: "input0"
output_stream: "output0"
}
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow1"
input_stream: "input1"
output_stream: "output1"
}
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow2"
input_stream: "input2"
output_stream: "output2"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:output0"
input_stream: "INPUT:1:output1"
input_stream: "INPUT:2:output2"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
})pb";
absl::Status SendInput(GateAndMuxGraphInput in, CalculatorGraph& graph) {
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"input0", MakePacket<int>(in.input0).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"input1", MakePacket<int>(in.input1).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"input2", MakePacket<int>(in.input2).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"select", MakePacket<int>(in.select).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"allow0", MakePacket<bool>(in.allow0).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"allow1", MakePacket<bool>(in.allow1).At(in.at)));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"allow2", MakePacket<bool>(in.allow2).At(in.at)));
return graph.WaitUntilIdle();
}
TEST(MuxInputStreamHandlerTest, BasicMuxing) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = true,
.allow1 = false,
.allow2 = false,
.at = Timestamp(1)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = false,
.allow1 = true,
.allow2 = false,
.at = Timestamp(2)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = false,
.allow1 = false,
.allow2 = true,
.at = Timestamp(3)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(3))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingNonEmptyInputs) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(1)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(2)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(3)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(1000, Timestamp(1)),
IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(3))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingAllTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(1)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(2)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(3)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingSlectedTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = false,
.allow1 = true,
.allow2 = true,
.at = Timestamp(1)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = true,
.allow1 = false,
.allow2 = true,
.at = Timestamp(2)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = false,
.at = Timestamp(3)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest, MuxingSometimesTimestampBoundUpdates) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kGateAndMuxGraph);
std::vector<Packet> output_packets;
tool::AddVectorSink("output", &config, &output_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 0,
.allow0 = false,
.allow1 = false,
.allow2 = false,
.at = Timestamp(1)},
graph));
EXPECT_TRUE(output_packets.empty());
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 1,
.allow0 = false,
.allow1 = true,
.allow2 = false,
.at = Timestamp(2)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = false,
.at = Timestamp(3)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2))));
MP_ASSERT_OK(SendInput({.input0 = 1000,
.input1 = 900,
.input2 = 800,
.select = 2,
.allow0 = true,
.allow1 = true,
.allow2 = true,
.at = Timestamp(4)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(4))));
MP_ASSERT_OK(SendInput({.input0 = 700,
.input1 = 600,
.input2 = 500,
.select = 0,
.allow0 = true,
.allow1 = false,
.allow2 = false,
.at = Timestamp(5)},
graph));
EXPECT_THAT(output_packets, ElementsAre(IntPacket(900, Timestamp(2)),
IntPacket(800, Timestamp(4)),
IntPacket(700, Timestamp(5))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
MATCHER_P(EmptyPacket, ts, "") {
return arg.IsEmpty() && arg.Timestamp() == ts;
}
MATCHER_P2(Pair, m1, m2, "") {
const auto& p = arg.template Get<std::pair<Packet, Packet>>();
return testing::Matches(m1)(p.first) && testing::Matches(m2)(p.second);
}
TEST(MuxInputStreamHandlerTest,
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacket) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input0"
input_stream: "INPUT:1:input1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
node {
calculator: "MakePairCalculator"
input_stream: "select"
input_stream: "output"
output_stream: "pair"
}
)pb");
std::vector<Packet> pair_packets;
tool::AddVectorSink("pair", &config, &pair_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_TRUE(pair_packets.empty());
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
EmptyPacket(Timestamp(1)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(800).At(Timestamp(4))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
EmptyPacket(Timestamp(1)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(3))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(4))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest,
TimestampBoundUpdateWhenControlPacketEarlierThanDataPacketPacketsAtOnce) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:input0"
input_stream: "INPUT:1:input1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
node {
calculator: "MakePairCalculator"
input_stream: "select"
input_stream: "output"
output_stream: "pair"
}
)pb");
std::vector<Packet> pair_packets;
tool::AddVectorSink("pair", &config, &pair_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(900).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input1", MakePacket<int>(800).At(Timestamp(4))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(3))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(1).At(Timestamp(4))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(1000, Timestamp(2))),
Pair(IntPacket(1, Timestamp(3)), EmptyPacket(Timestamp(3))),
Pair(IntPacket(1, Timestamp(4)), IntPacket(800, Timestamp(4)))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
TEST(MuxInputStreamHandlerTest,
TimestampBoundUpdateTriggersTimestampBoundUpdate) {
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input0"
input_stream: "input1"
input_stream: "select"
input_stream: "allow0"
input_stream: "allow1"
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow0"
input_stream: "input0"
output_stream: "output0"
}
node {
calculator: "GateCalculator"
input_stream: "ALLOW:allow1"
input_stream: "input1"
output_stream: "output1"
}
node {
calculator: "MuxCalculator"
input_stream: "INPUT:0:output0"
input_stream: "INPUT:1:output1"
input_stream: "SELECT:select"
output_stream: "OUTPUT:output"
input_stream_handler { input_stream_handler: "MuxInputStreamHandler" }
}
node {
calculator: "MakePairCalculator"
input_stream: "select"
input_stream: "output"
output_stream: "pair"
}
)pb");
std::vector<Packet> pair_packets;
tool::AddVectorSink("pair", &config, &pair_packets);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(1000).At(Timestamp(1))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"allow0", MakePacket<bool>(false).At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(pair_packets, ElementsAre(Pair(IntPacket(0, Timestamp(1)),
EmptyPacket(Timestamp(1)))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"select", MakePacket<int>(0).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input0", MakePacket<int>(900).At(Timestamp(2))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"allow0", MakePacket<bool>(true).At(Timestamp(2))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
pair_packets,
ElementsAre(
Pair(IntPacket(0, Timestamp(1)), EmptyPacket(Timestamp(1))),
Pair(IntPacket(0, Timestamp(2)), IntPacket(900, Timestamp(2)))));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -56,7 +56,9 @@ class SubgraphContext {
return options_map_.Get<T>();
}
const CalculatorGraphConfig::Node& OriginalNode() { return original_node_; }
const CalculatorGraphConfig::Node& OriginalNode() const {
return original_node_;
}
template <typename T>
ServiceBinding<T> Service(const GraphService<T>& service) const {

View File

@ -724,6 +724,7 @@ cc_test(
srcs = ["subgraph_expansion_test.cc"],
deps = [
":node_chain_subgraph_cc_proto",
":node_chain_subgraph_options_lib",
":subgraph_expansion",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",

View File

@ -23,6 +23,8 @@
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/type_map.h"
#define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging()
namespace mediapipe {
namespace tool {
@ -47,13 +49,13 @@ absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in,
WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag);
if (IsLengthDelimited(wire_type)) {
uint32 length;
RET_CHECK(in->ReadVarint32(&length));
RET_CHECK(in->ReadString(result, length));
RET_CHECK_NO_LOG(in->ReadVarint32(&length));
RET_CHECK_NO_LOG(in->ReadString(result, length));
} else {
std::string field_data;
StringOutputStream sos(&field_data);
CodedOutputStream cos(&sos);
RET_CHECK(WireFormatLite::SkipField(in, tag, &cos));
RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, &cos));
// Skip the tag written by SkipField.
int tag_size = CodedOutputStream::VarintSize32(tag);
cos.Trim();
@ -67,13 +69,13 @@ absl::Status ReadPackedValues(WireFormatLite::WireType wire_type,
CodedInputStream* in,
std::vector<std::string>* field_values) {
uint32 data_size;
RET_CHECK(in->ReadVarint32(&data_size));
RET_CHECK_NO_LOG(in->ReadVarint32(&data_size));
// fake_tag encodes the wire-type for calls to WireFormatLite::SkipField.
uint32 fake_tag = WireFormatLite::MakeTag(1, wire_type);
while (data_size > 0) {
std::string number;
MP_RETURN_IF_ERROR(ReadFieldValue(fake_tag, in, &number));
RET_CHECK_LE(number.size(), data_size);
RET_CHECK_NO_LOG(number.size() <= data_size);
field_values->push_back(number);
data_size -= number.size();
}
@ -98,7 +100,7 @@ absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type,
field_values->push_back(value);
}
} else {
RET_CHECK(WireFormatLite::SkipField(in, tag, out));
RET_CHECK_NO_LOG(WireFormatLite::SkipField(in, tag, out));
}
}
return absl::OkStatus();
@ -157,12 +159,12 @@ absl::Status ProtoUtilLite::ReplaceFieldRange(
MP_RETURN_IF_ERROR(access.SetMessage(*message));
std::vector<std::string>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK(index >= 0 && index < v.size());
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(ReplaceFieldRange(&v[index], proto_path, length,
field_type, field_values));
} else {
RET_CHECK(index >= 0 && index <= v.size());
RET_CHECK(index + length >= 0 && index + length <= v.size());
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
v.erase(v.begin() + index, v.begin() + index + length);
v.insert(v.begin() + index, field_values.begin(), field_values.end());
}
@ -184,12 +186,12 @@ absl::Status ProtoUtilLite::GetFieldRange(
MP_RETURN_IF_ERROR(access.SetMessage(message));
std::vector<std::string>& v = *access.mutable_field_values();
if (!proto_path.empty()) {
RET_CHECK(index >= 0 && index < v.size());
RET_CHECK_NO_LOG(index >= 0 && index < v.size());
MP_RETURN_IF_ERROR(
GetFieldRange(v[index], proto_path, length, field_type, field_values));
} else {
RET_CHECK(index >= 0 && index <= v.size());
RET_CHECK(index + length >= 0 && index + length <= v.size());
RET_CHECK_NO_LOG(index >= 0 && index <= v.size());
RET_CHECK_NO_LOG(index + length >= 0 && index + length <= v.size());
field_values->insert(field_values->begin(), v.begin() + index,
v.begin() + index + length);
}

View File

@ -274,12 +274,14 @@ absl::Status ConnectSubgraphStreams(
absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) {
graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
RET_CHECK(config);
MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(
CalculatorGraphConfig::Node(), config));
graph_options ? *graph_options : CalculatorGraphConfig::Node(), config));
auto* nodes = config->mutable_node();
while (1) {
auto subgraph_nodes_start = std::stable_partition(

View File

@ -72,6 +72,7 @@ absl::Status ConnectSubgraphStreams(
absl::Status ExpandSubgraphs(
CalculatorGraphConfig* config,
const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Creates a graph wrapping the provided node and exposing all of its

View File

@ -560,9 +560,111 @@ TEST(SubgraphExpansionTest, GraphServicesUsage) {
MP_ASSERT_OK(service_manager.SetServiceObject(
kStringTestService, std::make_shared<std::string>("ExpectedNode")));
MP_EXPECT_OK(tool::ExpandSubgraphs(&supergraph, /*graph_registry=*/nullptr,
/*graph_options=*/nullptr,
&service_manager));
EXPECT_THAT(supergraph, mediapipe::EqualsProto(expected_graph));
}
// Shows SubgraphOptions consumed by GraphRegistry::CreateByName.
TEST(SubgraphExpansionTest, SubgraphOptionsUsage) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
GraphRegistry graph_registry;
// CalculatorGraph::Initialize passes the SubgraphOptions into:
// (1) GraphRegistry::CreateByName("NodeChainSubgraph", options)
// (2) tool::ExpandSubgraphs(&config, options)
auto graph_options =
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
options {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
})pb");
SubgraphContext context(&graph_options, /*service_manager=*/nullptr);
// "NodeChainSubgraph" consumes graph_options only in CreateByName.
auto subgraph_status =
graph_registry.CreateByName("", "NodeChainSubgraph", &context);
MP_ASSERT_OK(subgraph_status);
auto subgraph = std::move(subgraph_status).value();
MP_ASSERT_OK(
tool::ExpandSubgraphs(&subgraph, &graph_registry, &graph_options));
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_0"
output_stream: "stream_1"
}
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_1"
output_stream: "stream_2"
}
node {
calculator: "DoubleIntCalculator"
input_stream: "stream_2"
output_stream: "stream_3"
}
input_stream: "INPUT:stream_0"
output_stream: "OUTPUT:stream_3"
)pb");
EXPECT_THAT(subgraph, mediapipe::EqualsProto(expected_graph));
}
// Shows SubgraphOptions consumed by tool::ExpandSubgraphs.
TEST(SubgraphExpansionTest, SimpleSubgraphOptionsUsage) {
EXPECT_TRUE(SubgraphRegistry::IsRegistered("NodeChainSubgraph"));
GraphRegistry graph_registry;
auto moon_options =
mediapipe::ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"pb(
options {
[mediapipe.NodeChainSubgraphOptions.ext] {
node_type: "DoubleIntCalculator"
chain_length: 3
}
})pb");
auto moon_subgraph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
type: "MoonSubgraph"
graph_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
node: {
calculator: "MoonCalculator"
node_options: {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
option_value: "chain_length:options/chain_length"
}
)pb");
// The moon_options are copied into the graph_options of moon_subgraph.
MP_ASSERT_OK(
tool::ExpandSubgraphs(&moon_subgraph, &graph_registry, &moon_options));
// The field chain_length is copied from moon_options into MoonCalculator.
CalculatorGraphConfig expected_graph =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "MoonCalculator"
node_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {
chain_length: 3
}
}
option_value: "chain_length:options/chain_length"
}
type: "MoonSubgraph"
graph_options {
[type.googleapis.com/mediapipe.NodeChainSubgraphOptions] {}
}
)pb");
EXPECT_THAT(moon_subgraph, mediapipe::EqualsProto(expected_graph));
}
} // namespace
} // namespace mediapipe

View File

@ -150,12 +150,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
const int packet_count = 10;
// Send int value packets at {10K, 20K, 30K, ..., 100K}.
for (uint64 t = 1; t <= packet_count; ++t) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
if (send_bounds) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"enable", MakePacket<bool>(true).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// The inputs are sent to the input stream "foo", they should pass through.
EXPECT_EQ(out_foo.size(), t);
@ -175,12 +176,13 @@ void RunTestContainer(CalculatorGraphConfig supergraph,
// Send int value packets at {110K, 120K, ..., 200K}.
for (uint64 t = 11; t <= packet_count * 2; ++t) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
if (send_bounds) {
MP_EXPECT_OK(graph.AddPacketToInputStream(
"enable", MakePacket<bool>(false).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle());
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"foo", MakePacket<int>(t).At(Timestamp(t * 10000))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// The inputs are sent to the input stream "foo", they should pass through.
EXPECT_EQ(out_foo.size(), t);

View File

@ -143,11 +143,12 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) {
absl::Status PerformBasicTransforms(
const CalculatorGraphConfig& input_graph_config,
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager,
CalculatorGraphConfig* output_graph_config) {
*output_graph_config = input_graph_config;
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
service_manager));
graph_options, service_manager));
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
@ -347,6 +348,7 @@ absl::Status NodeTypeInfo::Initialize(
absl::Status ValidatedGraphConfig::Initialize(
const CalculatorGraphConfig& input_config,
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) {
RET_CHECK(!initialized_)
<< "ValidatedGraphConfig can be initialized only once.";
@ -356,8 +358,8 @@ absl::Status ValidatedGraphConfig::Initialize(
<< input_config.DebugString();
#endif
MP_RETURN_IF_ERROR(PerformBasicTransforms(input_config, graph_registry,
service_manager, &config_));
MP_RETURN_IF_ERROR(PerformBasicTransforms(
input_config, graph_registry, graph_options, service_manager, &config_));
// Initialize the basic node information.
MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
@ -431,22 +433,24 @@ absl::Status ValidatedGraphConfig::Initialize(
}
absl::Status ValidatedGraphConfig::Initialize(
const std::string& graph_type, const Subgraph::SubgraphOptions* options,
const GraphRegistry* graph_registry,
const std::string& graph_type, const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) {
graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
SubgraphContext subgraph_context(options, service_manager);
SubgraphContext subgraph_context(graph_options, service_manager);
auto status_or_config =
graph_registry->CreateByName("", graph_type, &subgraph_context);
MP_RETURN_IF_ERROR(status_or_config.status());
return Initialize(status_or_config.value(), graph_registry, service_manager);
return Initialize(status_or_config.value(), graph_registry, graph_options,
service_manager);
}
absl::Status ValidatedGraphConfig::Initialize(
const std::vector<CalculatorGraphConfig>& input_configs,
const std::vector<CalculatorGraphTemplate>& input_templates,
const std::string& graph_type, const Subgraph::SubgraphOptions* arguments,
const std::string& graph_type,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) {
GraphRegistry graph_registry;
for (auto& config : input_configs) {
@ -455,7 +459,8 @@ absl::Status ValidatedGraphConfig::Initialize(
for (auto& templ : input_templates) {
graph_registry.Register(templ.config().type(), templ);
}
return Initialize(graph_type, arguments, &graph_registry, service_manager);
return Initialize(graph_type, &graph_registry, graph_options,
service_manager);
}
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {

View File

@ -195,17 +195,20 @@ class ValidatedGraphConfig {
// Initializes the ValidatedGraphConfig. This function must be called
// before any other functions. Subgraphs are specified through the
// global graph registry or an optional local graph registry.
absl::Status Initialize(const CalculatorGraphConfig& input_config,
absl::Status Initialize(
const CalculatorGraphConfig& input_config,
const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Initializes the ValidatedGraphConfig from registered graph and subgraph
// configs. Subgraphs are retrieved from the specified graph registry or from
// the global graph registry. A subgraph can be instantiated directly by
// specifying its type in |graph_type|.
absl::Status Initialize(const std::string& graph_type,
const Subgraph::SubgraphOptions* options = nullptr,
absl::Status Initialize(
const std::string& graph_type,
const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Initializes the ValidatedGraphConfig from the specified graph and subgraph
@ -218,7 +221,7 @@ class ValidatedGraphConfig {
const std::vector<CalculatorGraphConfig>& input_configs,
const std::vector<CalculatorGraphTemplate>& input_templates,
const std::string& graph_type = "",
const Subgraph::SubgraphOptions* arguments = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr);
// Returns true if the ValidatedGraphConfig has been initialized.

View File

@ -155,6 +155,7 @@ TEST(ValidatedGraphConfigTest, InitializeSubgraphWithServiceCalculatorB) {
kStringTestService, std::make_shared<std::string>(calculator_name)));
MP_EXPECT_OK(config.Initialize(graph,
/*graph_registry=*/nullptr,
/*subgraph_options=*/nullptr,
/*service_manager=*/&service_manager));
ASSERT_TRUE(config.Initialized());
EXPECT_THAT(config.Config(), EqualsProto(ExpectedConfigExpandedFromGraph(

View File

@ -196,7 +196,9 @@ cc_library(
deps = [
":gl_base",
":gl_context",
":gl_texture_view",
":gpu_buffer_format",
":gpu_buffer_storage",
# TODO: remove this dependency. Some other teams' tests
# depend on having an indirect image_frame dependency, need to be
# fixed first.
@ -205,6 +207,17 @@ cc_library(
],
)
cc_library(
name = "gl_texture_view",
srcs = ["gl_texture_view.cc"],
hdrs = ["gl_texture_view.h"],
visibility = ["//visibility:public"],
deps = [
":gl_base",
":gl_context",
],
)
cc_library(
name = "gpu_buffer",
srcs = ["gpu_buffer.cc"],
@ -214,12 +227,15 @@ cc_library(
":gl_base",
":gl_context",
":gpu_buffer_format",
":gpu_buffer_storage",
":gl_texture_view",
"//mediapipe/framework/formats:image_frame",
] + select({
"//conditions:default": [
":gl_texture_buffer",
],
"//mediapipe:ios": [
":gpu_buffer_storage_cv_pixel_buffer",
"//mediapipe/objc:util",
"//mediapipe/objc:CFHolder",
],
@ -244,6 +260,35 @@ cc_library(
],
)
cc_library(
name = "gpu_buffer_storage",
hdrs = ["gpu_buffer_storage.h"],
visibility = ["//visibility:public"],
deps = [
":gl_base",
":gpu_buffer_format",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "gpu_buffer_storage_cv_pixel_buffer",
srcs = ["gpu_buffer_storage_cv_pixel_buffer.cc"],
hdrs = ["gpu_buffer_storage_cv_pixel_buffer.h"],
visibility = ["//visibility:public"],
deps = [
":gl_base",
":gl_context",
":gl_texture_view",
":gpu_buffer_storage",
"//mediapipe/objc:CFHolder",
"//mediapipe/objc:util",
],
)
mediapipe_proto_library(
name = "gpu_origin_proto",
srcs = ["gpu_origin.proto"],

View File

@ -109,12 +109,10 @@ GlTexture GlCalculatorHelper::CreateSourceTexture(
return impl_->CreateSourceTexture(image_frame);
}
#ifdef __APPLE__
GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer,
int plane) {
return impl_->CreateSourceTexture(pixel_buffer, plane);
}
#endif
void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer,
int* width, int* height) {

View File

@ -29,10 +29,6 @@
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/graph_support.h"
#ifdef __APPLE__
#include <CoreVideo/CoreVideo.h>
#endif // __APPLE__
namespace mediapipe {
class GlCalculatorHelperImpl;
@ -111,24 +107,35 @@ class GlCalculatorHelper {
// where it is supported (iOS, for now) they take advantage of memory sharing
// between the CPU and GPU, avoiding memory copies.
// Creates a texture representing an input frame, and manages sync token.
// Gives access to an input frame as an OpenGL texture for reading (sampling).
//
// IMPORTANT: the returned GlTexture should be treated as a short-term view
// into the frame (typically for the duration of a Process call). Do not store
// it as a member in your calculator. If you need to keep a frame around,
// store the GpuBuffer instead, and call CreateSourceTexture again on each
// Process call.
//
// TODO: rename this; the use of "Create" makes this sound more expensive than
// it is.
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer);
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
GlTexture CreateSourceTexture(const mediapipe::Image& image);
#ifdef __APPLE__
// Creates a texture from a plane of a planar buffer.
// Gives read access to a plane of a planar buffer.
// The plane index is zero-based. The number of planes depends on the
// internal format of the buffer.
// Note: multi-plane support is not available on all platforms.
GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer, int plane);
#endif
// Convenience function for converting an ImageFrame to GpuBuffer and then
// accessing it as a texture.
GlTexture CreateSourceTexture(const ImageFrame& image_frame);
// Extracts GpuBuffer dimensions without creating a texture.
ABSL_DEPRECATED("Use width and height methods on GpuBuffer instead")
void GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, int* width,
int* height);
// Creates a texture representing an output frame, and manages sync token.
// Gives access to an OpenGL texture for writing (rendering) a new frame.
// TODO: This should either return errors or a status.
GlTexture CreateDestinationTexture(
int output_width, int output_height,

View File

@ -62,10 +62,7 @@ class GlCalculatorHelperImpl {
private:
// Makes a GpuBuffer accessible as a texture in the GL context.
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, int plane,
bool for_reading);
void AttachGlTexture(GlTexture& texture, const GpuBuffer& gpu_buffer,
int plane, bool for_reading);
GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view);
// Create the framebuffer for rendering.
void CreateFramebuffer();

View File

@ -91,9 +91,7 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
}
GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
int plane, bool for_reading) {
GlTextureView view = gpu_buffer.GetGlTextureView(plane, for_reading);
GlTextureView view) {
if (gpu_buffer.format() != GpuBufferFormat::kUnknown) {
// TODO: do the params need to be reset here??
glBindTexture(view.target(), view.name());
@ -109,19 +107,18 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer,
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
const GpuBuffer& gpu_buffer) {
return MapGpuBuffer(gpu_buffer, 0, true);
return CreateSourceTexture(gpu_buffer, 0);
}
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
const GpuBuffer& gpu_buffer, int plane) {
return MapGpuBuffer(gpu_buffer, plane, true);
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(plane));
}
GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
const ImageFrame& image_frame) {
GlTexture texture =
MapGpuBuffer(GpuBuffer::CopyingImageFrame(image_frame), 0, true);
return texture;
auto gpu_buffer = GpuBuffer::CopyingImageFrame(image_frame);
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureReadView(0));
}
template <>
@ -150,11 +147,9 @@ GlTexture GlCalculatorHelperImpl::CreateDestinationTexture(
CreateFramebuffer();
}
GpuBuffer buffer =
GpuBuffer gpu_buffer =
gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format);
GlTexture texture = MapGpuBuffer(buffer, 0, false);
return texture;
return MapGpuBuffer(gpu_buffer, gpu_buffer.GetGlTextureWriteView(0));
}
} // namespace mediapipe

View File

@ -14,6 +14,9 @@
#include "mediapipe/gpu/gl_texture_buffer.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gl_texture_view.h"
namespace mediapipe {
std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Wrap(
@ -122,6 +125,15 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
if (alignment != 4 && data) glPixelStorei(GL_UNPACK_ALIGNMENT, 4);
// TODO: does this need to set the texture params? We set them again when the
// texture is actually acccessed via GlTexture[View]. Or should they always be
// set on creation?
if (format_ != GpuBufferFormat::kUnknown) {
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format_, /*plane=*/0, context->GetGlVersion());
context->SetStandardTextureParams(target_, info.gl_internal_format);
}
glBindTexture(target_, 0);
// Use the deletion callback to delete the texture on the context
@ -167,7 +179,7 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
producer_context_ = producer_sync_->GetContext();
}
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) {
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
absl::MutexLock lock(&consumer_sync_mutex_);
consumer_multi_sync_->Add(std::move(cons_token));
}
@ -181,7 +193,7 @@ GlTextureBuffer::~GlTextureBuffer() {
}
}
void GlTextureBuffer::WaitUntilComplete() {
void GlTextureBuffer::WaitUntilComplete() const {
// Buffers created by the application (using the constructor that wraps an
// existing texture) have no sync token and are assumed to be already
// complete.
@ -190,7 +202,7 @@ void GlTextureBuffer::WaitUntilComplete() {
}
}
void GlTextureBuffer::WaitOnGpu() {
void GlTextureBuffer::WaitOnGpu() const {
// Buffers created by the application (using the constructor that wraps an
// existing texture) have no sync token and are assumed to be already
// complete.
@ -212,4 +224,127 @@ void GlTextureBuffer::WaitForConsumersOnGpu() {
// precisely, on only one GL context.
}
GlTextureView GlTextureBuffer::GetGlTextureReadView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
CHECK_EQ(plane, 0);
// Insert wait call to sync with the producer.
WaitOnGpu();
GlTextureView::DetachFn detach = [this](mediapipe::GlTextureView& texture) {
// Inform the GlTextureBuffer that we have finished accessing its
// contents, and create a consumer sync point.
DidRead(texture.gl_context()->CreateSyncToken());
};
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
std::move(gpu_buffer), plane, std::move(detach),
nullptr);
}
GlTextureView GlTextureBuffer::GetGlTextureWriteView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
CHECK_EQ(plane, 0);
// Insert wait call to sync with the producer.
WaitOnGpu();
Reuse(); // TODO: the producer wait should probably be part of Reuse in the
// case when there are no consumers.
GlTextureView::DoneWritingFn done_writing =
[this](const mediapipe::GlTextureView& texture) {
ViewDoneWriting(texture);
};
return GlTextureView(gl_context.get(), target(), name(), width(), height(),
std::move(gpu_buffer), plane, nullptr,
std::move(done_writing));
}
void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) {
// Inform the GlTextureBuffer that we have produced new content, and create
// a producer sync point.
Updated(view.gl_context()->CreateSyncToken());
#ifdef __ANDROID__
// On (some?) Android devices, the texture may need to be explicitly
// detached from the current framebuffer.
// TODO: is this necessary even with the unbind in BindFramebuffer?
// It is not clear if this affected other contexts too, but let's keep it
// while in doubt.
GLint type = GL_NONE;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
&type);
if (type == GL_TEXTURE) {
GLint color_attachment = 0;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment);
if (color_attachment == name()) {
glBindFramebuffer(GL_FRAMEBUFFER, 0);
}
}
// Some Android drivers log a GL_INVALID_ENUM error after the first
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
// even though it should be ok to ask for the type and get back GL_NONE.
// Let's just ignore any pending errors here.
GLenum error;
while ((error = glGetError()) != GL_NO_ERROR) {
}
#endif // __ANDROID__
}
static void ReadTexture(const GlTextureView& view, GpuBufferFormat format,
void* output, size_t size) {
// TODO: check buffer size? We could use glReadnPixels where available
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
// won't overflow the buffer with glReadPixels, we'd also need to check or
// reset several glPixelStore parameters (e.g. what if someone had the
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
CHECK(view.gl_context());
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format, view.plane(), view.gl_context()->GetGlVersion());
GLint current_fbo;
glGetIntegerv(GL_FRAMEBUFFER_BINDING, &current_fbo);
CHECK_NE(current_fbo, 0);
GLint color_attachment_name;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment_name);
if (color_attachment_name != view.name()) {
// Save the viewport. Note that we assume that the color attachment is a
// GL_TEXTURE_2D texture.
GLint viewport[4];
glGetIntegerv(GL_VIEWPORT, viewport);
// Set the data from GLTextureView object.
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
// Restore from the saved viewport and color attachment name.
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
color_attachment_name, 0);
} else {
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
}
}
std::unique_ptr<ImageFrame> GlTextureBuffer::AsImageFrame() const {
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
auto output = absl::make_unique<ImageFrame>(
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
auto view = GetGlTextureReadView(nullptr, 0);
ReadTexture(view, format(), output->MutablePixelData(),
output->PixelDataSize());
return output;
}
} // namespace mediapipe

View File

@ -25,13 +25,14 @@
#include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_context.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
namespace mediapipe {
class GlCalculatorHelperImpl;
// Implements a GPU memory buffer as an OpenGL texture. For internal use.
class GlTextureBuffer {
class GlTextureBuffer : public mediapipe::internal::GpuBufferStorage {
public:
// This is called when the texture buffer is deleted. It is passed a sync
// token created at that time on the GlContext. If the GlTextureBuffer has
@ -85,6 +86,13 @@ class GlTextureBuffer {
int height() const { return height_; }
GpuBufferFormat format() const { return format_; }
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) const override;
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) override;
void ViewDoneWriting(const GlTextureView& view) override;
std::unique_ptr<ImageFrame> AsImageFrame() const override;
// If this texture is going to be used outside of the context that produced
// it, this method should be called to ensure that its updated contents are
// available. When this method returns, all changed made before the call to
@ -94,13 +102,13 @@ class GlTextureBuffer {
// NOTE: This blocks the current CPU thread and makes the changes visible
// to the CPU. If you want to access the data via OpenGL, use WaitOnGpu
// instead.
void WaitUntilComplete();
void WaitUntilComplete() const;
// Call this method to synchronize the current GL context with the texture's
// producer. This will not block the current CPU thread, but will ensure that
// subsequent GL commands see the texture in its complete status, with all
// rendering done on the GPU by the generating context.
void WaitOnGpu();
void WaitOnGpu() const;
// Informs the buffer that its contents are going to be overwritten.
// This invalidates the current sync token.
@ -114,7 +122,7 @@ class GlTextureBuffer {
void Updated(std::shared_ptr<GlSyncPoint> prod_token);
// Informs the buffer that a consumer has finished reading from it.
void DidRead(std::shared_ptr<GlSyncPoint> cons_token);
void DidRead(std::shared_ptr<GlSyncPoint> cons_token) const;
// Waits for all pending consumers to finish accessing the current content
// of the texture. This (preferably the OnGpu version) should be called
@ -143,10 +151,11 @@ class GlTextureBuffer {
const GLenum target_ = GL_TEXTURE_2D;
// Token tracking changes to this texture. Used by WaitUntilComplete.
std::shared_ptr<GlSyncPoint> producer_sync_;
absl::Mutex consumer_sync_mutex_;
mutable absl::Mutex consumer_sync_mutex_;
// Tokens tracking the point when consumers finished using this texture.
std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_ ABSL_GUARDED_BY(
consumer_sync_mutex_) = absl::make_unique<GlMultiSyncPoint>();
mutable std::unique_ptr<GlMultiSyncPoint> consumer_multi_sync_
ABSL_GUARDED_BY(consumer_sync_mutex_) =
absl::make_unique<GlMultiSyncPoint>();
DeletionCallback deletion_callback_;
std::shared_ptr<GlContext> producer_context_;
};

View File

@ -0,0 +1,16 @@
#include "mediapipe/gpu/gl_texture_view.h"
namespace mediapipe {
void GlTextureView::Release() {
if (detach_) detach_(*this);
detach_ = nullptr;
gl_context_ = nullptr;
gpu_buffer_ = nullptr;
plane_ = 0;
name_ = 0;
width_ = 0;
height_ = 0;
}
} // namespace mediapipe

View File

@ -0,0 +1,86 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
#define MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_
#include <functional>
#include <memory>
#include <utility>
#include "mediapipe/gpu/gl_base.h"
namespace mediapipe {
class GlContext;
class GlTextureViewManager;
class GpuBuffer;
class GlTextureView {
public:
GlTextureView() {}
~GlTextureView() { Release(); }
// TODO: make this class move-only.
GlContext* gl_context() const { return gl_context_; }
int width() const { return width_; }
int height() const { return height_; }
GLenum target() const { return target_; }
GLuint name() const { return name_; }
const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; }
int plane() const { return plane_; }
using DetachFn = std::function<void(GlTextureView&)>;
using DoneWritingFn = std::function<void(const GlTextureView&)>;
private:
friend class GpuBuffer;
friend class GlTextureBuffer;
friend class GpuBufferStorageCvPixelBuffer;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
DetachFn detach, DoneWritingFn done_writing)
: gl_context_(context),
target_(target),
name_(name),
width_(width),
height_(height),
gpu_buffer_(std::move(gpu_buffer)),
plane_(plane),
detach_(std::move(detach)),
done_writing_(std::move(done_writing)) {}
// TODO: remove this friend declaration.
friend class GlTexture;
void Release();
// TODO: make this non-const.
void DoneWriting() const {
if (done_writing_) done_writing_(*this);
}
GlContext* gl_context_ = nullptr;
GLenum target_ = GL_TEXTURE_2D;
GLuint name_ = 0;
// Note: when scale is not 1, we still give the nominal size of the image.
int width_ = 0;
int height_ = 0;
std::shared_ptr<GpuBuffer> gpu_buffer_; // using shared_ptr temporarily
int plane_ = 0;
DetachFn detach_;
DoneWritingFn done_writing_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GL_TEXTURE_VIEW_H_

View File

@ -8,62 +8,7 @@
namespace mediapipe {
void GlTextureView::Release() {
if (detach_) detach_(*this);
detach_ = nullptr;
gl_context_ = nullptr;
gpu_buffer_ = nullptr;
plane_ = 0;
name_ = 0;
width_ = 0;
height_ = 0;
}
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#if TARGET_OS_OSX
typedef CVOpenGLTextureRef CVTextureType;
#else
typedef CVOpenGLESTextureRef CVTextureType;
#endif // TARGET_OS_OSX
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
CVReturn err;
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
#if TARGET_OS_OSX
CVTextureType cv_texture_temp;
err = CVOpenGLTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(),
GetCVPixelBufferRef(), NULL, &cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
[cv_texture](
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
#else
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format(), plane, gl_context->GetGlVersion());
CVTextureType cv_texture_temp;
err = CVOpenGLESTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(),
GetCVPixelBufferRef(), NULL, GL_TEXTURE_2D, info.gl_internal_format,
width() / info.downscale, height() / info.downscale, info.gl_format,
info.gl_type, plane, &cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
CVOpenGLESTextureGetName(*cv_texture), width(), height(), *this, plane,
[cv_texture](
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
#endif // TARGET_OS_OSX
}
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame);
@ -72,187 +17,11 @@ GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
CHECK_OK(maybe_buffer.status());
return GpuBuffer(std::move(maybe_buffer).value());
}
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
CHECK(GetCVPixelBufferRef());
return CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef());
}
void GlTextureView::DoneWriting() const {
CHECK(gpu_buffer_);
#if TARGET_IPHONE_SIMULATOR
CVPixelBufferRef pixel_buffer = gpu_buffer_.GetCVPixelBufferRef();
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferLockBaseAddress failed: " << err;
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
uint8_t* pixel_ptr =
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
if (pixel_format == kCVPixelFormatType_32BGRA) {
// TODO: restore previous framebuffer? Move this to helper so we
// can use BindFramebuffer?
glViewport(0, 0, width(), height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, target(),
name(), 0);
size_t contiguous_bytes_per_row = width() * 4;
if (bytes_per_row == contiguous_bytes_per_row) {
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
pixel_ptr);
} else {
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
height());
uint8_t* temp_ptr = contiguous_buffer.data();
glReadPixels(0, 0, width(), height(), GL_BGRA, GL_UNSIGNED_BYTE,
temp_ptr);
for (int i = 0; i < height(); ++i) {
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
temp_ptr += contiguous_bytes_per_row;
pixel_ptr += bytes_per_row;
}
}
} else {
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
}
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
#endif
}
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
GlTextureView GpuBuffer::GetGlTextureView(int plane, bool for_reading) const {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
const GlTextureBufferSharedPtr& texture_buffer =
GetGlTextureBufferSharedPtr();
// Insert wait call to sync with the producer.
texture_buffer->WaitOnGpu();
CHECK_EQ(plane, 0);
GlTextureView::DetachFn detach;
if (for_reading) {
detach = [](mediapipe::GlTextureView& texture) {
// Inform the GlTextureBuffer that we have finished accessing its
// contents, and create a consumer sync point.
texture.gpu_buffer().GetGlTextureBufferSharedPtr()->DidRead(
texture.gl_context()->CreateSyncToken());
};
}
return GlTextureView(gl_context.get(), texture_buffer->target(),
texture_buffer->name(), width(), height(), *this, plane,
std::move(detach));
}
GpuBuffer GpuBuffer::CopyingImageFrame(const ImageFrame& image_frame) {
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
auto buffer = GlTextureBuffer::Create(image_frame);
// TODO: does this need to set the texture params? We set them again when the
// texture is actually acccessed via GlTexture[View]. Or should they always be
// set on creation?
if (buffer->format() != GpuBufferFormat::kUnknown) {
glBindTexture(GL_TEXTURE_2D, buffer->name());
GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
buffer->format(), /*plane=*/0, gl_context->GetGlVersion());
gl_context->SetStandardTextureParams(buffer->target(),
info.gl_internal_format);
glBindTexture(GL_TEXTURE_2D, 0);
}
return GpuBuffer(std::move(buffer));
}
static void ReadTexture(const GlTextureView& view, void* output, size_t size) {
// TODO: check buffer size? We could use glReadnPixels where available
// (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read
// won't overflow the buffer with glReadPixels, we'd also need to check or
// reset several glPixelStore parameters (e.g. what if someone had the
// ill-advised idea of setting GL_PACK_SKIP_PIXELS?).
CHECK(view.gl_context());
GlTextureInfo info =
GlTextureInfoForGpuBufferFormat(view.gpu_buffer().format(), view.plane(),
view.gl_context()->GetGlVersion());
GLint current_fbo;
glGetIntegerv(GL_FRAMEBUFFER_BINDING, &current_fbo);
CHECK_NE(current_fbo, 0);
GLint color_attachment_name;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment_name);
if (color_attachment_name != view.name()) {
// Save the viewport. Note that we assume that the color attachment is a
// GL_TEXTURE_2D texture.
GLint viewport[4];
glGetIntegerv(GL_VIEWPORT, viewport);
// Set the data from GLTextureView object.
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
// Restore from the saved viewport and color attachment name.
glViewport(viewport[0], viewport[1], viewport[2], viewport[3]);
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D,
color_attachment_name, 0);
} else {
glReadPixels(0, 0, view.width(), view.height(), info.gl_format,
info.gl_type, output);
}
}
std::unique_ptr<ImageFrame> GpuBuffer::AsImageFrame() const {
ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(format());
auto output = absl::make_unique<ImageFrame>(
image_format, width(), height(), ImageFrame::kGlDefaultAlignmentBoundary);
auto view = GetGlTextureView(0, true);
ReadTexture(view, output->MutablePixelData(), output->PixelDataSize());
return output;
}
void GlTextureView::DoneWriting() const {
CHECK(gpu_buffer_);
// Inform the GlTextureBuffer that we have produced new content, and create
// a producer sync point.
gpu_buffer_.GetGlTextureBufferSharedPtr()->Updated(
gl_context()->CreateSyncToken());
#ifdef __ANDROID__
// On (some?) Android devices, the texture may need to be explicitly
// detached from the current framebuffer.
// TODO: is this necessary even with the unbind in BindFramebuffer?
// It is not clear if this affected other contexts too, but let's keep it
// while in doubt.
GLint type = GL_NONE;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_TYPE,
&type);
if (type == GL_TEXTURE) {
GLint color_attachment = 0;
glGetFramebufferAttachmentParameteriv(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0,
GL_FRAMEBUFFER_ATTACHMENT_OBJECT_NAME,
&color_attachment);
if (color_attachment == name()) {
glBindFramebuffer(GL_FRAMEBUFFER, 0);
}
}
// Some Android drivers log a GL_INVALID_ENUM error after the first
// glGetFramebufferAttachmentParameteriv call if there is no bound object,
// even though it should be ok to ask for the type and get back GL_NONE.
// Let's just ignore any pending errors here.
GLenum error;
while ((error = glGetError()) != GL_NO_ERROR) {
}
#endif // __ANDROID__
return GpuBuffer(GlTextureBuffer::Create(image_frame));
}
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER

View File

@ -19,7 +19,9 @@
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_texture_view.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
#if defined(__APPLE__)
#include <CoreVideo/CoreVideo.h>
@ -27,6 +29,10 @@
#include "mediapipe/objc/CFHolder.h"
#endif // defined(__APPLE__)
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
#include "mediapipe/gpu/gl_texture_buffer.h"
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
@ -34,7 +40,6 @@
namespace mediapipe {
class GlContext;
class GlTextureView;
// This class wraps a platform-specific buffer of GPU data.
// An instance of GpuBuffer acts as an opaque reference to the underlying
@ -71,9 +76,9 @@ class GpuBuffer {
}
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
int width() const;
int height() const;
GpuBufferFormat format() const;
int width() const { return current_storage().width(); }
int height() const { return current_storage().height(); }
GpuBufferFormat format() const { return current_storage().format(); }
// Converts to true iff valid.
explicit operator bool() const { return operator!=(nullptr); }
@ -88,8 +93,15 @@ class GpuBuffer {
// Allow assignment from nullptr.
GpuBuffer& operator=(std::nullptr_t other);
// TODO: split into read and write, remove const from write.
GlTextureView GetGlTextureView(int plane, bool for_reading) const;
GlTextureView GetGlTextureReadView(int plane) const {
return current_storage().GetGlTextureReadView(
std::make_shared<GpuBuffer>(*this), plane);
}
GlTextureView GetGlTextureWriteView(int plane) {
return current_storage().GetGlTextureWriteView(
std::make_shared<GpuBuffer>(*this), plane);
}
// Make a GpuBuffer copying the data from an ImageFrame.
static GpuBuffer CopyingImageFrame(const ImageFrame& image_frame);
@ -99,114 +111,84 @@ class GpuBuffer {
// In order to work correctly across platforms, callers should always treat
// the returned ImageFrame as if it shares memory with the GpuBuffer, i.e.
// treat it as immutable if the GpuBuffer must not be modified.
std::unique_ptr<ImageFrame> AsImageFrame() const;
std::unique_ptr<ImageFrame> AsImageFrame() const {
return current_storage().AsImageFrame();
}
private:
class PlaceholderGpuBufferStorage
: public mediapipe::internal::GpuBufferStorage {
public:
int width() const override { return 0; }
int height() const override { return 0; }
virtual GpuBufferFormat format() const override {
return GpuBufferFormat::kUnknown;
}
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) const override {
return {};
}
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) override {
return {};
}
void ViewDoneWriting(const GlTextureView& view) override{};
std::unique_ptr<ImageFrame> AsImageFrame() const override {
return nullptr;
}
};
mediapipe::internal::GpuBufferStorage& no_storage() const {
static PlaceholderGpuBufferStorage placeholder;
return placeholder;
}
const mediapipe::internal::GpuBufferStorage& current_storage() const {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
CFHolder<CVPixelBufferRef> pixel_buffer_;
if (pixel_buffer_ != nullptr) return pixel_buffer_;
#else
if (texture_buffer_) return *texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return no_storage();
}
mediapipe::internal::GpuBufferStorage& current_storage() {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
if (pixel_buffer_ != nullptr) return pixel_buffer_;
#else
if (texture_buffer_) return *texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return no_storage();
}
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
GpuBufferStorageCvPixelBuffer pixel_buffer_;
#else
GlTextureBufferSharedPtr texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
};
class GlTextureView {
public:
GlTextureView() {}
~GlTextureView() { Release(); }
// TODO: make this class move-only.
GlContext* gl_context() const { return gl_context_; }
int width() const { return width_; }
int height() const { return height_; }
GLenum target() const { return target_; }
GLuint name() const { return name_; }
const GpuBuffer& gpu_buffer() const { return gpu_buffer_; }
int plane() const { return plane_; }
private:
friend class GpuBuffer;
using DetachFn = std::function<void(GlTextureView&)>;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, GpuBuffer gpu_buffer, int plane, DetachFn detach)
: gl_context_(context),
target_(target),
name_(name),
width_(width),
height_(height),
gpu_buffer_(std::move(gpu_buffer)),
plane_(plane),
detach_(std::move(detach)) {}
// TODO: remove this friend declaration.
friend class GlTexture;
void Release();
// TODO: make this non-const.
void DoneWriting() const;
GlContext* gl_context_ = nullptr;
GLenum target_ = GL_TEXTURE_2D;
GLuint name_ = 0;
// Note: when scale is not 1, we still give the nominal size of the image.
int width_ = 0;
int height_ = 0;
GpuBuffer gpu_buffer_;
int plane_ = 0;
DetachFn detach_;
};
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
return &current_storage() == &no_storage();
}
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
inline int GpuBuffer::width() const {
return static_cast<int>(CVPixelBufferGetWidth(*pixel_buffer_));
}
inline int GpuBuffer::height() const {
return static_cast<int>(CVPixelBufferGetHeight(*pixel_buffer_));
}
inline GpuBufferFormat GpuBuffer::format() const {
return GpuBufferFormatForCVPixelFormat(
CVPixelBufferGetPixelFormatType(*pixel_buffer_));
}
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
return pixel_buffer_ == other;
}
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
return pixel_buffer_ == other.pixel_buffer_;
}
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
pixel_buffer_.reset(other);
return *this;
}
#else
inline int GpuBuffer::width() const { return texture_buffer_->width(); }
inline int GpuBuffer::height() const { return texture_buffer_->height(); }
inline GpuBufferFormat GpuBuffer::format() const {
return texture_buffer_->format();
}
inline bool GpuBuffer::operator==(std::nullptr_t other) const {
return texture_buffer_ == other;
}
inline bool GpuBuffer::operator==(const GpuBuffer& other) const {
return texture_buffer_ == other.texture_buffer_;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
}
inline GpuBuffer& GpuBuffer::operator=(std::nullptr_t other) {
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
pixel_buffer_.reset(other);
#else
texture_buffer_ = other;
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return *this;
}
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_BUFFER_H_

View File

@ -0,0 +1,41 @@
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
namespace mediapipe {
class GlTextureView;
class GpuBuffer;
} // namespace mediapipe
namespace mediapipe {
namespace internal {
using mediapipe::GlTextureView;
using mediapipe::GpuBuffer;
using mediapipe::GpuBufferFormat;
class GlTextureViewManager {
public:
virtual ~GlTextureViewManager() = default;
virtual GlTextureView GetGlTextureReadView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const = 0;
virtual GlTextureView GetGlTextureWriteView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) = 0;
virtual void ViewDoneWriting(const GlTextureView& view) = 0;
};
class GpuBufferStorage : public GlTextureViewManager {
public:
virtual ~GpuBufferStorage() = default;
virtual int width() const = 0;
virtual int height() const = 0;
virtual GpuBufferFormat format() const = 0;
virtual std::unique_ptr<ImageFrame> AsImageFrame() const = 0;
};
} // namespace internal
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_H_

View File

@ -0,0 +1,116 @@
#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h"
#include "mediapipe/gpu/gl_context.h"
#include "mediapipe/objc/util.h"
namespace mediapipe {
#if TARGET_OS_OSX
typedef CVOpenGLTextureRef CVTextureType;
#else
typedef CVOpenGLESTextureRef CVTextureType;
#endif // TARGET_OS_OSX
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureReadView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) const {
CVReturn err;
auto gl_context = GlContext::GetCurrent();
CHECK(gl_context);
#if TARGET_OS_OSX
CVTextureType cv_texture_temp;
err = CVOpenGLTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
&cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLTextureGetTarget(*cv_texture),
CVOpenGLTextureGetName(*cv_texture), width(), height(), *this, plane,
[cv_texture](
mediapipe::GlTextureView&) { /* only retains cv_texture */ });
#else
const GlTextureInfo info = GlTextureInfoForGpuBufferFormat(
format(), plane, gl_context->GetGlVersion());
CVTextureType cv_texture_temp;
err = CVOpenGLESTextureCacheCreateTextureFromImage(
kCFAllocatorDefault, gl_context->cv_texture_cache(), **this, NULL,
GL_TEXTURE_2D, info.gl_internal_format, width() / info.downscale,
height() / info.downscale, info.gl_format, info.gl_type, plane,
&cv_texture_temp);
CHECK(cv_texture_temp && !err)
<< "CVOpenGLESTextureCacheCreateTextureFromImage failed: " << err;
CFHolder<CVTextureType> cv_texture;
cv_texture.adopt(cv_texture_temp);
return GlTextureView(
gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture),
CVOpenGLESTextureGetName(*cv_texture), width(), height(),
std::move(gpu_buffer), plane,
[cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ },
// TODO: make GetGlTextureView for write view non-const, remove cast
// Note: we have to copy *this here because this storage is currently
// stored in GpuBuffer by value, and so the this pointer becomes invalid
// if the GpuBuffer is moved/copied. TODO: fix this.
[me = *this](const mediapipe::GlTextureView& view) {
const_cast<GpuBufferStorageCvPixelBuffer*>(&me)->ViewDoneWriting(view);
});
#endif // TARGET_OS_OSX
}
GlTextureView GpuBufferStorageCvPixelBuffer::GetGlTextureWriteView(
std::shared_ptr<GpuBuffer> gpu_buffer, int plane) {
// For this storage there is currently no difference between read and write
// views, so we delegate to the read method.
return GetGlTextureReadView(std::move(gpu_buffer), plane);
}
void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) {
#if TARGET_IPHONE_SIMULATOR
CVPixelBufferRef pixel_buffer = **this;
CHECK(pixel_buffer);
CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferLockBaseAddress failed: " << err;
OSType pixel_format = CVPixelBufferGetPixelFormatType(pixel_buffer);
size_t bytes_per_row = CVPixelBufferGetBytesPerRow(pixel_buffer);
uint8_t* pixel_ptr =
static_cast<uint8_t*>(CVPixelBufferGetBaseAddress(pixel_buffer));
if (pixel_format == kCVPixelFormatType_32BGRA) {
// TODO: restore previous framebuffer? Move this to helper so we
// can use BindFramebuffer?
glViewport(0, 0, view.width(), view.height());
glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(),
view.name(), 0);
size_t contiguous_bytes_per_row = view.width() * 4;
if (bytes_per_row == contiguous_bytes_per_row) {
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
pixel_ptr);
} else {
std::vector<uint8_t> contiguous_buffer(contiguous_bytes_per_row *
view.height());
uint8_t* temp_ptr = contiguous_buffer.data();
glReadPixels(0, 0, view.width(), view.height(), GL_BGRA, GL_UNSIGNED_BYTE,
temp_ptr);
for (int i = 0; i < view.height(); ++i) {
memcpy(pixel_ptr, temp_ptr, contiguous_bytes_per_row);
temp_ptr += contiguous_bytes_per_row;
pixel_ptr += bytes_per_row;
}
}
} else {
LOG(ERROR) << "unsupported pixel format: " << pixel_format;
}
err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0);
CHECK(err == kCVReturnSuccess)
<< "CVPixelBufferUnlockBaseAddress failed: " << err;
#endif
}
std::unique_ptr<ImageFrame> GpuBufferStorageCvPixelBuffer::AsImageFrame()
const {
return CreateImageFrameForCVPixelBuffer(**this);
}
} // namespace mediapipe

View File

@ -0,0 +1,41 @@
#ifndef MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
#define MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_
#include <CoreVideo/CoreVideo.h>
#include "mediapipe/gpu/gl_texture_view.h"
#include "mediapipe/gpu/gpu_buffer_storage.h"
#include "mediapipe/objc/CFHolder.h"
namespace mediapipe {
class GlContext;
class GpuBufferStorageCvPixelBuffer
: public mediapipe::internal::GpuBufferStorage,
public CFHolder<CVPixelBufferRef> {
public:
using CFHolder<CVPixelBufferRef>::CFHolder;
GpuBufferStorageCvPixelBuffer(const CFHolder<CVPixelBufferRef>& other)
: CFHolder(other) {}
GpuBufferStorageCvPixelBuffer(CFHolder<CVPixelBufferRef>&& other)
: CFHolder(std::move(other)) {}
int width() const { return static_cast<int>(CVPixelBufferGetWidth(**this)); }
int height() const {
return static_cast<int>(CVPixelBufferGetHeight(**this));
}
virtual GpuBufferFormat format() const {
return GpuBufferFormatForCVPixelFormat(
CVPixelBufferGetPixelFormatType(**this));
}
GlTextureView GetGlTextureReadView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) const override;
GlTextureView GetGlTextureWriteView(std::shared_ptr<GpuBuffer> gpu_buffer,
int plane) override;
std::unique_ptr<ImageFrame> AsImageFrame() const override;
void ViewDoneWriting(const GlTextureView& view) override;
};
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_BUFFER_STORAGE_CV_PIXEL_BUFFER_H_

View File

@ -64,6 +64,13 @@ public class AppTextureFrame implements TextureFrame {
return timestamp;
}
/** Returns true if a call to waitUntilReleased() would block waiting for release. */
public boolean isNotYetReleased() {
synchronized (this) {
return inUse && releaseSyncToken == null;
}
}
/**
* Waits until the consumer is done with the texture.
*

View File

@ -26,7 +26,8 @@ android_library(
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_gpu_image.binarypb",
"//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu_image.binarypb",
"//mediapipe/modules/hand_landmark:handedness.txt",
"//mediapipe/modules/hand_landmark:hand_landmark.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_lite.tflite",
"//mediapipe/modules/hand_landmark:hand_landmark_full.tflite",
"//mediapipe/modules/palm_detection:palm_detection.tflite",
],
assets_dir = "",

View File

@ -78,6 +78,7 @@ public class Hands extends ImageSolutionBase {
Connection.create(HandLandmark.PINKY_PIP, HandLandmark.PINKY_DIP),
Connection.create(HandLandmark.PINKY_DIP, HandLandmark.PINKY_TIP));
private static final String MODEL_COMPLEXITY = "model_complexity";
private static final String NUM_HANDS = "num_hands";
private static final String USE_PREV_LANDMARKS = "use_prev_landmarks";
private static final String GPU_GRAPH_NAME = "hand_landmark_tracking_gpu_image.binarypb";
@ -131,6 +132,7 @@ public class Hands extends ImageSolutionBase {
initialize(context, solutionInfo, outputHandler);
Map<String, Packet> inputSidePackets = new HashMap<>();
inputSidePackets.put(NUM_HANDS, packetCreator.createInt32(options.maxNumHands()));
inputSidePackets.put(MODEL_COMPLEXITY, packetCreator.createInt32(options.modelComplexity()));
inputSidePackets.put(USE_PREV_LANDMARKS, packetCreator.createBool(!options.staticImageMode()));
start(inputSidePackets);
}

View File

@ -26,6 +26,10 @@ import com.google.auto.value.AutoValue;
* <p>maxNumHands: Maximum number of hands to detect. See details in
* https://solutions.mediapipe.dev/hands#max_num_hands.
*
* <p>modelComplexity: Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
* inference latency generally go up with the model complexity. See details in
* https://solutions.mediapipe.dev/hands#model_complexity.
*
* <p>minDetectionConfidence: Minimum confidence value ([0.0, 1.0]) for hand detection to be
* considered successful. See details in
* https://solutions.mediapipe.dev/hands#min_detection_confidence.
@ -43,6 +47,8 @@ public abstract class HandsOptions {
public abstract int maxNumHands();
public abstract int modelComplexity();
public abstract float minDetectionConfidence();
public abstract float minTrackingConfidence();
@ -59,6 +65,7 @@ public abstract class HandsOptions {
public Builder withDefaultValues() {
return setStaticImageMode(false)
.setMaxNumHands(2)
.setModelComplexity(1)
.setMinDetectionConfidence(0.5f)
.setMinTrackingConfidence(0.5f)
.setRunOnGpu(true);
@ -68,6 +75,8 @@ public abstract class HandsOptions {
public abstract Builder setMaxNumHands(int value);
public abstract Builder setModelComplexity(int value);
public abstract Builder setMinDetectionConfidence(float value);
public abstract Builder setMinTrackingConfidence(float value);

View File

@ -22,16 +22,30 @@ licenses(["notice"])
package(default_visibility = ["//visibility:public"])
exports_files([
"hand_landmark.tflite",
"hand_landmark_full.tflite",
"hand_landmark_lite.tflite",
"hand_landmark_sparse.tflite",
"handedness.txt",
])
mediapipe_simple_subgraph(
name = "hand_landmark_model_loader",
graph = "hand_landmark_model_loader.pbtxt",
register_as = "HandLandmarkModelLoader",
deps = [
"//mediapipe/calculators/core:constant_side_packet_calculator",
"//mediapipe/calculators/tflite:tflite_model_calculator",
"//mediapipe/calculators/util:local_file_contents_calculator",
"//mediapipe/framework/tool:switch_container",
],
)
mediapipe_simple_subgraph(
name = "hand_landmark_cpu",
graph = "hand_landmark_cpu.pbtxt",
register_as = "HandLandmarkCpu",
deps = [
":hand_landmark_model_loader",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",
@ -50,6 +64,7 @@ mediapipe_simple_subgraph(
graph = "hand_landmark_gpu.pbtxt",
register_as = "HandLandmarkGpu",
deps = [
":hand_landmark_model_loader",
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/tensor:image_to_tensor_calculator",

View File

@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
# (NormalizedRect)
input_stream: "ROI:hand_rect"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
# NOTE: if a hand is not present within the given ROI, for this particular
# timestamp there will not be an output packet in the LANDMARKS stream. However,
@ -40,16 +45,23 @@ node {
}
}
# Loads the hand landmark TF Lite model.
node {
calculator: "HandLandmarkModelLoader"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
output_side_packet: "MODEL:model"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "InferenceCalculator"
input_side_packet: "MODEL:model"
input_stream: "TENSORS:input_tensor"
output_stream: "TENSORS:output_tensors"
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
delegate {
xnnpack {}
}

Binary file not shown.

View File

@ -8,6 +8,11 @@ input_stream: "IMAGE:image"
# (NormalizedRect)
input_stream: "ROI:hand_rect"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# 21 hand landmarks within the given ROI. (NormalizedLandmarkList)
# NOTE: if a hand is not present within the given ROI, for this particular
# timestamp there will not be an output packet in the LANDMARKS stream. However,
@ -41,18 +46,21 @@ node {
}
}
# Loads the hand landmark TF Lite model.
node {
calculator: "HandLandmarkModelLoader"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
output_side_packet: "MODEL:model"
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "InferenceCalculator"
input_side_packet: "MODEL:model"
input_stream: "TENSORS:input_tensor"
output_stream: "TENSORS:output_tensors"
options: {
[mediapipe.InferenceCalculatorOptions.ext] {
model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite"
}
}
}
# Splits a vector of tensors to multiple vectors according to the ranges

Binary file not shown.

View File

@ -0,0 +1,63 @@
# MediaPipe graph to load a selected hand landmark TF Lite model.
type: "HandLandmarkModelLoader"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# TF Lite model represented as a FlatBuffer.
# (std::unique_ptr<tflite::FlatBufferModel, std::function<void(tflite::FlatBufferModel*)>>)
output_side_packet: "MODEL:model"
# Determines path to the desired pose landmark model file.
node {
calculator: "SwitchContainer"
input_side_packet: "SELECT:model_complexity"
output_side_packet: "PACKET:model_path"
options: {
[mediapipe.SwitchContainerOptions.ext] {
select: 1
contained_node: {
calculator: "ConstantSidePacketCalculator"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet {
string_value: "mediapipe/modules/hand_landmark/hand_landmark_lite.tflite"
}
}
}
}
contained_node: {
calculator: "ConstantSidePacketCalculator"
options: {
[mediapipe.ConstantSidePacketCalculatorOptions.ext]: {
packet {
string_value: "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
}
}
}
}
}
}
}
# Loads the file in the specified path into a blob.
node {
calculator: "LocalFileContentsCalculator"
input_side_packet: "FILE_PATH:model_path"
output_side_packet: "CONTENTS:model_blob"
options: {
[mediapipe.LocalFileContentsCalculatorOptions.ext]: {
text_mode: false
}
}
}
# Converts the input blob into a TF Lite model.
node {
calculator: "TfLiteModelCalculator"
input_side_packet: "MODEL_BLOB:model_blob"
output_side_packet: "MODEL:model"
}

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -177,6 +182,7 @@ node {
# Detect hand landmarks for the specific hand rect.
node {
calculator: "HandLandmarkCpu"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_stream: "IMAGE:image_for_landmarks"
input_stream: "ROI:single_hand_rect"
output_stream: "LANDMARKS:single_hand_landmarks"

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -85,6 +90,7 @@ node {
calculator: "HandLandmarkTrackingCpu"
input_stream: "IMAGE:image_frame"
input_side_packet: "NUM_HANDS:num_hands"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
output_stream: "LANDMARKS:multi_hand_landmarks"
output_stream: "HANDEDNESS:multi_handedness"

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -178,6 +183,7 @@ node {
# Detect hand landmarks for the specific hand rect.
node {
calculator: "HandLandmarkGpu"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_stream: "IMAGE:image_for_landmarks"
input_stream: "ROI:single_hand_rect"
output_stream: "LANDMARKS:single_hand_landmarks"

View File

@ -14,6 +14,11 @@ input_stream: "IMAGE:image"
# Max number of hands to detect/track. (int)
input_side_packet: "NUM_HANDS:num_hands"
# Complexity of the hand landmark model: 0 or 1. Landmark accuracy as well as
# inference latency generally go up with the model complexity. If unspecified,
# functions as set to 1. (int)
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
# Whether landmarks on the previous image should be used to help localize
# landmarks on the current image. (bool)
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
@ -85,6 +90,7 @@ node {
calculator: "HandLandmarkTrackingGpu"
input_stream: "IMAGE:gpu_buffer"
input_side_packet: "NUM_HANDS:num_hands"
input_side_packet: "MODEL_COMPLEXITY:model_complexity"
input_side_packet: "USE_PREV_LANDMARKS:use_prev_landmarks"
output_stream: "LANDMARKS:multi_hand_landmarks"
output_stream: "HANDEDNESS:multi_handedness"

View File

@ -7,8 +7,8 @@
# - "face_landmark.tflite" is available at
# "mediapipe/modules/face_landmark/face_landmark.tflite"
#
# - "hand_landmark.tflite" is available at
# "mediapipe/modules/hand_landmark/hand_landmark.tflite"
# - "hand_landmark_full.tflite" is available at
# "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
#
# - "hand_recrop.tflite" is available at
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"

View File

@ -7,8 +7,8 @@
# - "face_landmark.tflite" is available at
# "mediapipe/modules/face_landmark/face_landmark.tflite"
#
# - "hand_landmark.tflite" is available at
# "mediapipe/modules/hand_landmark/hand_landmark.tflite"
# - "hand_landmark_full.tflite" is available at
# "mediapipe/modules/hand_landmark/hand_landmark_full.tflite"
#
# - "hand_recrop.tflite" is available at
# "mediapipe/modules/holistic_landmark/hand_recrop.tflite"

View File

@ -237,7 +237,7 @@ absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) {
}
bool FilterDetectionCalculator::IsValidLabel(const std::string& label) {
bool match = !limit_labels_ || ContainsKey(allowed_labels_, label);
bool match = !limit_labels_ || allowed_labels_.contains(label);
if (!match) {
// If no exact match is found, check for regular expression
// comparions in the allowed_labels.

View File

@ -21,6 +21,9 @@ cc_library(
features = ["-parse_headers"],
linkopts = [
"-framework Accelerate",
"-framework CoreFoundation",
"-framework CoreGraphics",
"-framework CoreVideo",
],
visibility = ["//mediapipe/framework:mediapipe_internal"],
deps = [

View File

@ -17,6 +17,7 @@
#import <Accelerate/Accelerate.h>
#import <CoreFoundation/CoreFoundation.h>
#import <CoreGraphics/CoreGraphics.h>
#import <CoreVideo/CoreVideo.h>
#include "mediapipe/framework/formats/image_frame.h"

View File

@ -89,6 +89,7 @@ class Hands(SolutionBase):
def __init__(self,
static_image_mode=False,
max_num_hands=2,
model_complexity=1,
min_detection_confidence=0.5,
min_tracking_confidence=0.5):
"""Initializes a MediaPipe Hand object.
@ -99,6 +100,10 @@ class Hands(SolutionBase):
https://solutions.mediapipe.dev/hands#static_image_mode.
max_num_hands: Maximum number of hands to detect. See details in
https://solutions.mediapipe.dev/hands#max_num_hands.
model_complexity: Complexity of the hand landmark model: 0 or 1.
Landmark accuracy as well as inference latency generally go up with the
model complexity. See details in
https://solutions.mediapipe.dev/hands#model_complexity.
min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for hand
detection to be considered successful. See details in
https://solutions.mediapipe.dev/hands#min_detection_confidence.
@ -109,6 +114,7 @@ class Hands(SolutionBase):
super().__init__(
binary_graph_path=_BINARYPB_FILE_PATH,
side_inputs={
'model_complexity': model_complexity,
'num_hands': max_num_hands,
'use_prev_landmarks': not static_image_mode,
},

View File

@ -32,7 +32,8 @@ from mediapipe.python.solutions import hands as mp_hands
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
DIFF_THRESHOLD = 20 # pixels
LITE_MODEL_DIFF_THRESHOLD = 25 # pixels
FULL_MODEL_DIFF_THRESHOLD = 20 # pixels
EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
[289, 237], [322, 203], [219, 216],
[238, 138], [249, 90], [253, 51],
@ -40,7 +41,7 @@ EXPECTED_HAND_COORDINATES_PREDICTION = [[[138, 343], [211, 330], [257, 286],
[185, 19], [138, 208], [131, 127],
[124, 77], [117, 36], [106, 222],
[92, 159], [79, 124], [68, 93]],
[[580, 36], [504, 50], [459, 94],
[[580, 34], [504, 50], [459, 94],
[429, 146], [397, 182], [507, 167],
[479, 245], [469, 292], [464, 330],
[545, 180], [534, 265], [533, 319],
@ -75,14 +76,18 @@ class HandsTest(parameterized.TestCase):
self.assertIsNone(results.multi_hand_landmarks)
self.assertIsNone(results.multi_handedness)
@parameterized.named_parameters(('static_image_mode', True, 1),
('video_mode', False, 5))
def test_multi_hands(self, static_image_mode, num_frames):
@parameterized.named_parameters(
('static_image_mode_with_lite_model', True, 0, 5),
('video_mode_with_lite_model', False, 0, 10),
('static_image_mode_with_full_model', True, 1, 5),
('video_mode_with_full_model', False, 1, 10))
def test_multi_hands(self, static_image_mode, model_complexity, num_frames):
image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg')
image = cv2.imread(image_path)
with mp_hands.Hands(
static_image_mode=static_image_mode,
max_num_hands=2,
model_complexity=model_complexity,
min_detection_confidence=0.5) as hands:
for idx in range(num_frames):
results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
@ -104,7 +109,8 @@ class HandsTest(parameterized.TestCase):
prediction_error = np.abs(
np.asarray(multi_hand_coordinates) -
np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION))
npt.assert_array_less(prediction_error, DIFF_THRESHOLD)
diff_threshold = LITE_MODEL_DIFF_THRESHOLD if model_complexity == 0 else FULL_MODEL_DIFF_THRESHOLD
npt.assert_array_less(prediction_error, diff_threshold)
if __name__ == '__main__':

View File

@ -18,6 +18,7 @@
#include <memory>
#include <utility>
#include "absl/status/status.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
@ -81,6 +82,32 @@ ObjectDef GetSSBOObjectDef(int channels) {
return gpu_object_def;
}
#ifdef __ANDROID__
cl::InferenceOptions GetClInferenceOptions(const InferenceOptions& options) {
cl::InferenceOptions result{};
result.priority1 = options.priority1;
result.priority2 = options.priority2;
result.priority3 = options.priority3;
result.usage = options.usage;
return result;
}
absl::Status VerifyShapes(const std::vector<TensorObjectDef>& actual,
const std::vector<BHWC>& expected) {
RET_CHECK_EQ(actual.size(), expected.size());
const int size = actual.size();
for (int i = 0; i < size; ++i) {
const auto& dims = actual[i].dimensions;
const BHWC& bhwc = expected[i];
RET_CHECK(dims.b == bhwc.b && dims.h == bhwc.h && dims.w == bhwc.w &&
dims.c == bhwc.c);
}
return absl::OkStatus();
}
#endif // __ANDROID__
} // namespace
absl::Status TFLiteGPURunner::InitializeWithModel(
@ -139,16 +166,16 @@ absl::Status TFLiteGPURunner::Build() {
// try to build OpenCL first. If something goes wrong, fall back to OpenGL.
absl::Status status = InitializeOpenCL(&builder);
if (status.ok()) {
LOG(INFO) << "OpenCL backend is used.";
VLOG(2) << "OpenCL backend is used.";
} else {
LOG(ERROR) << "Falling back to OpenGL: " << status.message();
VLOG(2) << "Falling back to OpenGL: " << status.message();
MP_RETURN_IF_ERROR(InitializeOpenGL(&builder));
}
}
// Both graphs are not needed anymore. Make sure they are deleted.
// GL graph not needed anymore, CL graph maybe needed for serialized model
// calculation.
graph_gl_.reset(nullptr);
graph_cl_.reset(nullptr);
// 2. Describe output/input objects for created builder.
for (int flow_index = 0; flow_index < input_shapes_.size(); ++flow_index) {
@ -204,18 +231,57 @@ absl::Status TFLiteGPURunner::InitializeOpenCL(
env_options.serialized_binary_cache = serialized_binary_cache_;
}
cl::InferenceEnvironmentProperties properties;
cl::InferenceOptions cl_options;
cl_options.priority1 = options_.priority1;
cl_options.priority2 = options_.priority2;
cl_options.priority3 = options_.priority3;
cl_options.usage = options_.usage;
MP_RETURN_IF_ERROR(
cl::NewInferenceEnvironment(env_options, &cl_environment_, &properties));
// Try to initialize from serialized model first.
if (!serialized_model_.empty()) {
absl::Status init_status = InitializeOpenCLFromSerializedModel(builder);
if (init_status.ok()) {
serialized_model_used_ = true;
return absl::OkStatus();
}
VLOG(2) << "Failed to init from serialized model: [" << init_status
<< "]. Trying to init from scratch.";
}
// Initialize from scratch.
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
GraphFloat32 graph_cl;
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
MP_RETURN_IF_ERROR(cl_environment_->NewInferenceBuilder(
cl_options, std::move(*graph_cl_), builder));
#endif
cl_options, std::move(graph_cl), builder));
#endif // __ANDROID__
return absl::OkStatus();
}
#ifdef __ANDROID__
absl::Status TFLiteGPURunner::InitializeOpenCLFromSerializedModel(
std::unique_ptr<InferenceBuilder>* builder) {
MP_RETURN_IF_ERROR(
cl_environment_->NewInferenceBuilder(serialized_model_, builder));
MP_RETURN_IF_ERROR(VerifyShapes(builder->get()->inputs(), input_shapes_));
return VerifyShapes(builder->get()->outputs(), output_shapes_);
}
absl::StatusOr<std::vector<uint8_t>> TFLiteGPURunner::GetSerializedModel() {
RET_CHECK(runner_) << "Runner is in invalid state.";
if (serialized_model_used_) {
return serialized_model_;
}
RET_CHECK(graph_cl_) << "CL graph is not initialized.";
GraphFloat32 graph_cl;
MP_RETURN_IF_ERROR(graph_cl_->MakeExactCopy(&graph_cl));
cl::InferenceOptions cl_options = GetClInferenceOptions(options_);
std::vector<uint8_t> serialized_model;
MP_RETURN_IF_ERROR(cl_environment_->BuildSerializedModel(
cl_options, std::move(graph_cl), &serialized_model));
return serialized_model;
}
#endif // __ANDROID__
} // namespace gpu
} // namespace tflite

View File

@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -29,7 +30,7 @@
#ifdef __ANDROID__
#include "tensorflow/lite/delegates/gpu/cl/api.h"
#endif
#endif // __ANDROID__
namespace tflite {
namespace gpu {
@ -90,11 +91,22 @@ class TFLiteGPURunner {
std::vector<uint8_t> GetSerializedBinaryCache() {
return cl_environment_->GetSerializedBinaryCache();
}
#endif
void SetSerializedModel(std::vector<uint8_t>&& serialized_model) {
serialized_model_ = std::move(serialized_model);
serialized_model_used_ = false;
}
absl::StatusOr<std::vector<uint8_t>> GetSerializedModel();
#endif // __ANDROID__
private:
absl::Status InitializeOpenGL(std::unique_ptr<InferenceBuilder>* builder);
absl::Status InitializeOpenCL(std::unique_ptr<InferenceBuilder>* builder);
#ifdef __ANDROID__
absl::Status InitializeOpenCLFromSerializedModel(
std::unique_ptr<InferenceBuilder>* builder);
#endif // __ANDROID__
InferenceOptions options_;
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
@ -103,9 +115,12 @@ class TFLiteGPURunner {
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
std::vector<uint8_t> serialized_binary_cache_;
#endif
std::vector<uint8_t> serialized_model_;
bool serialized_model_used_ = false;
#endif // __ANDROID__
// graph_ is maintained temporarily and becomes invalid after runner_ is ready
// graph_gl_ is maintained temporarily and becomes invalid after runner_ is
// ready
std::unique_ptr<GraphFloat32> graph_gl_;
std::unique_ptr<GraphFloat32> graph_cl_;
std::unique_ptr<InferenceRunner> runner_;