From ebba119f151ec1963eac0b2bda3e10f4cfb7624f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 01:22:38 -0800 Subject: [PATCH 001/137] Add Java ImageEmbedder API. PiperOrigin-RevId: 488588010 --- .../proto/image_embedder_graph_options.proto | 3 + .../mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + .../com/google/mediapipe/tasks/vision/BUILD | 29 ++ .../vision/imageembedder/AndroidManifest.xml | 8 + .../vision/imageembedder/ImageEmbedder.java | 448 ++++++++++++++++++ .../imageembedder/ImageEmbedderResult.java | 54 +++ .../vision/imageembedder/AndroidManifest.xml | 24 + .../tasks/vision/imageembedder/BUILD | 19 + .../imageembedder/ImageEmbedderTest.java | 444 +++++++++++++++++ 9 files changed, 1030 insertions(+) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD create mode 100644 mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java diff --git a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto index 4adba5ab7..72b3e7ee3 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/embedder_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageembedder.proto"; +option java_outer_classname = "ImageEmbedderGraphOptionsProto"; + message ImageEmbedderGraphOptions { extend mediapipe.CalculatorOptions { optional ImageEmbedderGraphOptions ext = 476348187; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 2b648bc43..8b09260bd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -42,6 +42,7 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 4dc4a547e..289e3000d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -43,6 +43,7 @@ cc_binary( "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -172,6 +173,34 @@ android_library( ], ) +android_library( + name = "imageembedder", + srcs = [ + "imageembedder/ImageEmbedder.java", + "imageembedder/ImageEmbedderResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "imageembedder/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") mediapipe_tasks_vision_aar( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..ebdb037d6 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java new file mode 100644 index 000000000..0d8ecd5c3 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -0,0 +1,448 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Embedding; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.utils.CosineSimilarity; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.proto.ImageEmbedderGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs embedding extraction on images. + * + *

The API expects a TFLite model with optional, but strongly recommended, TFLite Model Metadata.. + * + *

The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + *

+ */ +public final class ImageEmbedder extends BaseVisionTaskApi { + private static final String TAG = ImageEmbedder.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("EMBEDDINGS:embeddings_out", "IMAGE:image_out")); + private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph"; + + static { + ProtoUtil.registerTypeName( + EmbeddingsProto.EmbeddingResult.class, + "mediapipe.tasks.components.containers.proto.EmbeddingResult"); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the embedding model in the assets. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from a model file and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the embedding model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link ImageEmbedder} instance from a model buffer and default {@link + * ImageEmbedderOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding + * model. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, ImageEmbedderOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link ImageEmbedder} instance from an {@link ImageEmbedderOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link ImageEmbedderOptions} instance. + * @throws MediaPipeException if there is an error during {@link ImageEmbedder} creation. + */ + public static ImageEmbedder createFromOptions(Context context, ImageEmbedderOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageEmbedderResult convertToTaskResult(List packets) { + try { + return ImageEmbedderResult.create( + EmbeddingResult.createFromProto( + PacketGetter.getProto( + packets.get(EMBEDDINGS_OUT_STREAM_INDEX), + EmbeddingsProto.EmbeddingResult.getDefaultInstance())), + BaseVisionTaskApi.generateResultTimestampMs( + options.runningMode(), packets.get(EMBEDDINGS_OUT_STREAM_INDEX))); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + options.resultListener().ifPresent(handler::setResultListener); + options.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(options.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new ImageEmbedder(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link ImageEmbedder} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private ImageEmbedder(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs embedding extraction on the provided single image with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image) { + return embed(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs embedding extraction on the provided single image. Only use this method when the + * {@link ImageEmbedder} is created with {@link RunningMode.IMAGE}. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embed(MPImage image, ImageProcessingOptions imageProcessingOptions) { + return (ImageEmbedderResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs embedding extraction on the provided video frame with default image processing + * options, i.e. using the whole image as region-of-interest and without any rotation applied. + * Only use this method when the {@link ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo(MPImage image, long timestampMs) { + return embedForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs embedding extraction on the provided video frame. Only use this method when the {@link + * ImageEmbedder} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public ImageEmbedderResult embedForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + return (ImageEmbedderResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform embedding extraction with default image processing options, + * i.e. using the whole image as region-of-interest and without any rotation applied, and the + * results will be available via the {@link ResultListener} provided in the {@link + * ImageEmbedderOptions}. Only use this method when the {@link ImageEmbedder} is created with + * {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync(MPImage image, long timestampMs) { + embedAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform embedding extraction, and the results will be available via + * the {@link ResultListener} provided in the {@link ImageEmbedderOptions}. Only use this method + * when the {@link ImageEmbedder} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the object detector. The input timestamps must be monotonically increasing. + * + *

{@link ImageEmbedder} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void embedAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** + * Utility function to compute cosine + * similarity between two {@link Embedding} objects. + * + * @throws IllegalArgumentException if the embeddings are of different types (float vs. + * quantized), have different sizes, or have an L2-norm of 0. + */ + public static double cosineSimilarity(Embedding u, Embedding v) { + return CosineSimilarity.compute(u, v); + } + + /** Options for setting up and {@link ImageEmbedder}. */ + @AutoValue + public abstract static class ImageEmbedderOptions extends TaskOptions { + + /** Builder for {@link ImageEmbedderOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the image embedder task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the image embedder task. Default to the image mode. Image + * embedder has three modes: + * + *
    + *
  • IMAGE: The mode for performing embedding extraction on single image inputs. + *
  • VIDEO: The mode for performing embedding extraction on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for performing embedding extraction on a live stream of + * input data, such as from camera. In this mode, {@code setResultListener} must be + * called to set up a listener to receive the embedding results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as + * L2-normalization and scalar quantization. + */ + public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + + /** + * Sets the {@link ResultListener} to receive the embedding results asynchronously when the + * image embedder is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract ImageEmbedderOptions autoBuild(); + + /** + * Validates and builds the {@link ImageEmbedderOptions} instance. * + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the image embedder is + * in the live stream mode. + */ + public final ImageEmbedderOptions build() { + ImageEmbedderOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the live stream mode, a user-defined result listener" + + " must be provided in the ImageEmbedderOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The image embedder is in the image or video mode, a user-defined result listener" + + " shouldn't be provided in ImageEmbedderOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional embedderOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() + .setRunningMode(RunningMode.IMAGE); + } + + /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (embedderOptions().isPresent()) { + taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java new file mode 100644 index 000000000..ee3f4abc9 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderResult.java @@ -0,0 +1,54 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.EmbeddingResult; +import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; +import com.google.mediapipe.tasks.core.TaskResult; + +/** Represents the embedding results generated by {@link ImageEmbedder}. */ +@AutoValue +public abstract class ImageEmbedderResult implements TaskResult { + + /** + * Creates an {@link ImageEmbedderResult} instance. + * + * @param embeddingResult the {@link EmbeddingResult} object containing one embedding per embedder + * head. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult create(EmbeddingResult embeddingResult, long timestampMs) { + return new AutoValue_ImageEmbedderResult(embeddingResult, timestampMs); + } + + /** + * Creates an {@link ImageEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult} + * protobuf message. + * + * @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static ImageEmbedderResult createFromProto( + EmbeddingsProto.EmbeddingResult proto, long timestampMs) { + return create(EmbeddingResult.createFromProto(proto), timestampMs); + } + + /** Contains one embedding per embedder head. */ + public abstract EmbeddingResult embeddingResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml new file mode 100644 index 000000000..db303a439 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java new file mode 100644 index 000000000..56249ead9 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -0,0 +1,444 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.imageembedder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imageembedder.ImageEmbedder.ImageEmbedderOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link ImageEmbedder}/ */ +@RunWith(Suite.class) +@SuiteClasses({ImageEmbedderTest.General.class, ImageEmbedderTest.RunningModeTest.class}) +public class ImageEmbedderTest { + private static final String MOBILENET_EMBEDDER = "mobilenet_v3_small_100_224_embedder.tflite"; + private static final String BURGER_IMAGE = "burger.jpg"; + private static final String BURGER_CROP_IMAGE = "burger_crop.jpg"; + private static final String BURGER_ROTATED_IMAGE = "burger_rotated.jpg"; + + private static final double DOUBLE_DIFF_TOLERANCE = 1e-4; + + @RunWith(AndroidJUnit4.class) + public static final class General extends ImageEmbedderTest { + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedder.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void embed_succeedsWithNoOptions() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithL2Normalization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setEmbedderOptions(embedderOptions) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithQuantization() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setEmbedderOptions(embedderOptions) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.926776); + } + + @Test + public void embed_succeedsWithRegionOfInterest() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 0.833333f, 1.0f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).build(); + ImageEmbedderResult resultRoi = + imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoi.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999931f); + } + + @Test + public void embed_succeedsWithRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + ImageEmbedderResult resultRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultRotated.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.571648426f); + } + + @Test + public void embed_succeedsWithRegionOfInterestAndRotation() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + // RectF around the region in "burger_rotated.jpg" corresponding to "burger_crop.jpg". + RectF roi = new RectF(0.0f, 0.0f, 1.0f, 0.833333f); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build(); + ImageEmbedderResult resultRoiRotated = + imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + resultRoiRotated.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.62780395f); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends ImageEmbedderTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(mode) + .setResultListener((result, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageEmbedderOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void embed_failsWithCallingWrongApiInImageMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInVideoMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void embed_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((imageClassificationResult, inputImage) -> {}) + .build(); + + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void embed_succeedsWithImageMode() throws Exception { + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromFile( + ApplicationProvider.getApplicationContext(), MOBILENET_EMBEDDER); + ImageEmbedderResult result = imageEmbedder.embed(getImageFromAsset(BURGER_IMAGE)); + ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); + + // Check results. + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + // Check similarity. + double similarity = + ImageEmbedder.cosineSimilarity( + result.embeddingResult().embeddings().get(0), + resultCrop.embeddingResult().embeddings().get(0)); + assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.925272); + } + + @Test + public void embed_succeedsWithVideoMode() throws Exception { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.VIDEO) + .build(); + ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + for (int i = 0; i < 3; ++i) { + ImageEmbedderResult result = + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); + assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + } + } + + @Test + public void embed_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void embed_succeedsWithLiveStreamMode() throws Exception { + MPImage image = getImageFromAsset(BURGER_IMAGE); + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); + ImageEmbedderOptions options = + ImageEmbedderOptions.builder() + .setBaseOptions(baseOptions) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (imageEmbedderResult, inputImage) -> { + assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (ImageEmbedder imageEmbedder = + ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; ++i) { + imageEmbedder.embedAsync(image, /*timestampMs=*/ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertHasOneHeadAndCorrectDimension( + ImageEmbedderResult result, boolean quantized) { + assertThat(result.embeddingResult().embeddings()).hasSize(1); + assertThat(result.embeddingResult().embeddings().get(0).headIndex()).isEqualTo(0); + assertThat(result.embeddingResult().embeddings().get(0).headName().get()).isEqualTo("feature"); + if (quantized) { + assertThat(result.embeddingResult().embeddings().get(0).quantizedEmbedding()).hasLength(1024); + } else { + assertThat(result.embeddingResult().embeddings().get(0).floatEmbedding()).hasLength(1024); + } + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(480); + assertThat(inputImage.getHeight()).isEqualTo(325); + } +} From f14645cb06376cd1a6818a6155118ad0667d2d84 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 10:48:41 -0800 Subject: [PATCH 002/137] Model maker gesture recognizer test changes PiperOrigin-RevId: 488702055 --- .../gesture_recognizer_test.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index eb2b1d171..7e7a1ca30 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import tempfile from unittest import mock as unittest_mock import zipfile @@ -40,30 +41,35 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - self._model_options = gesture_recognizer.ModelOptions() - self._hparams = gesture_recognizer.HParams(epochs=2) - self._gesture_recognizer_options = ( - gesture_recognizer.GestureRecognizerOptions( - model_options=self._model_options, hparams=self._hparams)) all_data = self._load_data() # Splits data, 90% data for training, 10% for testing self._train_data, self._test_data = all_data.split(0.9) def test_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) self._test_accuracy(model) def test_export_gesture_recognizer_model(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) model.export_model() - model_bundle_file = os.path.join(self._hparams.export_dir, + model_bundle_file = os.path.join(hparams.export_dir, 'gesture_recognizer.task') with zipfile.ZipFile(model_bundle_file) as zf: self.assertEqual( @@ -102,7 +108,7 @@ class GestureRecognizerTest(tf.test.TestCase): 'GestureRecognizerModelOptions', autospec=True, return_value=gesture_recognizer.ModelOptions()) - def test_create_hparams_and_model_options_if_none_in_image_classifier_options( + def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( self, mock_hparams, mock_model_options): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( @@ -113,16 +119,21 @@ class GestureRecognizerTest(tf.test.TestCase): mock_model_options.assert_called_once() def test_continual_training_by_loading_checkpoint(self): + model_options = gesture_recognizer.ModelOptions() + hparams = gesture_recognizer.HParams( + export_dir=tempfile.mkdtemp(), epochs=2) + gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions( + model_options=model_options, hparams=hparams) mock_stdout = io.StringIO() with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, validation_data=self._test_data, - options=self._gesture_recognizer_options) + options=gesture_recognizer_options) self._test_accuracy(model) self.assertRegex(mock_stdout.getvalue(), 'Resuming from') From a94564540bc22af9d02c4df3102a1f0d3424929e Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 15 Nov 2022 11:49:21 -0800 Subject: [PATCH 003/137] Bump up the dependency library pybind11's version to 2.10.1. PiperOrigin-RevId: 488718815 --- WORKSPACE | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 702d1899e..fea96d941 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -212,14 +212,14 @@ http_archive( sha256 = "75922da3a1bdb417d820398eb03d4e9bd067c4905a4246d35a44c01d62154d91", ) -# Point to the commit that deprecates the usage of Eigen::MappedSparseMatrix. +# 2022-10-20 http_archive( name = "pybind11", urls = [ - "https://github.com/pybind/pybind11/archive/70a58c577eaf067748c2ec31bfd0b0a614cffba6.zip", + "https://github.com/pybind/pybind11/archive/v2.10.1.zip", ], - sha256 = "b971842fab1b5b8f3815a2302331782b7d137fef0e06502422bc4bc360f4956c", - strip_prefix = "pybind11-70a58c577eaf067748c2ec31bfd0b0a614cffba6", + sha256 = "fcf94065efcfd0a7a828bacf118fa11c43f6390d0c805e3e6342ac119f2e9976", + strip_prefix = "pybind11-2.10.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) From 1689112b23fc6038114a143baf0253e0b6c043c6 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 14:02:21 -0800 Subject: [PATCH 004/137] Improve model_util_test code. PiperOrigin-RevId: 488752497 --- .../model_maker/python/core/utils/model_util_test.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index bef9c8a97..05c6ffe3f 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from typing import Optional from absl.testing import parameterized import tensorflow as tf @@ -76,8 +77,10 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]), expected_steps_per_epoch=2)) - def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, - expected_steps_per_epoch): + def test_get_steps_per_epoch(self, steps_per_epoch: Optional[int], + batch_size: Optional[int], + train_data: Optional[tf.data.Dataset], + expected_steps_per_epoch: int): estimated_steps_per_epoch = model_util.get_steps_per_epoch( steps_per_epoch=steps_per_epoch, batch_size=batch_size, @@ -130,7 +133,9 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): testcase_name='float16_quantize', config=quantization.QuantizationConfig.for_float16(), model_size=1468)) - def test_convert_to_tflite_quantized(self, config, model_size): + def test_convert_to_tflite_quantized(self, + config: quantization.QuantizationConfig, + model_size: int): input_dim = 16 num_classes = 2 max_input_value = 5 @@ -157,5 +162,6 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): test_util.test_tflite_file( keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + if __name__ == '__main__': tf.test.main() From 496720308c66d02832038090e1a6562ca5b6342f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 15 Nov 2022 14:03:17 -0800 Subject: [PATCH 005/137] Migrate remaining MP Tasks Libraries to ts_declarations PiperOrigin-RevId: 488752799 --- .../tasks/web/audio/audio_classifier/BUILD | 23 ++++++++++++++----- .../audio_classifier/audio_classifier.ts | 3 +++ ...tions.ts => audio_classifier_options.d.ts} | 0 ...result.ts => audio_classifier_result.d.ts} | 0 mediapipe/tasks/web/audio/index.ts | 3 --- mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/index.ts | 4 +--- .../tasks/web/text/text_classifier/BUILD | 22 +++++++++++++----- .../text/text_classifier/text_classifier.ts | 3 +++ ...ptions.ts => text_classifier_options.d.ts} | 0 ..._result.ts => text_classifier_result.d.ts} | 0 mediapipe/tasks/web/text/text_embedder/BUILD | 23 ++++++++++++++----- .../web/text/text_embedder/text_embedder.ts | 2 ++ .../tasks/web/vision/image_embedder/BUILD | 23 ++++++++++++++----- .../vision/image_embedder/image_embedder.ts | 2 ++ ...options.ts => image_embedder_options.d.ts} | 0 ...r_result.ts => image_embedder_result.d.ts} | 0 mediapipe/tasks/web/vision/index.ts | 11 --------- 18 files changed, 79 insertions(+), 41 deletions(-) rename mediapipe/tasks/web/audio/audio_classifier/{audio_classifier_options.ts => audio_classifier_options.d.ts} (100%) rename mediapipe/tasks/web/audio/audio_classifier/{audio_classifier_result.ts => audio_classifier_result.d.ts} (100%) rename mediapipe/tasks/web/text/text_classifier/{text_classifier_options.ts => text_classifier_options.d.ts} (100%) rename mediapipe/tasks/web/text/text_classifier/{text_classifier_result.ts => text_classifier_result.d.ts} (100%) rename mediapipe/tasks/web/vision/image_embedder/{image_embedder_options.ts => image_embedder_options.d.ts} (100%) rename mediapipe/tasks/web/vision/image_embedder/{image_embedder_result.ts => image_embedder_result.d.ts} (100%) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 1bc4af309..6a78116c3 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -2,7 +2,7 @@ # # This task takes audio data and outputs the classification result. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,12 +10,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "audio_classifier", - srcs = [ - "audio_classifier.ts", - "audio_classifier_options.ts", - "audio_classifier_result.ts", - ], + srcs = ["audio_classifier.ts"], deps = [ + ":audio_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", @@ -31,3 +28,17 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "audio_classifier_types", + srcs = [ + "audio_classifier_options.d.ts", + "audio_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e3700cd7a..76b926723 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -29,6 +29,9 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierResult} from './audio_classifier_result'; +export * from './audio_classifier_options'; +export * from './audio_classifier_result'; + const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts similarity index 100% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.ts rename to mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts rename to mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.d.ts diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 114a8ceca..a5083b326 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -14,7 +14,4 @@ * limitations under the License. */ -// Audio Classifier -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_options'; -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier_result'; export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index a369d0af0..4b465b0f5 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -9,5 +9,6 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", ], ) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index dc511a426..d50db209c 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -14,7 +14,5 @@ * limitations under the License. */ -// Text Classifier -export * from '../../../tasks/web/text/text_classifier/text_classifier_options'; -export * from '../../../tasks/web/text/text_classifier/text_classifier_result'; export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 4ebdce18a..7dbbb18ca 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -3,7 +3,7 @@ # This task takes text input performs Natural Language classification (including # BERT-based text classification). -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,12 +11,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_classifier", - srcs = [ - "text_classifier.ts", - "text_classifier_options.ts", - "text_classifier_result.ts", - ], + srcs = ["text_classifier.ts"], deps = [ + ":text_classifier_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", @@ -32,3 +29,16 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "text_classifier_types", + srcs = [ + "text_classifier_options.d.ts", + "text_classifier_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:category", + "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core:classifier_options", + ], +) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index e1d0c9601..d4f413efa 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -29,6 +29,9 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierResult} from './text_classifier_result'; +export * from './text_classifier_options'; +export * from './text_classifier_result'; + const INPUT_STREAM = 'text_in'; const CLASSIFICATIONS_STREAM = 'classifications_out'; const TEXT_CLASSIFIER_GRAPH = diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_options.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts similarity index 100% rename from mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts rename to mediapipe/tasks/web/text/text_classifier/text_classifier_result.d.ts diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 8e397ce6f..bebd612dd 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -3,7 +3,7 @@ # This task takes text input and performs embedding # -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,13 +11,11 @@ licenses(["notice"]) mediapipe_ts_library( name = "text_embedder", - srcs = [ - "text_embedder.ts", - "text_embedder_options.d.ts", - "text_embedder_result.d.ts", - ], + srcs = ["text_embedder.ts"], deps = [ + ":text_embedder_types", "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", @@ -30,3 +28,16 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "text_embedder_types", + srcs = [ + "text_embedder_options.d.ts", + "text_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 65df5df6a..7c631683d 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -29,6 +29,8 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm import {TextEmbedderOptions} from './text_embedder_options'; import {TextEmbedderResult} from './text_embedder_result'; +export * from './text_embedder_options'; +export * from './text_embedder_result'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index d12a05ad9..13ff2e4d6 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -2,7 +2,7 @@ # # This task performs embedding extraction on images. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -10,12 +10,9 @@ licenses(["notice"]) mediapipe_ts_library( name = "image_embedder", - srcs = [ - "image_embedder.ts", - "image_embedder_options.ts", - "image_embedder_result.ts", - ], + srcs = ["image_embedder.ts"], deps = [ + ":image_embedder_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", @@ -31,3 +28,17 @@ mediapipe_ts_library( "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) + +mediapipe_ts_declaration( + name = "image_embedder_types", + srcs = [ + "image_embedder_options.d.ts", + "image_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/vision/core:running_mode", + ], +) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 4184e763c..91d9b5119 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -38,6 +38,8 @@ const EMBEDDINGS_STREAM = 'embeddings_out'; const TEXT_EMBEDDER_CALCULATOR = 'mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph'; +export * from './image_embedder_options'; +export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_embedder/image_embedder_options.ts rename to mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts similarity index 100% rename from mediapipe/tasks/web/vision/image_embedder/image_embedder_result.ts rename to mediapipe/tasks/web/vision/image_embedder/image_embedder_result.d.ts diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 0ea844fc9..d68c00cc7 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -14,19 +14,8 @@ * limitations under the License. */ -// Image Classifier export * from '../../../tasks/web/vision/image_classifier/image_classifier'; - -// Image Embedder -export * from '../../../tasks/web/vision/image_embedder/image_embedder_options'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder_result'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; - -// Gesture Recognizer export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; - -// Hand Landmarker export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; - -// Object Detector export * from '../../../tasks/web/vision/object_detector/object_detector'; From e65f21e2d85f9f08097e953ed9948de481065024 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 14:34:45 -0800 Subject: [PATCH 006/137] Update the docstring to make it consistent with the model option update. PiperOrigin-RevId: 488761331 --- .../python/vision/image_classifier/image_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 1ff6132b4..df71a8fef 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -177,7 +177,7 @@ class ImageClassifier(classifier.Classifier): Args: model_name: File name to save TFLite model with metadata. The full export - path is {self._hparams.model_dir}/{model_name}. + path is {self._hparams.export_dir}/{model_name}. quantization_config: The configuration for model quantization. """ if not tf.io.gfile.exists(self._hparams.export_dir): From 7a87546c30c347f8fc8d046431dbb27208a0f920 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 14:35:54 -0800 Subject: [PATCH 007/137] Internal change PiperOrigin-RevId: 488761646 --- mediapipe/framework/tool/test_util.cc | 22 +++++++++++++--------- mediapipe/framework/tool/test_util.h | 4 ++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index 6433c93d2..c7ed063e0 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -258,11 +258,8 @@ std::string GetTestFilePath(absl::string_view relative_path) { return file::JoinPath(GetTestRootDir(), relative_path); } -absl::StatusOr> LoadTestImage( - absl::string_view path, ImageFormat::Format format) { - std::string encoded; - MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); - +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format) { // stbi_load determines the output pixel format based on the desired channels. // 0 means "use whatever's in the file". int desired_channels = format == ImageFormat::UNKNOWN ? 0 @@ -274,10 +271,10 @@ absl::StatusOr> LoadTestImage( << "unsupported output format requested: " << format; int width, height, channels_in_file; - auto data = stbi_load_from_memory(reinterpret_cast(encoded.data()), - encoded.size(), &width, &height, - &channels_in_file, desired_channels); - RET_CHECK(data) << "failed to decode image data from: " << path; + auto data = stbi_load_from_memory( + reinterpret_cast(encoded.data()), encoded.size(), &width, + &height, &channels_in_file, desired_channels); + RET_CHECK(data) << "failed to decode image data"; // If we didn't specify a desired format, it will be determined by what the // file contains. @@ -295,6 +292,13 @@ absl::StatusOr> LoadTestImage( format, width, height, width * output_channels, data, stbi_image_free); } +absl::StatusOr> LoadTestImage( + absl::string_view path, ImageFormat::Format format) { + std::string encoded; + MP_RETURN_IF_ERROR(mediapipe::file::GetContents(path, &encoded)); + return DecodeTestImage(encoded, format); +} + std::unique_ptr LoadTestPng(absl::string_view path, ImageFormat::Format format) { return nullptr; diff --git a/mediapipe/framework/tool/test_util.h b/mediapipe/framework/tool/test_util.h index 71c096db7..80b768e3d 100644 --- a/mediapipe/framework/tool/test_util.h +++ b/mediapipe/framework/tool/test_util.h @@ -81,6 +81,10 @@ std::string GetTestDataDir(absl::string_view package_base_path); // Loads a binary graph from path. Returns true iff successful. bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path); +// Loads an image from memory. +absl::StatusOr> DecodeTestImage( + absl::string_view encoded, ImageFormat::Format format = ImageFormat::SRGBA); + // Loads an image from path. absl::StatusOr> LoadTestImage( absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA); From 38b636f7ee6c952832bc869475d47a1bf5e1c453 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 15:10:36 -0800 Subject: [PATCH 008/137] Internal change PiperOrigin-RevId: 488770794 --- mediapipe/framework/deps/BUILD | 1 + mediapipe/framework/deps/registration.h | 39 +++++++++++++------------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index a39d7476e..95ab21707 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -225,6 +225,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index b39a1e293..1a33b2b24 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -26,10 +26,12 @@ #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/meta/type_traits.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/deps/registration_token.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -159,7 +161,7 @@ class FunctionRegistry { FunctionRegistry(const FunctionRegistry&) = delete; FunctionRegistry& operator=(const FunctionRegistry&) = delete; - RegistrationToken Register(const std::string& name, Function func) + RegistrationToken Register(absl::string_view name, Function func) ABSL_LOCKS_EXCLUDED(lock_) { std::string normalized_name = GetNormalizedName(name); absl::WriterMutexLock lock(&lock_); @@ -189,14 +191,15 @@ class FunctionRegistry { absl::enable_if_t, std::tuple>::value, int> = 0> - ReturnType Invoke(const std::string& name, Args2&&... args) + ReturnType Invoke(absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { Function function; { absl::ReaderMutexLock lock(&lock_); auto it = functions_.find(name); if (it == functions_.end()) { - return absl::NotFoundError("No registered object with name: " + name); + return absl::NotFoundError( + absl::StrCat("No registered object with name: ", name)); } function = it->second; } @@ -206,7 +209,7 @@ class FunctionRegistry { // Invokes the specified factory function and returns the result. // Namespaces in |name| and |ns| are separated by kNameSep. template - ReturnType Invoke(const std::string& ns, const std::string& name, + ReturnType Invoke(absl::string_view ns, absl::string_view name, Args2&&... args) ABSL_LOCKS_EXCLUDED(lock_) { return Invoke(GetQualifiedName(ns, name), args...); } @@ -214,14 +217,14 @@ class FunctionRegistry { // Note that it's possible for registered implementations to be subsequently // unregistered, though this will never happen with registrations made via // MEDIAPIPE_REGISTER_FACTORY_FUNCTION. - bool IsRegistered(const std::string& name) const ABSL_LOCKS_EXCLUDED(lock_) { + bool IsRegistered(absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { absl::ReaderMutexLock lock(&lock_); return functions_.count(name) != 0; } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - bool IsRegistered(const std::string& ns, const std::string& name) const + bool IsRegistered(absl::string_view ns, absl::string_view name) const ABSL_LOCKS_EXCLUDED(lock_) { return IsRegistered(GetQualifiedName(ns, name)); } @@ -244,7 +247,7 @@ class FunctionRegistry { // Normalizes a C++ qualified name. Validates the name qualification. // The name must be either unqualified or fully qualified with a leading "::". // The leading "::" in a fully qualified name is stripped. - std::string GetNormalizedName(const std::string& name) { + std::string GetNormalizedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); if (names[0].empty()) { @@ -259,8 +262,8 @@ class FunctionRegistry { // Returns the registry key for a name specified within a namespace. // Namespaces are separated by kNameSep. - std::string GetQualifiedName(const std::string& ns, - const std::string& name) const { + std::string GetQualifiedName(absl::string_view ns, + absl::string_view name) const { using ::mediapipe::registration_internal::kCxxSep; using ::mediapipe::registration_internal::kNameSep; std::vector names = absl::StrSplit(name, kNameSep); @@ -287,10 +290,10 @@ class FunctionRegistry { private: mutable absl::Mutex lock_; - std::unordered_map functions_ ABSL_GUARDED_BY(lock_); + absl::flat_hash_map functions_ ABSL_GUARDED_BY(lock_); // For names included in NamespaceAllowlist, strips the namespace. - std::string GetAdjustedName(const std::string& name) { + std::string GetAdjustedName(absl::string_view name) { using ::mediapipe::registration_internal::kCxxSep; std::vector names = absl::StrSplit(name, kCxxSep); std::string base_name = names.back(); @@ -299,10 +302,10 @@ class FunctionRegistry { if (NamespaceAllowlist::TopNamespaces().count(ns)) { return base_name; } - return name; + return std::string(name); } - void Unregister(const std::string& name) { + void Unregister(absl::string_view name) { absl::WriterMutexLock lock(&lock_); std::string adjusted_name = GetAdjustedName(name); if (adjusted_name != name) { @@ -317,7 +320,7 @@ class GlobalFactoryRegistry { using Functions = FunctionRegistry; public: - static RegistrationToken Register(const std::string& name, + static RegistrationToken Register(absl::string_view name, typename Functions::Function func) { return functions()->Register(name, std::move(func)); } @@ -326,7 +329,7 @@ class GlobalFactoryRegistry { // If using namespaces with this registry, the variant with a namespace // argument should be used. template - static typename Functions::ReturnType CreateByName(const std::string& name, + static typename Functions::ReturnType CreateByName(absl::string_view name, Args2&&... args) { return functions()->Invoke(name, std::forward(args)...); } @@ -334,7 +337,7 @@ class GlobalFactoryRegistry { // Returns true if the specified factory function is available. // If using namespaces with this registry, the variant with a namespace // argument should be used. - static bool IsRegistered(const std::string& name) { + static bool IsRegistered(absl::string_view name) { return functions()->IsRegistered(name); } @@ -350,13 +353,13 @@ class GlobalFactoryRegistry { std::tuple>::value, int> = 0> static typename Functions::ReturnType CreateByNameInNamespace( - const std::string& ns, const std::string& name, Args2&&... args) { + absl::string_view ns, absl::string_view name, Args2&&... args) { return functions()->Invoke(ns, name, std::forward(args)...); } // Returns true if the specified factory function is available. // Namespaces in |name| and |ns| are separated by kNameSep. - static bool IsRegistered(const std::string& ns, const std::string& name) { + static bool IsRegistered(absl::string_view ns, absl::string_view name) { return functions()->IsRegistered(ns, name); } From a67069156e8d42f18403d5c47aa6219f4379b00d Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:16:11 -0800 Subject: [PATCH 009/137] Use flat_hash_map in ResourceCache This is the recommended hashmap in most cases. PiperOrigin-RevId: 488772031 --- mediapipe/util/BUILD | 1 + mediapipe/util/resource_cache.h | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index ab3390e0a..15835aea5 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -228,6 +228,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", ], ) diff --git a/mediapipe/util/resource_cache.h b/mediapipe/util/resource_cache.h index 4cd869f6a..2b3ccbc7d 100644 --- a/mediapipe/util/resource_cache.h +++ b/mediapipe/util/resource_cache.h @@ -17,6 +17,7 @@ #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/function_ref.h" #include "mediapipe/framework/port/logging.h" @@ -26,7 +27,8 @@ namespace mediapipe { // resource (e.g., image dimension for an image pool) is described bye the `Key` // type. The `Value` type must include an unset value, with implicit conversion // to bool reflecting set/unset state. -template +template ::hasher> class ResourceCache { public: Value Lookup( @@ -36,15 +38,14 @@ class ResourceCache { Entry* entry; if (map_it == map_.end()) { std::tie(map_it, std::ignore) = - map_.emplace(std::piecewise_construct, std::forward_as_tuple(key), - std::forward_as_tuple(key)); - entry = &map_it->second; + map_.try_emplace(key, std::make_unique(key)); + entry = map_it->second.get(); CHECK_EQ(entry->request_count, 0); entry->request_count = 1; entry_list_.Append(entry); if (entry->prev != nullptr) CHECK_GE(entry->prev->request_count, 1); } else { - entry = &map_it->second; + entry = map_it->second.get(); ++entry->request_count; Entry* larger = entry->prev; while (larger != nullptr && @@ -171,7 +172,7 @@ class ResourceCache { size_t size_ = 0; }; - std::unordered_map map_; + absl::flat_hash_map, KeyHash> map_; EntryList entry_list_; int total_request_count_ = 0; }; From 3c71c64be12409ed2019ac16a02263d3ebf96335 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:30:59 -0800 Subject: [PATCH 010/137] Remove shared_ptr from SimplePool definition This makes the types more explicit and will help with factoring out platform-specific code. PiperOrigin-RevId: 488775470 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 26 +++++++++++++------------- mediapipe/gpu/gpu_buffer_multi_pool.h | 13 +++++++------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 6e4fd38ea..5e8ce06b9 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -90,8 +90,8 @@ std::string CvPixelBufferPoolWrapper::GetDebugString() const { void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec) { +std::shared_ptr +GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { return std::make_shared(spec, kMaxInactiveBufferAge); } @@ -123,7 +123,7 @@ void GpuBufferMultiPool::FlushTextureCaches() { #define FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR 0 GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { + BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { #if TARGET_IPHONE_SIMULATOR && FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR // On the simulator, syncing the texture with the pixelbuffer does not work, // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not @@ -134,14 +134,14 @@ GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( // pool to give us contiguous data. return GetBufferWithoutPool(spec); #else - return pool->GetBuffer([this]() { FlushTextureCaches(); }); + return pool.GetBuffer([this]() { FlushTextureCaches(); }); #endif // TARGET_IPHONE_SIMULATOR } #else -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::MakeSimplePool( - const BufferSpec& spec) { +std::shared_ptr +GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, kKeepCount); } @@ -152,16 +152,16 @@ GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { } GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, const GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool->GetBuffer()); + BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { + return GpuBuffer(pool.GetBuffer()); } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( +std::shared_ptr GpuBufferMultiPool::RequestPool( const BufferSpec& spec) { - SimplePool pool; - std::vector evicted; + std::shared_ptr pool; + std::vector> evicted; { absl::MutexLock lock(&mutex_); pool = @@ -180,10 +180,10 @@ GpuBufferMultiPool::SimplePool GpuBufferMultiPool::RequestPool( GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, GpuBufferFormat format) { BufferSpec key(width, height, format); - SimplePool pool = RequestPool(key); + std::shared_ptr pool = RequestPool(key); if (pool) { // Note: we release our multipool lock before accessing the simple pool. - return GetBufferFromSimplePool(key, pool); + return GetBufferFromSimplePool(key, *pool); } else { return GetBufferWithoutPool(key); } diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 5ea6e314f..287b3b2a7 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -83,22 +83,23 @@ class GpuBufferMultiPool { private: #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - using SimplePool = std::shared_ptr; + using SimplePool = CvPixelBufferPoolWrapper; #else - using SimplePool = std::shared_ptr; + using SimplePool = GlTextureBufferPool; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - SimplePool MakeSimplePool(const BufferSpec& spec); + std::shared_ptr MakeSimplePool(const BufferSpec& spec); // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a // pool, in which case the caller should invoke GetBufferWithoutPool instead // of GetBufferFromSimplePool. - SimplePool RequestPool(const BufferSpec& spec); - GpuBuffer GetBufferFromSimplePool(BufferSpec spec, const SimplePool& pool); + std::shared_ptr RequestPool(const BufferSpec& spec); + GpuBuffer GetBufferFromSimplePool(BufferSpec spec, SimplePool& pool); GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); absl::Mutex mutex_; - mediapipe::ResourceCache> + mediapipe::ResourceCache, + absl::Hash> cache_ ABSL_GUARDED_BY(mutex_); #ifdef __APPLE__ From a520d6cc38dd13c68bf7fac24a919ec8b0bfcdfe Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:39:41 -0800 Subject: [PATCH 011/137] Remove FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR This workaround code is no longer necessary, as per the comment. PiperOrigin-RevId: 488777606 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 5e8ce06b9..2bceb1c05 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -117,25 +117,9 @@ void GpuBufferMultiPool::FlushTextureCaches() { } } -// Turning this on disables the pixel buffer pools when using the simulator. -// It is no longer necessary, since the helper code now supports non-contiguous -// buffers. We leave the code in for now for the sake of documentation. -#define FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR 0 - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { -#if TARGET_IPHONE_SIMULATOR && FORCE_CONTIGUOUS_PIXEL_BUFFER_ON_IPHONE_SIMULATOR - // On the simulator, syncing the texture with the pixelbuffer does not work, - // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not - // available in OpenGL ES 2, we should create the buffer so the pixels are - // contiguous. - // - // TODO: verify if we can use kIOSurfaceBytesPerRow to force the - // pool to give us contiguous data. - return GetBufferWithoutPool(spec); -#else return pool.GetBuffer([this]() { FlushTextureCaches(); }); -#endif // TARGET_IPHONE_SIMULATOR } #else From fae55910f44370b86bd04f0cea106cec43be5374 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 15:56:36 -0800 Subject: [PATCH 012/137] Enable absl::string_view kCalculatorName PiperOrigin-RevId: 488781493 --- mediapipe/framework/api2/builder.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 6d3323b97..19273bf44 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -412,11 +412,11 @@ using GenericNode = Node; template class Node : public NodeBase { public: - Node() : NodeBase(Calc::kCalculatorName) {} + Node() : NodeBase(std::string(Calc::kCalculatorName)) {} // Overrides the built-in calculator type string with the provided argument. // Can be used to create nodes from pure interfaces. // TODO: only use this for pure interfaces - Node(const std::string& type_override) : NodeBase(type_override) {} + Node(std::string type_override) : NodeBase(std::move(type_override)) {} // These methods only allow access to ports declared in the contract. // The argument must be a tag object created with the MPP_TAG macro. From ab2dd779e73a6756bce09d107fc9a738d9e09edd Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:57:43 -0800 Subject: [PATCH 013/137] Factor out CvTextureCacheManager This is a platform-specific component that is only used with CVPixelBufferPool. PiperOrigin-RevId: 488781757 --- mediapipe/gpu/BUILD | 16 +++++++ mediapipe/gpu/cv_texture_cache_manager.cc | 55 +++++++++++++++++++++++ mediapipe/gpu/cv_texture_cache_manager.h | 49 ++++++++++++++++++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 40 +---------------- mediapipe/gpu/gpu_buffer_multi_pool.h | 28 +++--------- mediapipe/gpu/gpu_shared_data_internal.cc | 19 +++++--- mediapipe/gpu/gpu_shared_data_internal.h | 3 ++ 7 files changed, 143 insertions(+), 67 deletions(-) create mode 100644 mediapipe/gpu/cv_texture_cache_manager.cc create mode 100644 mediapipe/gpu/cv_texture_cache_manager.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9c2f47469..93527b565 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -344,6 +344,18 @@ cc_library( ], ) +cc_library( + name = "cv_texture_cache_manager", + srcs = ["cv_texture_cache_manager.cc"], + hdrs = ["cv_texture_cache_manager.h"], + deps = [ + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_buffer_storage_image_frame", hdrs = ["gpu_buffer_storage_image_frame.h"], @@ -440,6 +452,7 @@ objc_library( ":gpu_buffer_multi_pool", ":gpu_shared_data_header", ":graph_support", + ":cv_texture_cache_manager", "//mediapipe/gpu:gl_context_options_cc_proto", "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", @@ -555,6 +568,7 @@ cc_library( "//conditions:default": [], "//mediapipe:apple": [ ":MPPGraphGPUData", + ":cv_texture_cache_manager", ], }), ) @@ -617,11 +631,13 @@ cc_library( ":gl_texture_buffer_pool", ], "//mediapipe:ios": [ + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], "//mediapipe:macos": [ + ":cv_texture_cache_manager", ":pixel_buffer_pool_util", ":gl_texture_buffer", ":gl_texture_buffer_pool", diff --git a/mediapipe/gpu/cv_texture_cache_manager.cc b/mediapipe/gpu/cv_texture_cache_manager.cc new file mode 100644 index 000000000..b977a8993 --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.cc @@ -0,0 +1,55 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_texture_cache_manager.h" + +#include "mediapipe/framework/port/logging.h" + +namespace mediapipe { + +void CvTextureCacheManager::FlushTextureCaches() { + absl::MutexLock lock(&mutex_); + for (const auto& cache : texture_caches_) { +#if TARGET_OS_OSX + CVOpenGLTextureCacheFlush(*cache, 0); +#else + CVOpenGLESTextureCacheFlush(*cache, 0); +#endif // TARGET_OS_OSX + } +} + +void CvTextureCacheManager::RegisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) + << "Attempting to register a texture cache twice"; + texture_caches_.emplace_back(cache); +} + +void CvTextureCacheManager::UnregisterTextureCache(CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_); + + auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); + CHECK(it != texture_caches_.end()) + << "Attempting to unregister an unknown texture cache"; + texture_caches_.erase(it); +} + +CvTextureCacheManager::~CvTextureCacheManager() { + CHECK_EQ(texture_caches_.size(), 0) + << "Failed to unregister texture caches before deleting manager"; +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_texture_cache_manager.h b/mediapipe/gpu/cv_texture_cache_manager.h new file mode 100644 index 000000000..17e44fc6e --- /dev/null +++ b/mediapipe/gpu/cv_texture_cache_manager.h @@ -0,0 +1,49 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ +#define MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvTextureCacheManager { + public: + ~CvTextureCacheManager(); + + // TODO: add tests for the texture cache registration. + + // Inform the pool of a cache that should be flushed when it is low on + // reusable buffers. + void RegisterTextureCache(CVTextureCacheType cache); + + // Remove a texture cache from the list of caches to be flushed. + void UnregisterTextureCache(CVTextureCacheType cache); + + void FlushTextureCaches(); + + private: + absl::Mutex mutex_; + std::vector> texture_caches_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_TEXTURE_CACHE_MANAGER_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 2bceb1c05..f76833f24 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -106,20 +106,9 @@ GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { return GpuBuffer(MakeCFHolderAdopting(buffer)); } -void GpuBufferMultiPool::FlushTextureCaches() { - absl::MutexLock lock(&mutex_); - for (const auto& cache : texture_caches_) { -#if TARGET_OS_OSX - CVOpenGLTextureCacheFlush(*cache, 0); -#else - CVOpenGLESTextureCacheFlush(*cache, 0); -#endif // TARGET_OS_OSX - } -} - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return pool.GetBuffer([this]() { FlushTextureCaches(); }); + return pool.GetBuffer(flush_platform_caches_); } #else @@ -173,31 +162,4 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, } } -GpuBufferMultiPool::~GpuBufferMultiPool() { -#ifdef __APPLE__ - CHECK_EQ(texture_caches_.size(), 0) - << "Failed to unregister texture caches before deleting pool"; -#endif // defined(__APPLE__) -} - -#ifdef __APPLE__ -void GpuBufferMultiPool::RegisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == - texture_caches_.end()) - << "Attempting to register a texture cache twice"; - texture_caches_.emplace_back(cache); -} - -void GpuBufferMultiPool::UnregisterTextureCache(CVTextureCacheType cache) { - absl::MutexLock lock(&mutex_); - - auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); - CHECK(it != texture_caches_.end()) - << "Attempting to unregister an unknown texture cache"; - texture_caches_.erase(it); -} -#endif // defined(__APPLE__) - } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 287b3b2a7..7317ac60e 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -43,25 +43,14 @@ class CvPixelBufferPoolWrapper; class GpuBufferMultiPool { public: GpuBufferMultiPool() {} - explicit GpuBufferMultiPool(void* ignored) {} - ~GpuBufferMultiPool(); // Obtains a buffer. May either be reused or created anew. GpuBuffer GetBuffer(int width, int height, GpuBufferFormat format = GpuBufferFormat::kBGRA32); -#ifdef __APPLE__ - // TODO: add tests for the texture cache registration. - - // Inform the pool of a cache that should be flushed when it is low on - // reusable buffers. - void RegisterTextureCache(CVTextureCacheType cache); - - // Remove a texture cache from the list of caches to be flushed. - void UnregisterTextureCache(CVTextureCacheType cache); - - void FlushTextureCaches(); -#endif // defined(__APPLE__) + void SetFlushPlatformCaches(std::function flush_platform_caches) { + flush_platform_caches_ = flush_platform_caches; + } // This class is not intended as part of the public api of this class. It is // public only because it is used as a map key type, and the map @@ -98,15 +87,10 @@ class GpuBufferMultiPool { GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); absl::Mutex mutex_; - mediapipe::ResourceCache, - absl::Hash> - cache_ ABSL_GUARDED_BY(mutex_); - -#ifdef __APPLE__ - // Texture caches used with this pool. - std::vector> texture_caches_ + mediapipe::ResourceCache> cache_ ABSL_GUARDED_BY(mutex_); -#endif // defined(__APPLE__) + // This is used to hook up the TextureCacheManager on Apple platforms. + std::function flush_platform_caches_; }; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index a8bf0c3a3..457b04fd3 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -85,7 +85,12 @@ GpuResources::GpuResources(std::shared_ptr gl_context) { named_executors_[kGpuExecutorName] = std::make_shared(gl_context.get()); #if __APPLE__ - gpu_buffer_pool().RegisterTextureCache(gl_context->cv_texture_cache()); + texture_caches_ = std::make_shared(); + gpu_buffer_pool().SetFlushPlatformCaches( + [tc = texture_caches_] { tc->FlushTextureCaches(); }); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() multiPool:&gpu_buffer_pool_]; #endif // __APPLE__ @@ -98,10 +103,12 @@ GpuResources::~GpuResources() { #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER for (auto& kv : gl_key_context_) { - gpu_buffer_pool().UnregisterTextureCache(kv.second->cv_texture_cache()); + texture_caches_->UnregisterTextureCache(kv.second->cv_texture_cache()); } -#endif +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // __APPLE__ } absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { @@ -174,9 +181,9 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GlContext::Create(*gl_key_context_[SharedContextKey()], kGlContextUseDedicatedThread)); it = gl_key_context_.emplace(key, new_context).first; -#if __APPLE__ - gpu_buffer_pool_.RegisterTextureCache(it->second->cv_texture_cache()); -#endif +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + texture_caches_->RegisterTextureCache(it->second->cv_texture_cache()); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } return it->second; } diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 62d6bb27e..12a7a1296 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -30,6 +30,7 @@ #include "mediapipe/gpu/gpu_buffer_multi_pool.h" #ifdef __APPLE__ +#include "mediapipe/gpu/cv_texture_cache_manager.h" #ifdef __OBJC__ @class MPPGraphGPUData; #else @@ -91,6 +92,8 @@ class GpuResources { GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ + std::shared_ptr texture_caches_; + // Note that this is an Objective-C object. MPPGraphGPUData* ios_gpu_data_; #endif // defined(__APPLE__) From 0d273dd11aac9701c241f1097377614b80690fc3 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:58:32 -0800 Subject: [PATCH 014/137] Factor out CvPixelBufferPoolWrapper This is platform-specific and does not need to live in the main multi_pool sources. PiperOrigin-RevId: 488781934 --- mediapipe/gpu/BUILD | 22 ++++++ mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 71 +++++++++++++++++++ mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 50 +++++++++++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 49 +------------ mediapipe/gpu/gpu_buffer_multi_pool.h | 24 ++----- 5 files changed, 149 insertions(+), 67 deletions(-) create mode 100644 mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc create mode 100644 mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 93527b565..26df167c4 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -356,6 +356,26 @@ cc_library( ], ) +cc_library( + name = "cv_pixel_buffer_pool_wrapper", + srcs = ["cv_pixel_buffer_pool_wrapper.cc"], + hdrs = ["cv_pixel_buffer_pool_wrapper.h"], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", + ], + }), + deps = [ + ":gpu_buffer", + ":pixel_buffer_pool_util", + "//mediapipe/framework/port:logging", + "//mediapipe/objc:CFHolder", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "gpu_buffer_storage_image_frame", hdrs = ["gpu_buffer_storage_image_frame.h"], @@ -631,12 +651,14 @@ cc_library( ":gl_texture_buffer_pool", ], "//mediapipe:ios": [ + ":cv_pixel_buffer_pool_wrapper", ":cv_texture_cache_manager", ":pixel_buffer_pool_util", "//mediapipe/objc:CFHolder", "//mediapipe/objc:util", ], "//mediapipe:macos": [ + ":cv_pixel_buffer_pool_wrapper", ":cv_texture_cache_manager", ":pixel_buffer_pool_util", ":gl_texture_buffer", diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc new file mode 100644 index 000000000..3293b0238 --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -0,0 +1,71 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" + +#include + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper(int width, int height, + GpuBufferFormat format, + CFTimeInterval maxAge) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + pool_ = MakeCFHolderAdopting( + /* keep count is 0 because the age param keeps buffers around anyway */ + CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); +} + +GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { + CVPixelBufferRef buffer; + int threshold = 1; + NSMutableDictionary* auxAttributes = + [NSMutableDictionary dictionaryWithCapacity:1]; + CVReturn err; + bool tried_flushing = false; + while (1) { + auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); + err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( + kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, + &buffer); + if (err != kCVReturnWouldExceedAllocationThreshold) break; + if (flush && !tried_flushing) { + // Call the flush function to potentially release old holds on buffers + // and try again to create a pixel buffer. + // This is used to flush CV texture caches, which may retain buffers until + // flushed. + flush(); + tried_flushing = true; + } else { + ++threshold; + } + } + CHECK(!err) << "Error creating pixel buffer: " << err; + count_ = threshold; + return GpuBuffer(MakeCFHolderAdopting(buffer)); +} + +std::string CvPixelBufferPoolWrapper::GetDebugString() const { + auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); + return [(__bridge NSString*)*description UTF8String]; +} + +void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } + +} // namespace mediapipe diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h new file mode 100644 index 000000000..081df4676 --- /dev/null +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ +#define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ + +#include "CoreFoundation/CFBase.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#include "mediapipe/objc/CFHolder.h" + +namespace mediapipe { + +class CvPixelBufferPoolWrapper { + public: + CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, + CFTimeInterval maxAge); + GpuBuffer GetBuffer(std::function flush); + + int GetBufferCount() const { return count_; } + std::string GetDebugString() const; + + void Flush(); + + private: + CFHolder pool_; + int count_ = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index f76833f24..1909d116e 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -45,55 +45,10 @@ static constexpr int kRequestCountScrubInterval = 50; #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( - const GpuBufferMultiPool::BufferSpec& spec, CFTimeInterval maxAge) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - pool_ = MakeCFHolderAdopting( - /* keep count is 0 because the age param keeps buffers around anyway */ - CreateCVPixelBufferPool(spec.width, spec.height, cv_format, 0, maxAge)); -} - -GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { - CVPixelBufferRef buffer; - int threshold = 1; - NSMutableDictionary* auxAttributes = - [NSMutableDictionary dictionaryWithCapacity:1]; - CVReturn err; - bool tried_flushing = false; - while (1) { - auxAttributes[(id)kCVPixelBufferPoolAllocationThresholdKey] = @(threshold); - err = CVPixelBufferPoolCreatePixelBufferWithAuxAttributes( - kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, - &buffer); - if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush && !tried_flushing) { - // Call the flush function to potentially release old holds on buffers - // and try again to create a pixel buffer. - // This is used to flush CV texture caches, which may retain buffers until - // flushed. - flush(); - tried_flushing = true; - } else { - ++threshold; - } - } - CHECK(!err) << "Error creating pixel buffer: " << err; - count_ = threshold; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - -std::string CvPixelBufferPoolWrapper::GetDebugString() const { - auto description = MakeCFHolderAdopting(CFCopyDescription(*pool_)); - return [(__bridge NSString*)*description UTF8String]; -} - -void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } - std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { - return std::make_shared(spec, - kMaxInactiveBufferAge); + return std::make_shared( + spec.width, spec.height, spec.format, kMaxInactiveBufferAge); } GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 7317ac60e..f48577854 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -31,9 +31,11 @@ #include "mediapipe/gpu/pixel_buffer_pool_util.h" #endif // __APPLE__ -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" +#else #include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER namespace mediapipe { @@ -93,24 +95,6 @@ class GpuBufferMultiPool { std::function flush_platform_caches_; }; -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -class CvPixelBufferPoolWrapper { - public: - CvPixelBufferPoolWrapper(const GpuBufferMultiPool::BufferSpec& spec, - CFTimeInterval maxAge); - GpuBuffer GetBuffer(std::function flush); - - int GetBufferCount() const { return count_; } - std::string GetDebugString() const; - - void Flush(); - - private: - CFHolder pool_; - int count_ = 0; -}; -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - // BufferSpec equality operators inline bool operator==(const GpuBufferMultiPool::BufferSpec& lhs, const GpuBufferMultiPool::BufferSpec& rhs) { From a4fe3eb0941e9571bbc4ade95147c4959f8aa67f Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:59:01 -0800 Subject: [PATCH 015/137] Add CreateBufferWithoutPool method to base pools This may not fit exactly in a pool class, but it makes it easy for the multi-pool to find the appropriate method by depending only on the type of the base pool. For the CVPixelBuffer case, the buffer type is CFHolder, and it seems even less appropriate to specialize that template to add such a method there. An alternative would be to allow defining a creation function separately. PiperOrigin-RevId: 488782054 --- mediapipe/gpu/BUILD | 3 ++- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 17 ++++++++++++-- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 7 ++++-- mediapipe/gpu/gl_texture_buffer_pool.h | 5 +++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 22 +++++-------------- 5 files changed, 33 insertions(+), 21 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 26df167c4..2f06fe1d5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -368,10 +368,11 @@ cc_library( ], }), deps = [ - ":gpu_buffer", + ":gpu_buffer_format", ":pixel_buffer_pool_util", "//mediapipe/framework/port:logging", "//mediapipe/objc:CFHolder", + "//mediapipe/objc:util", "@com_google_absl//absl/synchronization", ], ) diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index 3293b0238..c97268307 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -19,6 +19,7 @@ #include "CoreFoundation/CFBase.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/objc/CFHolder.h" +#include "mediapipe/objc/util.h" namespace mediapipe { @@ -32,7 +33,8 @@ CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper(int width, int height, CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); } -GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { +CFHolder CvPixelBufferPoolWrapper::GetBuffer( + std::function flush) { CVPixelBufferRef buffer; int threshold = 1; NSMutableDictionary* auxAttributes = @@ -58,7 +60,7 @@ GpuBuffer CvPixelBufferPoolWrapper::GetBuffer(std::function flush) { } CHECK(!err) << "Error creating pixel buffer: " << err; count_ = threshold; - return GpuBuffer(MakeCFHolderAdopting(buffer)); + return MakeCFHolderAdopting(buffer); } std::string CvPixelBufferPoolWrapper::GetDebugString() const { @@ -68,4 +70,15 @@ std::string CvPixelBufferPoolWrapper::GetDebugString() const { void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } +CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( + int width, int height, GpuBufferFormat format) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + CVPixelBufferRef buffer; + CVReturn err = + CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + return MakeCFHolderAdopting(buffer); +} + } // namespace mediapipe diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 081df4676..7412b776f 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -23,7 +23,7 @@ #define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ #include "CoreFoundation/CFBase.h" -#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/pixel_buffer_pool_util.h" #include "mediapipe/objc/CFHolder.h" @@ -33,13 +33,16 @@ class CvPixelBufferPoolWrapper { public: CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, CFTimeInterval maxAge); - GpuBuffer GetBuffer(std::function flush); + CFHolder GetBuffer(std::function flush); int GetBufferCount() const { return count_; } std::string GetDebugString() const; void Flush(); + static CFHolder CreateBufferWithoutPool( + int width, int height, GpuBufferFormat format); + private: CFHolder pool_; int count_ = 0; diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index 4dcad305e..cd755b4aa 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -51,6 +51,11 @@ class GlTextureBufferPool // This method is meant for testing. std::pair GetInUseAndAvailableCounts(); + static GlTextureBufferSharedPtr CreateBufferWithoutPool( + int width, int height, GpuBufferFormat format) { + return GlTextureBuffer::Create(width, height, format); + } + private: GlTextureBufferPool(int width, int height, GpuBufferFormat format, int keep_count); diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 1909d116e..fdff3e692 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -51,19 +51,9 @@ GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { spec.width, spec.height, spec.format, kMaxInactiveBufferAge); } -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); - CHECK_NE(cv_format, -1) << "unsupported pixel format"; - CVPixelBufferRef buffer; - CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, - cv_format, &buffer); - CHECK(!err) << "Error creating pixel buffer: " << err; - return GpuBuffer(MakeCFHolderAdopting(buffer)); -} - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return pool.GetBuffer(flush_platform_caches_); + return GpuBuffer(pool.GetBuffer(flush_platform_caches_)); } #else @@ -74,11 +64,6 @@ GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { kKeepCount); } -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - return GpuBuffer( - GlTextureBuffer::Create(spec.width, spec.height, spec.format)); -} - GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { return GpuBuffer(pool.GetBuffer()); @@ -117,4 +102,9 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, } } +GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { + return GpuBuffer(SimplePool::CreateBufferWithoutPool(spec.width, spec.height, + spec.format)); +} + } // namespace mediapipe From 0c4522cb9fb7ce7fc940581ae2553f7282b419ca Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 15:59:33 -0800 Subject: [PATCH 016/137] Move flush hook to CvPixelBufferPoolWrapper constructor This unifies the implementation of GpuBufferMultiPool::GetBufferFromSimplePool. PiperOrigin-RevId: 488782173 --- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 14 +++++++------- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 6 ++++-- mediapipe/gpu/gpu_buffer_multi_pool.cc | 18 +++++++----------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index c97268307..b1c135afa 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -23,18 +23,18 @@ namespace mediapipe { -CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper(int width, int height, - GpuBufferFormat format, - CFTimeInterval maxAge) { +CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( + int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, + std::function flush_texture_caches) { OSType cv_format = CVPixelFormatForGpuBufferFormat(format); CHECK_NE(cv_format, -1) << "unsupported pixel format"; pool_ = MakeCFHolderAdopting( /* keep count is 0 because the age param keeps buffers around anyway */ CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); + flush_texture_caches_ = std::move(flush_texture_caches); } -CFHolder CvPixelBufferPoolWrapper::GetBuffer( - std::function flush) { +CFHolder CvPixelBufferPoolWrapper::GetBuffer() { CVPixelBufferRef buffer; int threshold = 1; NSMutableDictionary* auxAttributes = @@ -47,12 +47,12 @@ CFHolder CvPixelBufferPoolWrapper::GetBuffer( kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, &buffer); if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush && !tried_flushing) { + if (flush_texture_caches_ && !tried_flushing) { // Call the flush function to potentially release old holds on buffers // and try again to create a pixel buffer. // This is used to flush CV texture caches, which may retain buffers until // flushed. - flush(); + flush_texture_caches_(); tried_flushing = true; } else { ++threshold; diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 7412b776f..9d9328ca1 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -32,8 +32,9 @@ namespace mediapipe { class CvPixelBufferPoolWrapper { public: CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, - CFTimeInterval maxAge); - CFHolder GetBuffer(std::function flush); + CFTimeInterval maxAge, + std::function flush_texture_caches); + CFHolder GetBuffer(); int GetBufferCount() const { return count_; } std::string GetDebugString() const; @@ -46,6 +47,7 @@ class CvPixelBufferPoolWrapper { private: CFHolder pool_; int count_ = 0; + std::function flush_texture_caches_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index fdff3e692..9c3c9a33e 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -48,12 +48,8 @@ static constexpr int kRequestCountScrubInterval = 50; std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { return std::make_shared( - spec.width, spec.height, spec.format, kMaxInactiveBufferAge); -} - -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool.GetBuffer(flush_platform_caches_)); + spec.width, spec.height, spec.format, kMaxInactiveBufferAge, + flush_platform_caches_); } #else @@ -64,11 +60,6 @@ GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { kKeepCount); } -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool.GetBuffer()); -} - #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER std::shared_ptr GpuBufferMultiPool::RequestPool( @@ -102,6 +93,11 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, } } +GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( + BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { + return GpuBuffer(pool.GetBuffer()); +} + GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { return GpuBuffer(SimplePool::CreateBufferWithoutPool(spec.width, spec.height, spec.format)); From f13903b7c5ba53cf383f8f3c67816274f9307db0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:01:08 -0800 Subject: [PATCH 017/137] Call SimplePool methods directly This removes redundant helper functions in GpuBufferMultiPool. PiperOrigin-RevId: 488782516 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 21 +++------------------ mediapipe/gpu/gpu_buffer_multi_pool.h | 5 +---- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 9c3c9a33e..d03ae06aa 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -21,12 +21,6 @@ #include "mediapipe/framework/port/logging.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -#include "CoreFoundation/CFBase.h" -#include "mediapipe/objc/CFHolder.h" -#include "mediapipe/objc/util.h" -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - namespace mediapipe { // Keep this many buffers allocated for a given frame size. @@ -87,20 +81,11 @@ GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, std::shared_ptr pool = RequestPool(key); if (pool) { // Note: we release our multipool lock before accessing the simple pool. - return GetBufferFromSimplePool(key, *pool); + return GpuBuffer(pool->GetBuffer()); } else { - return GetBufferWithoutPool(key); + return GpuBuffer( + SimplePool::CreateBufferWithoutPool(width, height, format)); } } -GpuBuffer GpuBufferMultiPool::GetBufferFromSimplePool( - BufferSpec spec, GpuBufferMultiPool::SimplePool& pool) { - return GpuBuffer(pool.GetBuffer()); -} - -GpuBuffer GpuBufferMultiPool::GetBufferWithoutPool(const BufferSpec& spec) { - return GpuBuffer(SimplePool::CreateBufferWithoutPool(spec.width, spec.height, - spec.format)); -} - } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index f48577854..7feb39ad4 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -82,11 +82,8 @@ class GpuBufferMultiPool { std::shared_ptr MakeSimplePool(const BufferSpec& spec); // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a - // pool, in which case the caller should invoke GetBufferWithoutPool instead - // of GetBufferFromSimplePool. + // pool, in which case the caller should invoke CreateBufferWithoutPool. std::shared_ptr RequestPool(const BufferSpec& spec); - GpuBuffer GetBufferFromSimplePool(BufferSpec spec, SimplePool& pool); - GpuBuffer GetBufferWithoutPool(const BufferSpec& spec); absl::Mutex mutex_; mediapipe::ResourceCache> cache_ From 7ef3185ecbb84567c8759350f1baa30907756c02 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:01:56 -0800 Subject: [PATCH 018/137] Allow customizing MultiPool options These don't need to be constants. PiperOrigin-RevId: 488782713 --- mediapipe/gpu/gpu_buffer_multi_pool.cc | 23 +++++------------------ mediapipe/gpu/gpu_buffer_multi_pool.h | 22 +++++++++++++++++++++- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index d03ae06aa..44f1d40df 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -23,26 +23,12 @@ namespace mediapipe { -// Keep this many buffers allocated for a given frame size. -static constexpr int kKeepCount = 2; -// The maximum size of the GpuBufferMultiPool. When the limit is reached, the -// oldest BufferSpec will be dropped. -static constexpr int kMaxPoolCount = 10; -// Time in seconds after which an inactive buffer can be dropped from the pool. -// Currently only used with CVPixelBufferPool. -static constexpr float kMaxInactiveBufferAge = 0.25; -// Skip allocating a buffer pool until at least this many requests have been -// made for a given BufferSpec. -static constexpr int kMinRequestsBeforePool = 2; -// Do a deeper flush every this many requests. -static constexpr int kRequestCountScrubInterval = 50; - #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { return std::make_shared( - spec.width, spec.height, spec.format, kMaxInactiveBufferAge, + spec.width, spec.height, spec.format, options_.max_inactive_buffer_age, flush_platform_caches_); } @@ -51,7 +37,7 @@ GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { std::shared_ptr GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - kKeepCount); + options_.keep_count); } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER @@ -64,11 +50,12 @@ std::shared_ptr GpuBufferMultiPool::RequestPool( absl::MutexLock lock(&mutex_); pool = cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { - return (request_count >= kMinRequestsBeforePool) + return (request_count >= options_.min_requests_before_pool) ? MakeSimplePool(spec) : nullptr; }); - evicted = cache_.Evict(kMaxPoolCount, kRequestCountScrubInterval); + evicted = cache_.Evict(options_.max_pool_count, + options_.request_count_scrub_interval); } // Evicted pools, and their buffers, will be released without holding the // lock. diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 7feb39ad4..1396bcdb3 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -42,9 +42,28 @@ namespace mediapipe { struct GpuSharedData; class CvPixelBufferPoolWrapper; +struct MultiPoolOptions { + // Keep this many buffers allocated for a given frame size. + int keep_count = 2; + // The maximum size of the GpuBufferMultiPool. When the limit is reached, the + // oldest BufferSpec will be dropped. + int max_pool_count = 10; + // Time in seconds after which an inactive buffer can be dropped from the + // pool. Currently only used with CVPixelBufferPool. + float max_inactive_buffer_age = 0.25; + // Skip allocating a buffer pool until at least this many requests have been + // made for a given BufferSpec. + int min_requests_before_pool = 2; + // Do a deeper flush every this many requests. + int request_count_scrub_interval = 50; +}; + +static constexpr MultiPoolOptions kDefaultMultiPoolOptions; + class GpuBufferMultiPool { public: - GpuBufferMultiPool() {} + GpuBufferMultiPool(MultiPoolOptions options = kDefaultMultiPoolOptions) + : options_(options) {} // Obtains a buffer. May either be reused or created anew. GpuBuffer GetBuffer(int width, int height, @@ -85,6 +104,7 @@ class GpuBufferMultiPool { // pool, in which case the caller should invoke CreateBufferWithoutPool. std::shared_ptr RequestPool(const BufferSpec& spec); + MultiPoolOptions options_; absl::Mutex mutex_; mediapipe::ResourceCache> cache_ ABSL_GUARDED_BY(mutex_); From 267476657d18598dc993dc6bb7f5f084a951d8ff Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:02:32 -0800 Subject: [PATCH 019/137] MultiPool options header refactoring Passing MultiPool options to the base pool factories means that we don't have to specialize which options we pass to them. PiperOrigin-RevId: 488782861 --- mediapipe/gpu/BUILD | 8 ++++ mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 11 +++++ mediapipe/gpu/gl_texture_buffer_pool.h | 7 +++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 23 ++++------ mediapipe/gpu/gpu_buffer_multi_pool.h | 24 +++------- mediapipe/gpu/multi_pool.h | 47 ++++++++++++++++++++ 6 files changed, 86 insertions(+), 34 deletions(-) create mode 100644 mediapipe/gpu/multi_pool.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 2f06fe1d5..b94623ca5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -369,6 +369,7 @@ cc_library( }), deps = [ ":gpu_buffer_format", + ":multi_pool", ":pixel_buffer_pool_util", "//mediapipe/framework/port:logging", "//mediapipe/objc:CFHolder", @@ -604,6 +605,7 @@ cc_library( ":gl_texture_buffer", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -612,6 +614,11 @@ cc_library( ], ) +cc_library( + name = "multi_pool", + hdrs = ["multi_pool.h"], +) + cc_library( name = "gpu_buffer_multi_pool", srcs = ["gpu_buffer_multi_pool.cc"], @@ -639,6 +646,7 @@ cc_library( ":gl_base", ":gpu_buffer", ":gpu_shared_data_header", + ":multi_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 9d9328ca1..185ba37c6 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -24,6 +24,7 @@ #include "CoreFoundation/CFBase.h" #include "mediapipe/gpu/gpu_buffer_format.h" +#include "mediapipe/gpu/multi_pool.h" #include "mediapipe/gpu/pixel_buffer_pool_util.h" #include "mediapipe/objc/CFHolder.h" @@ -34,6 +35,16 @@ class CvPixelBufferPoolWrapper { CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, std::function flush_texture_caches); + + static std::shared_ptr Create( + int width, int height, GpuBufferFormat format, + const MultiPoolOptions& options, + std::function flush_texture_caches = nullptr) { + return std::make_shared( + width, height, format, options.max_inactive_buffer_age, + flush_texture_caches); + } + CFHolder GetBuffer(); int GetBufferCount() const { return count_; } diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index cd755b4aa..fee46915e 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -23,6 +23,7 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gl_texture_buffer.h" +#include "mediapipe/gpu/multi_pool.h" namespace mediapipe { @@ -40,6 +41,12 @@ class GlTextureBufferPool new GlTextureBufferPool(width, height, format, keep_count)); } + static std::shared_ptr Create( + int width, int height, GpuBufferFormat format, + const MultiPoolOptions& options) { + return Create(width, height, format, options.keep_count); + } + // Obtains a buffers. May either be reused or created anew. // A GlContext must be current when this is called. GlTextureBufferSharedPtr GetBuffer(); diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 44f1d40df..df228b7dd 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -23,24 +23,17 @@ namespace mediapipe { +std::shared_ptr +GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options) { #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - -std::shared_ptr -GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec) { - return std::make_shared( - spec.width, spec.height, spec.format, options_.max_inactive_buffer_age, - flush_platform_caches_); -} - + return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, spec.format, + options, flush_platform_caches_); #else - -std::shared_ptr -GpuBufferMultiPool::MakeSimplePool(const BufferSpec& spec) { return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - options_.keep_count); -} - + options); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +} std::shared_ptr GpuBufferMultiPool::RequestPool( const BufferSpec& spec) { @@ -51,7 +44,7 @@ std::shared_ptr GpuBufferMultiPool::RequestPool( pool = cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { return (request_count >= options_.min_requests_before_pool) - ? MakeSimplePool(spec) + ? MakeSimplePool(spec, options_) : nullptr; }); evicted = cache_.Evict(options_.max_pool_count, diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 1396bcdb3..3ea299f78 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -25,6 +25,7 @@ #include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/multi_pool.h" #include "mediapipe/util/resource_cache.h" #ifdef __APPLE__ @@ -42,24 +43,6 @@ namespace mediapipe { struct GpuSharedData; class CvPixelBufferPoolWrapper; -struct MultiPoolOptions { - // Keep this many buffers allocated for a given frame size. - int keep_count = 2; - // The maximum size of the GpuBufferMultiPool. When the limit is reached, the - // oldest BufferSpec will be dropped. - int max_pool_count = 10; - // Time in seconds after which an inactive buffer can be dropped from the - // pool. Currently only used with CVPixelBufferPool. - float max_inactive_buffer_age = 0.25; - // Skip allocating a buffer pool until at least this many requests have been - // made for a given BufferSpec. - int min_requests_before_pool = 2; - // Do a deeper flush every this many requests. - int request_count_scrub_interval = 50; -}; - -static constexpr MultiPoolOptions kDefaultMultiPoolOptions; - class GpuBufferMultiPool { public: GpuBufferMultiPool(MultiPoolOptions options = kDefaultMultiPoolOptions) @@ -98,7 +81,10 @@ class GpuBufferMultiPool { using SimplePool = GlTextureBufferPool; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - std::shared_ptr MakeSimplePool(const BufferSpec& spec); + std::shared_ptr MakeSimplePool( + const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options); + // Requests a simple buffer pool for the given spec. This may return nullptr // if we have not yet reached a sufficient number of requests to allocate a // pool, in which case the caller should invoke CreateBufferWithoutPool. diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h new file mode 100644 index 000000000..e504fc820 --- /dev/null +++ b/mediapipe/gpu/multi_pool.h @@ -0,0 +1,47 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_GPU_MULTI_POOL_H_ +#define MEDIAPIPE_GPU_MULTI_POOL_H_ + +namespace mediapipe { + +struct MultiPoolOptions { + // Keep this many buffers allocated for a given frame size. + int keep_count = 2; + // The maximum size of the GpuBufferMultiPool. When the limit is reached, the + // oldest BufferSpec will be dropped. + int max_pool_count = 10; + // Time in seconds after which an inactive buffer can be dropped from the + // pool. Currently only used with CVPixelBufferPool. + float max_inactive_buffer_age = 0.25; + // Skip allocating a buffer pool until at least this many requests have been + // made for a given BufferSpec. + int min_requests_before_pool = 2; + // Do a deeper flush every this many requests. + int request_count_scrub_interval = 50; +}; + +static constexpr MultiPoolOptions kDefaultMultiPoolOptions; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_MULTI_POOL_H_ From b9fa2e3496ad0879556162a738f4f608ebe1bb5b Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:03:03 -0800 Subject: [PATCH 020/137] Make it possible to override the SimplePool factory used by MultiPool This means MultiPool no longer needs a SetFlushPlatformCaches method, which was too specific to the CVPixelBufferPool. PiperOrigin-RevId: 488783003 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 8 ++++---- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 10 +++++----- mediapipe/gpu/gpu_buffer_multi_pool.cc | 15 +++++---------- mediapipe/gpu/gpu_buffer_multi_pool.h | 18 ++++++++++-------- mediapipe/gpu/gpu_shared_data_internal.cc | 8 ++++++-- 6 files changed, 31 insertions(+), 29 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index b94623ca5..36527736b 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -368,6 +368,7 @@ cc_library( ], }), deps = [ + ":cv_texture_cache_manager", ":gpu_buffer_format", ":multi_pool", ":pixel_buffer_pool_util", diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index b1c135afa..d8155f5cf 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -25,13 +25,13 @@ namespace mediapipe { CvPixelBufferPoolWrapper::CvPixelBufferPoolWrapper( int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, - std::function flush_texture_caches) { + CvTextureCacheManager* texture_caches) { OSType cv_format = CVPixelFormatForGpuBufferFormat(format); CHECK_NE(cv_format, -1) << "unsupported pixel format"; pool_ = MakeCFHolderAdopting( /* keep count is 0 because the age param keeps buffers around anyway */ CreateCVPixelBufferPool(width, height, cv_format, 0, maxAge)); - flush_texture_caches_ = std::move(flush_texture_caches); + texture_caches_ = texture_caches; } CFHolder CvPixelBufferPoolWrapper::GetBuffer() { @@ -47,12 +47,12 @@ CFHolder CvPixelBufferPoolWrapper::GetBuffer() { kCFAllocatorDefault, *pool_, (__bridge CFDictionaryRef)auxAttributes, &buffer); if (err != kCVReturnWouldExceedAllocationThreshold) break; - if (flush_texture_caches_ && !tried_flushing) { + if (texture_caches_ && !tried_flushing) { // Call the flush function to potentially release old holds on buffers // and try again to create a pixel buffer. // This is used to flush CV texture caches, which may retain buffers until // flushed. - flush_texture_caches_(); + texture_caches_->FlushTextureCaches(); tried_flushing = true; } else { ++threshold; diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 185ba37c6..7d0aec4eb 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -23,6 +23,7 @@ #define MEDIAPIPE_GPU_CV_PIXEL_BUFFER_POOL_WRAPPER_H_ #include "CoreFoundation/CFBase.h" +#include "mediapipe/gpu/cv_texture_cache_manager.h" #include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/multi_pool.h" #include "mediapipe/gpu/pixel_buffer_pool_util.h" @@ -34,15 +35,14 @@ class CvPixelBufferPoolWrapper { public: CvPixelBufferPoolWrapper(int width, int height, GpuBufferFormat format, CFTimeInterval maxAge, - std::function flush_texture_caches); + CvTextureCacheManager* texture_caches); static std::shared_ptr Create( int width, int height, GpuBufferFormat format, const MultiPoolOptions& options, - std::function flush_texture_caches = nullptr) { + CvTextureCacheManager* texture_caches = nullptr) { return std::make_shared( - width, height, format, options.max_inactive_buffer_age, - flush_texture_caches); + width, height, format, options.max_inactive_buffer_age, texture_caches); } CFHolder GetBuffer(); @@ -58,7 +58,7 @@ class CvPixelBufferPoolWrapper { private: CFHolder pool_; int count_ = 0; - std::function flush_texture_caches_; + CvTextureCacheManager* texture_caches_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index df228b7dd..744ccea2d 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -24,15 +24,10 @@ namespace mediapipe { std::shared_ptr -GpuBufferMultiPool::MakeSimplePool(const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, spec.format, - options, flush_platform_caches_); -#else - return GlTextureBufferPool::Create(spec.width, spec.height, spec.format, - options); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +GpuBufferMultiPool::DefaultMakeSimplePool( + const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options) { + return SimplePool::Create(spec.width, spec.height, spec.format, options); } std::shared_ptr GpuBufferMultiPool::RequestPool( @@ -44,7 +39,7 @@ std::shared_ptr GpuBufferMultiPool::RequestPool( pool = cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { return (request_count >= options_.min_requests_before_pool) - ? MakeSimplePool(spec, options_) + ? create_simple_pool_(spec, options_) : nullptr; }); evicted = cache_.Evict(options_.max_pool_count, diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 3ea299f78..88428d053 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -52,10 +52,6 @@ class GpuBufferMultiPool { GpuBuffer GetBuffer(int width, int height, GpuBufferFormat format = GpuBufferFormat::kBGRA32); - void SetFlushPlatformCaches(std::function flush_platform_caches) { - flush_platform_caches_ = flush_platform_caches; - } - // This class is not intended as part of the public api of this class. It is // public only because it is used as a map key type, and the map // implementation needs access to, e.g., the equality operator. @@ -74,14 +70,21 @@ class GpuBufferMultiPool { mediapipe::GpuBufferFormat format; }; - private: #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER using SimplePool = CvPixelBufferPoolWrapper; #else using SimplePool = GlTextureBufferPool; #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - std::shared_ptr MakeSimplePool( + using SimplePoolFactory = std::function( + const BufferSpec& spec, const MultiPoolOptions& options)>; + + void SetSimplePoolFactory(SimplePoolFactory create_simple_pool) { + create_simple_pool_ = create_simple_pool; + } + + private: + static std::shared_ptr DefaultMakeSimplePool( const GpuBufferMultiPool::BufferSpec& spec, const MultiPoolOptions& options); @@ -94,8 +97,7 @@ class GpuBufferMultiPool { absl::Mutex mutex_; mediapipe::ResourceCache> cache_ ABSL_GUARDED_BY(mutex_); - // This is used to hook up the TextureCacheManager on Apple platforms. - std::function flush_platform_caches_; + SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; }; // BufferSpec equality operators diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 457b04fd3..6633c2f00 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -86,8 +86,12 @@ GpuResources::GpuResources(std::shared_ptr gl_context) { std::make_shared(gl_context.get()); #if __APPLE__ texture_caches_ = std::make_shared(); - gpu_buffer_pool().SetFlushPlatformCaches( - [tc = texture_caches_] { tc->FlushTextureCaches(); }); + gpu_buffer_pool().SetSimplePoolFactory( + [tc = texture_caches_](const GpuBufferMultiPool::BufferSpec& spec, + const MultiPoolOptions& options) { + return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, + spec.format, options, tc.get()); + }); #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER From 53d015af08c96d39ecee97bdfa11cc5b5a882cec Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:03:41 -0800 Subject: [PATCH 021/137] Generic MultiPool template PiperOrigin-RevId: 488783176 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc | 8 +- mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h | 8 +- mediapipe/gpu/gl_texture_buffer_pool.h | 9 +- mediapipe/gpu/gpu_buffer_format.h | 28 +++++++ mediapipe/gpu/gpu_buffer_multi_pool.cc | 46 +--------- mediapipe/gpu/gpu_buffer_multi_pool.h | 77 ++--------------- mediapipe/gpu/gpu_shared_data_internal.cc | 18 ++-- mediapipe/gpu/gpu_shared_data_internal.h | 6 +- mediapipe/gpu/multi_pool.h | 84 +++++++++++++++++-- 10 files changed, 142 insertions(+), 143 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 36527736b..1efe75b52 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -618,6 +618,7 @@ cc_library( cc_library( name = "multi_pool", hdrs = ["multi_pool.h"], + deps = ["//mediapipe/util:resource_cache"], ) cc_library( diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc index d8155f5cf..6e077ae6e 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.cc @@ -71,12 +71,12 @@ std::string CvPixelBufferPoolWrapper::GetDebugString() const { void CvPixelBufferPoolWrapper::Flush() { CVPixelBufferPoolFlush(*pool_, 0); } CFHolder CvPixelBufferPoolWrapper::CreateBufferWithoutPool( - int width, int height, GpuBufferFormat format) { - OSType cv_format = CVPixelFormatForGpuBufferFormat(format); + const internal::GpuBufferSpec& spec) { + OSType cv_format = CVPixelFormatForGpuBufferFormat(spec.format); CHECK_NE(cv_format, -1) << "unsupported pixel format"; CVPixelBufferRef buffer; - CVReturn err = - CreateCVPixelBufferWithoutPool(width, height, cv_format, &buffer); + CVReturn err = CreateCVPixelBufferWithoutPool(spec.width, spec.height, + cv_format, &buffer); CHECK(!err) << "Error creating pixel buffer: " << err; return MakeCFHolderAdopting(buffer); } diff --git a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h index 7d0aec4eb..4d71adbf2 100644 --- a/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h +++ b/mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h @@ -38,11 +38,11 @@ class CvPixelBufferPoolWrapper { CvTextureCacheManager* texture_caches); static std::shared_ptr Create( - int width, int height, GpuBufferFormat format, - const MultiPoolOptions& options, + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options, CvTextureCacheManager* texture_caches = nullptr) { return std::make_shared( - width, height, format, options.max_inactive_buffer_age, texture_caches); + spec.width, spec.height, spec.format, options.max_inactive_buffer_age, + texture_caches); } CFHolder GetBuffer(); @@ -53,7 +53,7 @@ class CvPixelBufferPoolWrapper { void Flush(); static CFHolder CreateBufferWithoutPool( - int width, int height, GpuBufferFormat format); + const internal::GpuBufferSpec& spec); private: CFHolder pool_; diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index fee46915e..29fc3c01c 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -42,9 +42,8 @@ class GlTextureBufferPool } static std::shared_ptr Create( - int width, int height, GpuBufferFormat format, - const MultiPoolOptions& options) { - return Create(width, height, format, options.keep_count); + const internal::GpuBufferSpec& spec, const MultiPoolOptions& options) { + return Create(spec.width, spec.height, spec.format, options.keep_count); } // Obtains a buffers. May either be reused or created anew. @@ -59,8 +58,8 @@ class GlTextureBufferPool std::pair GetInUseAndAvailableCounts(); static GlTextureBufferSharedPtr CreateBufferWithoutPool( - int width, int height, GpuBufferFormat format) { - return GlTextureBuffer::Create(width, height, format); + const internal::GpuBufferSpec& spec) { + return GlTextureBuffer::Create(spec.width, spec.height, spec.format); } private: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index 45f054d31..06c5a0439 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -153,6 +153,34 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) { #endif // __APPLE__ +namespace internal { + +struct GpuBufferSpec { + GpuBufferSpec(int w, int h, GpuBufferFormat f) + : width(w), height(h), format(f) {} + + template + friend H AbslHashValue(H h, const GpuBufferSpec& spec) { + return H::combine(std::move(h), spec.width, spec.height, + static_cast(spec.format)); + } + + int width; + int height; + GpuBufferFormat format; +}; + +// BufferSpec equality operators +inline bool operator==(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return lhs.width == rhs.width && lhs.height == rhs.height && + lhs.format == rhs.format; +} +inline bool operator!=(const GpuBufferSpec& lhs, const GpuBufferSpec& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace internal + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_FORMAT_H_ diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.cc b/mediapipe/gpu/gpu_buffer_multi_pool.cc index 744ccea2d..e2ed523e4 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.cc +++ b/mediapipe/gpu/gpu_buffer_multi_pool.cc @@ -16,51 +16,7 @@ #include -#include "absl/memory/memory.h" #include "absl/synchronization/mutex.h" #include "mediapipe/framework/port/logging.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -namespace mediapipe { - -std::shared_ptr -GpuBufferMultiPool::DefaultMakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options) { - return SimplePool::Create(spec.width, spec.height, spec.format, options); -} - -std::shared_ptr GpuBufferMultiPool::RequestPool( - const BufferSpec& spec) { - std::shared_ptr pool; - std::vector> evicted; - { - absl::MutexLock lock(&mutex_); - pool = - cache_.Lookup(spec, [this](const BufferSpec& spec, int request_count) { - return (request_count >= options_.min_requests_before_pool) - ? create_simple_pool_(spec, options_) - : nullptr; - }); - evicted = cache_.Evict(options_.max_pool_count, - options_.request_count_scrub_interval); - } - // Evicted pools, and their buffers, will be released without holding the - // lock. - return pool; -} - -GpuBuffer GpuBufferMultiPool::GetBuffer(int width, int height, - GpuBufferFormat format) { - BufferSpec key(width, height, format); - std::shared_ptr pool = RequestPool(key); - if (pool) { - // Note: we release our multipool lock before accessing the simple pool. - return GpuBuffer(pool->GetBuffer()); - } else { - return GpuBuffer( - SimplePool::CreateBufferWithoutPool(width, height, format)); - } -} - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 88428d053..827cf514a 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -22,15 +22,9 @@ #ifndef MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ #define MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ -#include "absl/hash/hash.h" #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/multi_pool.h" -#include "mediapipe/util/resource_cache.h" - -#ifdef __APPLE__ -#include "mediapipe/gpu/pixel_buffer_pool_util.h" -#endif // __APPLE__ #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #include "mediapipe/gpu/cv_pixel_buffer_pool_wrapper.h" @@ -40,77 +34,24 @@ namespace mediapipe { -struct GpuSharedData; class CvPixelBufferPoolWrapper; -class GpuBufferMultiPool { - public: - GpuBufferMultiPool(MultiPoolOptions options = kDefaultMultiPoolOptions) - : options_(options) {} - - // Obtains a buffer. May either be reused or created anew. - GpuBuffer GetBuffer(int width, int height, - GpuBufferFormat format = GpuBufferFormat::kBGRA32); - - // This class is not intended as part of the public api of this class. It is - // public only because it is used as a map key type, and the map - // implementation needs access to, e.g., the equality operator. - struct BufferSpec { - BufferSpec(int w, int h, mediapipe::GpuBufferFormat f) - : width(w), height(h), format(f) {} - - template - friend H AbslHashValue(H h, const BufferSpec& spec) { - return H::combine(std::move(h), spec.width, spec.height, - static_cast(spec.format)); - } - - int width; - int height; - mediapipe::GpuBufferFormat format; - }; - +class GpuBufferMultiPool : public MultiPool< #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - using SimplePool = CvPixelBufferPoolWrapper; + CvPixelBufferPoolWrapper, #else - using SimplePool = GlTextureBufferPool; + GlTextureBufferPool, #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + internal::GpuBufferSpec, GpuBuffer> { + public: + using MultiPool::MultiPool; - using SimplePoolFactory = std::function( - const BufferSpec& spec, const MultiPoolOptions& options)>; - - void SetSimplePoolFactory(SimplePoolFactory create_simple_pool) { - create_simple_pool_ = create_simple_pool; + GpuBuffer GetBuffer(int width, int height, + GpuBufferFormat format = GpuBufferFormat::kBGRA32) { + return Get(internal::GpuBufferSpec(width, height, format)); } - - private: - static std::shared_ptr DefaultMakeSimplePool( - const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options); - - // Requests a simple buffer pool for the given spec. This may return nullptr - // if we have not yet reached a sufficient number of requests to allocate a - // pool, in which case the caller should invoke CreateBufferWithoutPool. - std::shared_ptr RequestPool(const BufferSpec& spec); - - MultiPoolOptions options_; - absl::Mutex mutex_; - mediapipe::ResourceCache> cache_ - ABSL_GUARDED_BY(mutex_); - SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; }; -// BufferSpec equality operators -inline bool operator==(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return lhs.width == rhs.width && lhs.height == rhs.height && - lhs.format == rhs.format; -} -inline bool operator!=(const GpuBufferMultiPool::BufferSpec& lhs, - const GpuBufferMultiPool::BufferSpec& rhs) { - return !operator==(lhs, rhs); -} - } // namespace mediapipe #endif // MEDIAPIPE_GPU_GPU_BUFFER_MULTI_POOL_H_ diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 6633c2f00..52db88633 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -80,18 +80,20 @@ GpuResources::StatusOrGpuResources GpuResources::Create( return gpu_resources; } -GpuResources::GpuResources(std::shared_ptr gl_context) { +GpuResources::GpuResources(std::shared_ptr gl_context) +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + : texture_caches_(std::make_shared()), + gpu_buffer_pool_( + [tc = texture_caches_](const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) { + return CvPixelBufferPoolWrapper::Create(spec, options, tc.get()); + }) +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +{ gl_key_context_[SharedContextKey()] = gl_context; named_executors_[kGpuExecutorName] = std::make_shared(gl_context.get()); #if __APPLE__ - texture_caches_ = std::make_shared(); - gpu_buffer_pool().SetSimplePoolFactory( - [tc = texture_caches_](const GpuBufferMultiPool::BufferSpec& spec, - const MultiPoolOptions& options) { - return CvPixelBufferPoolWrapper::Create(spec.width, spec.height, - spec.format, options, tc.get()); - }); #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 12a7a1296..4fe6ba04e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -87,13 +87,15 @@ class GpuResources { std::map node_key_; std::map> gl_key_context_; +#ifdef MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + std::shared_ptr texture_caches_; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + // The pool must be destructed before the gl_context, but after the // ios_gpu_data, so the declaration order is important. GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - std::shared_ptr texture_caches_; - // Note that this is an Objective-C object. MPPGraphGPUData* ios_gpu_data_; #endif // defined(__APPLE__) diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h index e504fc820..8a3cf6be0 100644 --- a/mediapipe/gpu/multi_pool.h +++ b/mediapipe/gpu/multi_pool.h @@ -12,16 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This class lets calculators allocate GpuBuffers of various sizes, caching -// and reusing them as needed. It does so by automatically creating and using -// platform-specific buffer pools for the requested sizes. -// -// This class is not meant to be used directly by calculators, but is instead -// used by GlCalculatorHelper to allocate buffers. - #ifndef MEDIAPIPE_GPU_MULTI_POOL_H_ #define MEDIAPIPE_GPU_MULTI_POOL_H_ +#include "mediapipe/util/resource_cache.h" + namespace mediapipe { struct MultiPoolOptions { @@ -42,6 +37,81 @@ struct MultiPoolOptions { static constexpr MultiPoolOptions kDefaultMultiPoolOptions; +// MultiPool is a generic class for vending reusable resources of type Item, +// which are assumed to be relatively expensive to create, so that reusing them +// is beneficial. +// Items are classified by Spec; when an item with a given Spec is requested, +// an old Item with the same Spec can be reused, if available; otherwise a new +// Item will be created. When user code is done with an Item, it is returned +// to the pool for reuse. +// In order to manage this, a MultiPool contains a map of Specs to SimplePool; +// each SimplePool manages Items with the same Spec, which are thus considered +// interchangeable. +// Item retention and eviction policies are controlled by options. +// A concrete example would be a pool of GlTextureBuffer, grouped by dimensions +// and format. +template +class MultiPool { + public: + using SimplePoolFactory = std::function( + const Spec& spec, const MultiPoolOptions& options)>; + + MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, + MultiPoolOptions options = kDefaultMultiPoolOptions) + : create_simple_pool_(factory), options_(options) {} + + // Obtains an item. May either be reused or created anew. + Item Get(const Spec& spec); + + private: + static std::shared_ptr DefaultMakeSimplePool( + const Spec& spec, const MultiPoolOptions& options) { + return SimplePool::Create(spec, options); + } + + // Requests a simple buffer pool for the given spec. This may return nullptr + // if we have not yet reached a sufficient number of requests to allocate a + // pool, in which case the caller should invoke CreateBufferWithoutPool. + std::shared_ptr RequestPool(const Spec& spec); + + absl::Mutex mutex_; + mediapipe::ResourceCache> cache_ + ABSL_GUARDED_BY(mutex_); + SimplePoolFactory create_simple_pool_ = DefaultMakeSimplePool; + MultiPoolOptions options_; +}; + +template +std::shared_ptr MultiPool::RequestPool( + const Spec& spec) { + std::shared_ptr pool; + std::vector> evicted; + { + absl::MutexLock lock(&mutex_); + pool = cache_.Lookup(spec, [this](const Spec& spec, int request_count) { + return (request_count >= options_.min_requests_before_pool) + ? create_simple_pool_(spec, options_) + : nullptr; + }); + evicted = cache_.Evict(options_.max_pool_count, + options_.request_count_scrub_interval); + } + // Evicted pools, and their buffers, will be released without holding the + // lock. + return pool; +} + +template +Item MultiPool::Get(const Spec& spec) { + std::shared_ptr pool = RequestPool(spec); + if (pool) { + // Note: we release our multipool lock before accessing the simple pool. + return Item(pool->GetBuffer()); + } else { + return Item(SimplePool::CreateBufferWithoutPool(spec)); + } +} + } // namespace mediapipe #endif // MEDIAPIPE_GPU_MULTI_POOL_H_ From ab074a579a206164384f36e76581e784b8a65bd3 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:04:11 -0800 Subject: [PATCH 022/137] Internal change PiperOrigin-RevId: 488783325 --- WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index fea96d941..d43394883 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -26,7 +26,7 @@ versions.check(minimum_bazel_version = "3.7.2") http_archive( name = "com_google_absl", urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20210324.2.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz", ], # Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved. patches = [ @@ -35,8 +35,8 @@ http_archive( patch_args = [ "-p1", ], - strip_prefix = "abseil-cpp-20210324.2", - sha256 = "59b862f50e710277f8ede96f083a5bb8d7c9595376146838b9580be90374ee1f" + strip_prefix = "abseil-cpp-20220623.1", + sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8" ) http_archive( From 583d27636b346ae2e69c4b12f1346e2a8c32401c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:04:45 -0800 Subject: [PATCH 023/137] Factor out ReusablePool PiperOrigin-RevId: 488783477 --- mediapipe/gpu/BUILD | 11 ++ mediapipe/gpu/gl_texture_buffer.h | 5 + mediapipe/gpu/gl_texture_buffer_pool.cc | 77 +------------ mediapipe/gpu/gl_texture_buffer_pool.h | 52 +++------ mediapipe/gpu/reusable_pool.h | 145 ++++++++++++++++++++++++ 5 files changed, 178 insertions(+), 112 deletions(-) create mode 100644 mediapipe/gpu/reusable_pool.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 1efe75b52..747d131ba 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -607,6 +607,7 @@ cc_library( ":gpu_buffer", ":gpu_shared_data_header", ":multi_pool", + ":reusable_pool", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", @@ -615,6 +616,16 @@ cc_library( ], ) +cc_library( + name = "reusable_pool", + hdrs = ["reusable_pool.h"], + deps = [ + ":multi_pool", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "multi_pool", hdrs = ["multi_pool.h"], diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 124a0ec2f..a770163b5 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -71,6 +71,11 @@ class GlTextureBuffer // Create a texture with a copy of the data in image_frame. static std::unique_ptr Create(const ImageFrame& image_frame); + static std::unique_ptr Create( + const internal::GpuBufferSpec& spec) { + return Create(spec.width, spec.height, spec.format); + } + // Wraps an existing texture, but does not take ownership of it. // deletion_callback is invoked when the GlTextureBuffer is released, so // the caller knows that the texture is no longer in use. diff --git a/mediapipe/gpu/gl_texture_buffer_pool.cc b/mediapipe/gpu/gl_texture_buffer_pool.cc index 3d5a8cdaa..599381a34 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.cc +++ b/mediapipe/gpu/gl_texture_buffer_pool.cc @@ -16,79 +16,4 @@ #include "absl/synchronization/mutex.h" -namespace mediapipe { - -GlTextureBufferPool::GlTextureBufferPool(int width, int height, - GpuBufferFormat format, int keep_count) - : width_(width), - height_(height), - format_(format), - keep_count_(keep_count) {} - -GlTextureBufferSharedPtr GlTextureBufferPool::GetBuffer() { - std::unique_ptr buffer; - bool reuse = false; - - { - absl::MutexLock lock(&mutex_); - if (available_.empty()) { - buffer = GlTextureBuffer::Create(width_, height_, format_); - if (!buffer) return nullptr; - } else { - buffer = std::move(available_.back()); - available_.pop_back(); - reuse = true; - } - - ++in_use_count_; - } - - // This needs to wait on consumer sync points, therefore it should not be - // done while holding the mutex. - if (reuse) { - buffer->Reuse(); - } - - // Return a shared_ptr with a custom deleter that adds the buffer back - // to our available list. - std::weak_ptr weak_pool(shared_from_this()); - return std::shared_ptr( - buffer.release(), [weak_pool](GlTextureBuffer* buf) { - auto pool = weak_pool.lock(); - if (pool) { - pool->Return(absl::WrapUnique(buf)); - } else { - delete buf; - } - }); -} - -std::pair GlTextureBufferPool::GetInUseAndAvailableCounts() { - absl::MutexLock lock(&mutex_); - return {in_use_count_, available_.size()}; -} - -void GlTextureBufferPool::Return(std::unique_ptr buf) { - std::vector> trimmed; - { - absl::MutexLock lock(&mutex_); - --in_use_count_; - available_.emplace_back(std::move(buf)); - TrimAvailable(&trimmed); - } - // The trimmed buffers will be released without holding the lock. -} - -void GlTextureBufferPool::TrimAvailable( - std::vector>* trimmed) { - int keep = std::max(keep_count_ - in_use_count_, 0); - if (available_.size() > keep) { - auto trim_it = std::next(available_.begin(), keep); - if (trimmed) { - std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); - } - available_.erase(trim_it, available_.end()); - } -} - -} // namespace mediapipe +namespace mediapipe {} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_buffer_pool.h b/mediapipe/gpu/gl_texture_buffer_pool.h index 29fc3c01c..726d0528d 100644 --- a/mediapipe/gpu/gl_texture_buffer_pool.h +++ b/mediapipe/gpu/gl_texture_buffer_pool.h @@ -24,11 +24,11 @@ #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/gpu/multi_pool.h" +#include "mediapipe/gpu/reusable_pool.h" namespace mediapipe { -class GlTextureBufferPool - : public std::enable_shared_from_this { +class GlTextureBufferPool : public ReusablePool { public: // Creates a pool. This pool will manage buffers of the specified dimensions, // and will keep keep_count buffers around for reuse. @@ -37,52 +37,32 @@ class GlTextureBufferPool static std::shared_ptr Create(int width, int height, GpuBufferFormat format, int keep_count) { - return std::shared_ptr( - new GlTextureBufferPool(width, height, format, keep_count)); + return Create({width, height, format}, {.keep_count = keep_count}); } static std::shared_ptr Create( const internal::GpuBufferSpec& spec, const MultiPoolOptions& options) { - return Create(spec.width, spec.height, spec.format, options.keep_count); + return std::shared_ptr( + new GlTextureBufferPool(spec, options)); } - // Obtains a buffers. May either be reused or created anew. - // A GlContext must be current when this is called. - GlTextureBufferSharedPtr GetBuffer(); - - int width() const { return width_; } - int height() const { return height_; } - GpuBufferFormat format() const { return format_; } - - // This method is meant for testing. - std::pair GetInUseAndAvailableCounts(); + int width() const { return spec_.width; } + int height() const { return spec_.height; } + GpuBufferFormat format() const { return spec_.format; } static GlTextureBufferSharedPtr CreateBufferWithoutPool( const internal::GpuBufferSpec& spec) { - return GlTextureBuffer::Create(spec.width, spec.height, spec.format); + return GlTextureBuffer::Create(spec); } - private: - GlTextureBufferPool(int width, int height, GpuBufferFormat format, - int keep_count); + protected: + GlTextureBufferPool(const internal::GpuBufferSpec& spec, + const MultiPoolOptions& options) + : ReusablePool( + [this] { return GlTextureBuffer::Create(spec_); }, options), + spec_(spec) {} - // Return a buffer to the pool. - void Return(std::unique_ptr buf); - - // If the total number of buffers is greater than keep_count, destroys any - // surplus buffers that are no longer in use. - void TrimAvailable(std::vector>* trimmed) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - const int width_; - const int height_; - const GpuBufferFormat format_; - const int keep_count_; - - absl::Mutex mutex_; - int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; - std::vector> available_ - ABSL_GUARDED_BY(mutex_); + const internal::GpuBufferSpec spec_; }; } // namespace mediapipe diff --git a/mediapipe/gpu/reusable_pool.h b/mediapipe/gpu/reusable_pool.h new file mode 100644 index 000000000..ddeaa5ba7 --- /dev/null +++ b/mediapipe/gpu/reusable_pool.h @@ -0,0 +1,145 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Consider this file an implementation detail. None of this is part of the +// public API. + +#ifndef MEDIAPIPE_GPU_REUSABLE_POOL_H_ +#define MEDIAPIPE_GPU_REUSABLE_POOL_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/gpu/multi_pool.h" + +namespace mediapipe { + +template +class ReusablePool : public std::enable_shared_from_this> { + public: + using ItemFactory = absl::AnyInvocable() const>; + + // Creates a pool. This pool will manage buffers of the specified dimensions, + // and will keep keep_count buffers around for reuse. + // We enforce creation as a shared_ptr so that we can use a weak reference in + // the buffers' deleters. + static std::shared_ptr> Create( + ItemFactory item_factory, const MultiPoolOptions& options) { + return std::shared_ptr>( + new ReusablePool(std::move(item_factory), options)); + } + + // Obtains a buffer. May either be reused or created anew. + // A GlContext must be current when this is called. + std::shared_ptr GetBuffer(); + + // This method is meant for testing. + std::pair GetInUseAndAvailableCounts(); + + protected: + ReusablePool(ItemFactory item_factory, const MultiPoolOptions& options) + : item_factory_(std::move(item_factory)), + keep_count_(options.keep_count) {} + + private: + // Return a buffer to the pool. + void Return(std::unique_ptr buf); + + // If the total number of buffers is greater than keep_count, destroys any + // surplus buffers that are no longer in use. + void TrimAvailable(std::vector>* trimmed) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + const ItemFactory item_factory_; + const int keep_count_; + + absl::Mutex mutex_; + int in_use_count_ ABSL_GUARDED_BY(mutex_) = 0; + std::vector> available_ ABSL_GUARDED_BY(mutex_); +}; + +template +inline std::shared_ptr ReusablePool::GetBuffer() { + std::unique_ptr buffer; + bool reuse = false; + + { + absl::MutexLock lock(&mutex_); + if (available_.empty()) { + buffer = item_factory_(); + if (!buffer) return nullptr; + } else { + buffer = std::move(available_.back()); + available_.pop_back(); + reuse = true; + } + + ++in_use_count_; + } + + // This needs to wait on consumer sync points, therefore it should not be + // done while holding the mutex. + if (reuse) { + buffer->Reuse(); + } + + // Return a shared_ptr with a custom deleter that adds the buffer back + // to our available list. + std::weak_ptr> weak_pool(this->shared_from_this()); + return std::shared_ptr(buffer.release(), [weak_pool](Item* buf) { + auto pool = weak_pool.lock(); + if (pool) { + pool->Return(absl::WrapUnique(buf)); + } else { + delete buf; + } + }); +} + +template +inline std::pair ReusablePool::GetInUseAndAvailableCounts() { + absl::MutexLock lock(&mutex_); + return {in_use_count_, available_.size()}; +} + +template +void ReusablePool::Return(std::unique_ptr buf) { + std::vector> trimmed; + { + absl::MutexLock lock(&mutex_); + --in_use_count_; + available_.emplace_back(std::move(buf)); + TrimAvailable(&trimmed); + } + // The trimmed buffers will be released without holding the lock. +} + +template +void ReusablePool::TrimAvailable( + std::vector>* trimmed) { + int keep = std::max(keep_count_ - in_use_count_, 0); + if (available_.size() > keep) { + auto trim_it = std::next(available_.begin(), keep); + if (trimmed) { + std::move(trim_it, available_.end(), std::back_inserter(*trimmed)); + } + available_.erase(trim_it, available_.end()); + } +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_REUSABLE_POOL_H_ From 1beca6165057a4d198a09cfb9becca9252529895 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:06:16 -0800 Subject: [PATCH 024/137] Register GlTextureBuffer pool with GpuBuffer First crack at hooking up pools with the GpuBufferStorage system. Will most likely be superseded later, but for now this works with minimal code impact: just overwrite the factory for a storage type with one that uses the pool. PiperOrigin-RevId: 488783854 --- mediapipe/gpu/gpu_buffer_storage.h | 20 +++++++++++---- mediapipe/gpu/gpu_shared_data_internal.cc | 30 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 3d872eb66..214f506c0 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -74,13 +74,17 @@ class GpuBufferStorageRegistry { template RegistryToken Register() { - return Register( + return RegisterFactory( [](int width, int height, GpuBufferFormat format) -> std::shared_ptr { return CreateStorage(overload_priority<10>{}, width, height, format); - }, - Storage::GetProviderTypes()); + }); + } + + template + RegistryToken RegisterFactory(F&& factory) { + return Register(factory, Storage::GetProviderTypes()); } template @@ -148,6 +152,13 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { return kHashes; } + // Exposing this as a function allows dependent initializers to call this to + // ensure proper ordering. + static GpuBufferStorageRegistry::RegistryToken RegisterOnce() { + static auto registration = GpuBufferStorageRegistry::Get().Register(); + return registration; + } + private: virtual const void* down_cast(TypeId to) const override { return down_cast_impl(to, types{}); @@ -161,8 +172,7 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { return down_cast_impl(to, types{}); } - inline static auto registration = - GpuBufferStorageRegistry::Get().Register(); + inline static auto registration = RegisterOnce(); using RequireStatics = ForceStaticInstantiation<®istration>; }; diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 52db88633..91723a7d1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -200,4 +200,34 @@ GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } #endif // __APPLE__ +extern const GraphService kGpuService; + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +static std::shared_ptr GetGlTextureBufferFromPool( + int width, int height, GpuBufferFormat format) { + std::shared_ptr texture_buffer; + const auto cc = LegacyCalculatorSupport::Scoped::current(); + + if (cc && cc->Service(kGpuService).IsAvailable()) { + GpuBufferMultiPool* pool = + &cc->Service(kGpuService).GetObject().gpu_buffer_pool(); + // Note that the "gpu_buffer_pool" serves GlTextureBuffers on non-Apple + // platforms. TODO: refactor into storage pools. + texture_buffer = pool->GetBuffer(width, height, format) + .internal_storage(); + } else { + texture_buffer = GlTextureBuffer::Create(width, height, format); + } + return texture_buffer; +} + +static auto kGlTextureBufferPoolRegistration = [] { + // Ensure that the GlTextureBuffer's own factory is already registered, so we + // can override it. + GlTextureBuffer::RegisterOnce(); + return internal::GpuBufferStorageRegistry::Get() + .RegisterFactory(GetGlTextureBufferFromPool); +}(); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe From 7e19bbe35c85e77ba1d99a9824ecb60d06869f52 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 16:57:46 -0800 Subject: [PATCH 025/137] Internal change PiperOrigin-RevId: 488795920 --- mediapipe/gpu/gl_texture_buffer.h | 4 ++++ mediapipe/gpu/gpu_buffer_storage.h | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index a770163b5..1be24a86b 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -143,6 +143,10 @@ class GlTextureBuffer return producer_context_; } +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + static constexpr bool kDisableGpuBufferRegistration = true; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + private: // Creates a texture of dimensions width x height and allocates space for it. // If data is provided, it is uploaded to the texture; otherwise, it can be diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 214f506c0..0da5f236a 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -84,11 +84,17 @@ class GpuBufferStorageRegistry { template RegistryToken RegisterFactory(F&& factory) { + if constexpr (kDisableRegistration) { + return {}; + } return Register(factory, Storage::GetProviderTypes()); } template RegistryToken RegisterConverter(F&& converter) { + if constexpr (kDisableRegistration) { + return {}; + } return Register( [converter](std::shared_ptr source) -> std::shared_ptr { @@ -119,6 +125,13 @@ class GpuBufferStorageRegistry { return std::make_shared(args...); } + // Temporary workaround: a Storage class can define a static constexpr + // kDisableGpuBufferRegistration member to true to prevent registering any + // factory of converter that would produce it. + // TODO: better solution for storage priorities. + template + static constexpr bool kDisableRegistration = false; + RegistryToken Register(StorageFactory factory, std::vector provider_hashes); RegistryToken Register(StorageConverter converter, @@ -130,6 +143,13 @@ class GpuBufferStorageRegistry { converter_for_view_provider_and_existing_storage_; }; +// Putting this outside the class body to work around a GCC bug. +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=71954 +template +constexpr bool GpuBufferStorageRegistry::kDisableRegistration< + Storage, std::void_t> = + Storage::kDisableGpuBufferRegistration; + // Defining a member of this type causes P to be ODR-used, which forces its // instantiation if it's a static member of a template. template From 6702ef3d57570e66101a7e4535a04b0a75cdb6bb Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Tue, 15 Nov 2022 16:58:38 -0800 Subject: [PATCH 026/137] Internal change PiperOrigin-RevId: 488796090 --- docs/BUILD | 1 + docs/build_java_api_docs.py | 33 ++++++++++++++++----------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/docs/BUILD b/docs/BUILD index ad08df66a..8e85dbf86 100644 --- a/docs/BUILD +++ b/docs/BUILD @@ -17,6 +17,7 @@ py_binary( name = "build_java_api_docs", srcs = ["build_java_api_docs.py"], data = [ + "//third_party/android/sdk:api/26.txt", "//third_party/java/doclava/current:doclava.jar", "//third_party/java/jsilver:jsilver_jar", ], diff --git a/docs/build_java_api_docs.py b/docs/build_java_api_docs.py index e96e1fd83..b13e8d1df 100644 --- a/docs/build_java_api_docs.py +++ b/docs/build_java_api_docs.py @@ -20,10 +20,6 @@ from absl import flags from tensorflow_docs.api_generator import gen_java -_JAVA_ROOT = flags.DEFINE_string('java_src', None, - 'Override the Java source path.', - required=False) - _OUT_DIR = flags.DEFINE_string('output_dir', '/tmp/mp_java/', 'Write docs here.') @@ -37,27 +33,30 @@ _ = flags.DEFINE_bool( 'search_hints', True, '[UNUSED] Include metadata search hints in the generated files') +_ANDROID_SDK = pathlib.Path('android/sdk/api/26.txt') + def main(_) -> None: - if not (java_root := _JAVA_ROOT.value): - # Default to using a relative path to find the Java source. - mp_root = pathlib.Path(__file__) - while (mp_root := mp_root.parent).name != 'mediapipe': - # Find the nearest `mediapipe` dir. - pass + # Default to using a relative path to find the Java source. + mp_root = pathlib.Path(__file__) + while (mp_root := mp_root.parent).name != 'mediapipe': + # Find the nearest `mediapipe` dir. + pass - # Externally, parts of the repo are nested inside a mediapipe/ directory - # that does not exist internally. Support both. - if (mp_root / 'mediapipe').exists(): - mp_root = mp_root / 'mediapipe' + # Find the root from which all packages are relative. + root = mp_root.parent - java_root = mp_root / 'tasks/java' + # Externally, parts of the repo are nested inside a mediapipe/ directory + # that does not exist internally. Support both. + if (mp_root / 'mediapipe').exists(): + mp_root = mp_root / 'mediapipe' gen_java.gen_java_docs( package='com.google.mediapipe', - source_path=pathlib.Path(java_root), + source_path=mp_root / 'tasks/java', output_dir=pathlib.Path(_OUT_DIR.value), - site_path=pathlib.Path(_SITE_PATH.value)) + site_path=pathlib.Path(_SITE_PATH.value), + federated_docs={'https://developer.android.com': root / _ANDROID_SDK}) if __name__ == '__main__': From 77b3edbb6757f6afe3446bd237297d62dc14832d Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 17:04:39 -0800 Subject: [PATCH 027/137] Internal change PiperOrigin-RevId: 488797407 --- mediapipe/gpu/gpu_buffer.cc | 47 +++++++++++++++++++++++-------------- mediapipe/gpu/gpu_buffer.h | 14 +++++++---- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index e570ce8ba..35a73fd8f 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -1,6 +1,7 @@ #include "mediapipe/gpu/gpu_buffer.h" #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -29,7 +30,7 @@ std::string GpuBuffer::DebugString() const { "]"); } -internal::GpuBufferStorage& GpuBuffer::GetStorageForView( +internal::GpuBufferStorage* GpuBuffer::GetStorageForView( TypeId view_provider_type, bool for_writing) const { const std::shared_ptr* chosen_storage = nullptr; @@ -45,38 +46,48 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView( // TODO: choose best conversion. if (!chosen_storage) { for (const auto& s : storages_) { - auto converter = internal::GpuBufferStorageRegistry::Get() - .StorageConverterForViewProvider(view_provider_type, - s->storage_type()); - if (converter) { - storages_.push_back(converter(s)); - chosen_storage = &storages_.back(); + if (auto converter = internal::GpuBufferStorageRegistry::Get() + .StorageConverterForViewProvider( + view_provider_type, s->storage_type())) { + if (auto new_storage = converter(s)) { + storages_.push_back(new_storage); + chosen_storage = &storages_.back(); + break; + } } } } if (for_writing) { - if (!chosen_storage) { - // Allocate a new storage supporting the requested view. - auto factory = internal::GpuBufferStorageRegistry::Get() - .StorageFactoryForViewProvider(view_provider_type); - if (factory) { - storages_ = {factory(width(), height(), format())}; - chosen_storage = &storages_.back(); - } - } else { + if (chosen_storage) { // Discard all other storages. storages_ = {*chosen_storage}; chosen_storage = &storages_.back(); + } else { + // Allocate a new storage supporting the requested view. + if (auto factory = + internal::GpuBufferStorageRegistry::Get() + .StorageFactoryForViewProvider(view_provider_type)) { + if (auto new_storage = factory(width(), height(), format())) { + storages_ = {std::move(new_storage)}; + chosen_storage = &storages_.back(); + } + } } } + return chosen_storage ? chosen_storage->get() : nullptr; +} +internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( + TypeId view_provider_type, bool for_writing) const { + auto* chosen_storage = + GpuBuffer::GetStorageForView(view_provider_type, for_writing); CHECK(chosen_storage) << "no view provider found for requested view " << view_provider_type.name() << "; storages available: " << absl::StrJoin(storages_, ", ", StorageTypeFormatter()); - DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); - return **chosen_storage; + DCHECK(chosen_storage->can_down_cast_to(view_provider_type)); + return *chosen_storage; } #if !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 57e077151..ad5c130b5 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -105,7 +105,7 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetReadView(Args... args) const { - return GetViewProvider(false)->GetReadView( + return GetViewProviderOrDie(false).GetReadView( internal::types{}, std::make_shared(*this), std::forward(args)...); } @@ -114,7 +114,7 @@ class GpuBuffer { // specific view type; see the corresponding ViewProvider. template decltype(auto) GetWriteView(Args... args) { - return GetViewProvider(true)->GetWriteView( + return GetViewProviderOrDie(true).GetWriteView( internal::types{}, std::make_shared(*this), std::forward(args)...); } @@ -147,13 +147,17 @@ class GpuBuffer { GpuBufferFormat format_ = GpuBufferFormat::kUnknown; }; - internal::GpuBufferStorage& GetStorageForView(TypeId view_provider_type, + internal::GpuBufferStorage* GetStorageForView(TypeId view_provider_type, bool for_writing) const; + internal::GpuBufferStorage& GetStorageForViewOrDie(TypeId view_provider_type, + bool for_writing) const; + template - internal::ViewProvider* GetViewProvider(bool for_writing) const { + internal::ViewProvider& GetViewProviderOrDie(bool for_writing) const { using VP = internal::ViewProvider; - return GetStorageForView(kTypeId, for_writing).template down_cast(); + return *GetStorageForViewOrDie(kTypeId, for_writing) + .template down_cast(); } std::shared_ptr& no_storage() const { From 4bda012bba8fa7b5e0b4a04ebdfae8519329bc32 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 17:07:26 -0800 Subject: [PATCH 028/137] Factor out gl_texture_util PiperOrigin-RevId: 488797985 --- mediapipe/gpu/BUILD | 11 +++++++ mediapipe/gpu/gl_texture_util.cc | 30 ++++++++++++++++++ mediapipe/gpu/gl_texture_util.h | 34 +++++++++++++++++++++ mediapipe/gpu/gpu_buffer_test.cc | 52 ++++---------------------------- 4 files changed, 81 insertions(+), 46 deletions(-) create mode 100644 mediapipe/gpu/gl_texture_util.cc create mode 100644 mediapipe/gpu/gl_texture_util.h diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 747d131ba..68e788c52 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -689,6 +689,17 @@ cc_library( }), ) +cc_library( + name = "gl_texture_util", + srcs = ["gl_texture_util.cc"], + hdrs = ["gl_texture_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":gl_base", + ":gl_texture_view", + ], +) + cc_library( name = "shader_util", srcs = ["shader_util.cc"], diff --git a/mediapipe/gpu/gl_texture_util.cc b/mediapipe/gpu/gl_texture_util.cc new file mode 100644 index 000000000..603e82a46 --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.cc @@ -0,0 +1,30 @@ +#include "mediapipe/gpu/gl_texture_util.h" + +namespace mediapipe { + +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { + glViewport(0, 0, src.width(), src.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), + src.name(), 0); + + glActiveTexture(GL_TEXTURE0); + glBindTexture(dst.target(), dst.name()); + glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); + + glBindTexture(dst.target(), 0); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, + 0); +} + +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, + float a) { + glViewport(0, 0, view.width(), view.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), + view.name(), 0); + glClearColor(r, g, b, a); + glClear(GL_COLOR_BUFFER_BIT); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, + 0); +} + +} // namespace mediapipe diff --git a/mediapipe/gpu/gl_texture_util.h b/mediapipe/gpu/gl_texture_util.h new file mode 100644 index 000000000..73ac37ade --- /dev/null +++ b/mediapipe/gpu/gl_texture_util.h @@ -0,0 +1,34 @@ +#ifndef MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ +#define MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ + +#include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gl_texture_view.h" + +namespace mediapipe { + +// Copies a texture to another. +// Assumes a framebuffer is already set up +void CopyGlTexture(const GlTextureView& src, GlTextureView& dst); + +// Fills a texture with a color. +void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, float a); + +// RAII class to set up a temporary framebuffer. Mainly for test use. +class TempGlFramebuffer { + public: + TempGlFramebuffer() { + glGenFramebuffers(1, &framebuffer_); + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + } + ~TempGlFramebuffer() { + glBindFramebuffer(GL_FRAMEBUFFER, 0); + glDeleteFramebuffers(1, &framebuffer_); + } + + private: + GLuint framebuffer_; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_GPU_GL_TEXTURE_UTIL_H_ diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 3fd519b21..796cb1d9d 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gl_texture_util.h" #include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" @@ -41,47 +42,6 @@ void FillImageFrameRGBA(ImageFrame& image, uint8 r, uint8 g, uint8 b, uint8 a) { } } -// Assumes a framebuffer is already set up -void CopyGlTexture(const GlTextureView& src, GlTextureView& dst) { - glViewport(0, 0, src.width(), src.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), - src.name(), 0); - - glActiveTexture(GL_TEXTURE0); - glBindTexture(dst.target(), dst.name()); - glCopyTexSubImage2D(dst.target(), 0, 0, 0, 0, 0, dst.width(), dst.height()); - - glBindTexture(dst.target(), 0); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, src.target(), 0, - 0); -} - -void FillGlTextureRgba(GlTextureView& view, float r, float g, float b, - float a) { - glViewport(0, 0, view.width(), view.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), - view.name(), 0); - glClearColor(r, g, b, a); - glClear(GL_COLOR_BUFFER_BIT); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), 0, - 0); -} - -class TempGlFramebuffer { - public: - TempGlFramebuffer() { - glGenFramebuffers(1, &framebuffer_); - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - } - ~TempGlFramebuffer() { - glBindFramebuffer(GL_FRAMEBUFFER, 0); - glDeleteFramebuffers(1, &framebuffer_); - } - - private: - GLuint framebuffer_; -}; - class GpuBufferTest : public GpuTestBase {}; TEST_F(GpuBufferTest, BasicTest) { @@ -127,7 +87,7 @@ TEST_F(GpuBufferTest, GlTextureView) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "gltv_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "gltv_red_view")); } @@ -162,7 +122,7 @@ TEST_F(GpuBufferTest, ImageFrame) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "if_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "if_red_view")); } @@ -196,7 +156,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame red(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(red, 255, 0, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, red, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, red, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(red, "ow_red_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_red_view")); } @@ -230,7 +190,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame green(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(green, 0, 255, 0, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, green, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, green, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(green, "ow_green_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_green_view")); } @@ -240,7 +200,7 @@ TEST_F(GpuBufferTest, Overwrite) { ImageFrame blue(ImageFormat::SRGBA, 300, 200); FillImageFrameRGBA(blue, 0, 0, 255, 255); - EXPECT_TRUE(mediapipe::CompareImageFrames(*view, blue, 0.0, 0.0)); + EXPECT_TRUE(CompareImageFrames(*view, blue, 0.0, 0.0)); MP_EXPECT_OK(SavePngTestOutput(blue, "ow_blue_gold")); MP_EXPECT_OK(SavePngTestOutput(*view, "ow_blue_view")); } From b308c0dd5e114cbf803dd2864f67589be048b7a0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 17:08:37 -0800 Subject: [PATCH 029/137] Implement CVPixelBufferRef access as a view. PiperOrigin-RevId: 488798216 --- mediapipe/gpu/gpu_buffer.cc | 7 ++-- mediapipe/gpu/gpu_buffer.h | 4 +++ .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 35 ++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer.cc b/mediapipe/gpu/gpu_buffer.cc index 35a73fd8f..388960b11 100644 --- a/mediapipe/gpu/gpu_buffer.cc +++ b/mediapipe/gpu/gpu_buffer.cc @@ -93,8 +93,11 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForViewOrDie( #if !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer) { - auto p = buffer.internal_storage(); - if (p) return **p; + if (buffer.GetStorageForView( + kTypeId>, + /*for_writing=*/false) != nullptr) { + return *buffer.GetReadView(); + } return nullptr; } #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index ad5c130b5..45146a322 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -179,6 +179,10 @@ class GpuBuffer { // This is mutable because view methods that do not change the contents may // still need to allocate new storages. mutable std::vector> storages_; + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + friend CVPixelBufferRef GetCVPixelBufferRef(const GpuBuffer& buffer); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; inline bool GpuBuffer::operator==(std::nullptr_t other) const { diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index 017771dc7..e5bc5de43 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -12,10 +12,27 @@ namespace mediapipe { class GlContext; +namespace internal { + +template <> +class ViewProvider { + public: + virtual ~ViewProvider() = default; + virtual CFHolder GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const = 0; + virtual CFHolder GetWriteView( + internal::types, + std::shared_ptr gpu_buffer) = 0; +}; + +} // namespace internal + class GpuBufferStorageCvPixelBuffer : public internal::GpuBufferStorageImpl< GpuBufferStorageCvPixelBuffer, internal::ViewProvider, - internal::ViewProvider>, + internal::ViewProvider, + internal::ViewProvider>, public CFHolder { public: using CFHolder::CFHolder; @@ -44,6 +61,12 @@ class GpuBufferStorageCvPixelBuffer std::shared_ptr GetWriteView( internal::types, std::shared_ptr gpu_buffer) override; + CFHolder GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const override; + CFHolder GetWriteView( + internal::types, + std::shared_ptr gpu_buffer) override; private: GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, @@ -51,6 +74,16 @@ class GpuBufferStorageCvPixelBuffer void ViewDoneWriting(const GlTextureView& view); }; +inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types, + std::shared_ptr gpu_buffer) const { + return *this; +} +inline CFHolder GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, std::shared_ptr gpu_buffer) { + return *this; +} + namespace internal { // These functions enable backward-compatible construction of a GpuBuffer from // CVPixelBufferRef without having to expose that type in the main GpuBuffer From 2f77bf44e3f3a53ff187bd9a39f9cbc413b4e413 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 15 Nov 2022 18:08:31 -0800 Subject: [PATCH 030/137] Use train_data to evaluate accuracy of unit test for gesture_recognizer due to limited dataset size. PiperOrigin-RevId: 488808942 --- .../gesture_recognizer_test.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 7e7a1ca30..9bac22133 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -42,8 +42,8 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() all_data = self._load_data() - # Splits data, 90% data for training, 10% for testing - self._train_data, self._test_data = all_data.split(0.9) + # Splits data, 90% data for training, 10% for validation + self._train_data, self._validation_data = all_data.split(0.9) def test_gesture_recognizer_model(self): model_options = gesture_recognizer.ModelOptions() @@ -53,7 +53,7 @@ class GestureRecognizerTest(tf.test.TestCase): model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) self._test_accuracy(model) @@ -66,7 +66,7 @@ class GestureRecognizerTest(tf.test.TestCase): model_options=model_options, hparams=hparams) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) model.export_model() model_bundle_file = os.path.join(hparams.export_dir, @@ -94,8 +94,9 @@ class GestureRecognizerTest(tf.test.TestCase): size=[1, model.embedding_size]) def _test_accuracy(self, model, threshold=0.5): - _, accuracy = model.evaluate(self._test_data) - tf.compat.v1.logging.info(f'accuracy: {accuracy}') + # Test on _train_data because of our limited dataset size + _, accuracy = model.evaluate(self._train_data) + tf.compat.v1.logging.info(f'train accuracy: {accuracy}') self.assertGreaterEqual(accuracy, threshold) @unittest_mock.patch.object( @@ -113,7 +114,7 @@ class GestureRecognizerTest(tf.test.TestCase): options = gesture_recognizer.GestureRecognizerOptions() gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=options) mock_hparams.assert_called_once() mock_model_options.assert_called_once() @@ -128,11 +129,11 @@ class GestureRecognizerTest(tf.test.TestCase): with mock.patch('sys.stdout', mock_stdout): model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) model = gesture_recognizer.GestureRecognizer.create( train_data=self._train_data, - validation_data=self._test_data, + validation_data=self._validation_data, options=gesture_recognizer_options) self._test_accuracy(model) From fe66de37149bbd8a706b78e33b210bde5c3a021c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:30:58 -0800 Subject: [PATCH 031/137] Internal change PiperOrigin-RevId: 488812677 --- mediapipe/gpu/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 68e788c52..9cb27d2f1 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -221,11 +221,11 @@ cc_library( ":gpu_buffer_format", ":gpu_buffer_storage", ":gpu_buffer_storage_image_frame", + "@com_google_absl//absl/memory", # TODO: remove this dependency. Some other teams' tests # depend on having an indirect image_frame dependency, need to be # fixed first. "//mediapipe/framework/formats:image_frame", - "@com_google_absl//absl/memory", ], ) From 4c874fe4cd7f8c2fe4afdf1ac7630450264c3eba Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:31:27 -0800 Subject: [PATCH 032/137] Allow conversion of GlTextureBuffer to CVPixelBufferRef This means that, if an iOS application sends in a GlTextureBuffer but expects a CVPixelBufferRef as output, everything will work even if the graph just forwards the same input. Also, access by Metal calculators will also work transparently. PiperOrigin-RevId: 488812748 --- mediapipe/gpu/BUILD | 27 ++++++++++++++++++++++++++- mediapipe/gpu/gl_texture_buffer.cc | 29 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9cb27d2f1..196de3076 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -226,7 +226,13 @@ cc_library( # depend on having an indirect image_frame dependency, need to be # fixed first. "//mediapipe/framework/formats:image_frame", - ], + ] + select({ + "//conditions:default": [], + ":platform_ios_with_gpu": [ + ":gl_texture_util", + ":gpu_buffer_storage_cv_pixel_buffer", + ], + }), ) cc_library( @@ -344,6 +350,25 @@ cc_library( ], ) +mediapipe_cc_test( + name = "gpu_buffer_storage_cv_pixel_buffer_test", + size = "small", + timeout = "moderate", + srcs = ["gpu_buffer_storage_cv_pixel_buffer_test.cc"], + platforms = ["ios"], + deps = [ + ":gl_texture_buffer", + ":gl_texture_util", + ":gpu_buffer", + ":gpu_buffer_storage_cv_pixel_buffer", + ":gpu_test_base", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:test_util", + "//mediapipe/objc:util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "cv_texture_cache_manager", srcs = ["cv_texture_cache_manager.cc"], diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index fbb91a8f5..4c2f15a8d 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -18,6 +18,11 @@ #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/gl_texture_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + namespace mediapipe { std::unique_ptr GlTextureBuffer::Wrap( @@ -380,4 +385,28 @@ static auto kConverterRegistration2 = .RegisterConverter( ConvertFromImageFrame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +static std::shared_ptr ConvertToCvPixelBuffer( + std::shared_ptr buf) { + auto output = absl::make_unique( + buf->width(), buf->height(), buf->format()); + buf->GetProducerContext()->Run([buf, &output] { + TempGlFramebuffer framebuffer; + auto src = buf->GetReadView(internal::types{}, nullptr, 0); + auto dst = + output->GetWriteView(internal::types{}, nullptr, 0); + CopyGlTexture(src, dst); + glFlush(); + }); + return output; +} + +static auto kConverterRegistrationCvpb = + internal::GpuBufferStorageRegistry::Get() + .RegisterConverter( + ConvertToCvPixelBuffer); + +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + } // namespace mediapipe From 767cc2ee3cbec8472fcacedbd890def1d9c0b63f Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:31:54 -0800 Subject: [PATCH 033/137] More comments on gpu_buffer_storage This gives a basic explanation of the role of storages and views, and provides some details on how to implement a new storage type. PiperOrigin-RevId: 488812807 --- mediapipe/gpu/gpu_buffer_storage.h | 45 +++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 0da5f236a..b15c9c843 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -22,13 +22,27 @@ struct types {}; template class ViewProvider; -// Interface for a backing storage for GpuBuffer. +// Generic interface for a backing storage for GpuBuffer. +// +// GpuBuffer is an opaque handle to an image. Its contents are handled by +// Storage classes. Application code does not interact with the storages +// directly; to access the data, it asks the GpuBuffer for a View, and in turn +// GpuBuffer looks for a storage that can provide that view. +// This architecture decouples application code from the underlying storage, +// making it possible to use platform-specific optimized storage systems, e.g. +// for zero-copy data sharing between CPU and GPU. +// +// Storage implementations should inherit from GpuBufferStorageImpl. See that +// class for details. class GpuBufferStorage { public: virtual ~GpuBufferStorage() = default; + + // Concrete storage types should override the following three accessors. virtual int width() const = 0; virtual int height() const = 0; virtual GpuBufferFormat format() const = 0; + // We can't use dynamic_cast since we want to support building without RTTI. // The public methods delegate to the type-erased private virtual method. template @@ -72,6 +86,8 @@ class GpuBufferStorageRegistry { return *registry; } + // Registers a storage type by automatically creating a factory for it. + // This is normally called by GpuBufferImpl. template RegistryToken Register() { return RegisterFactory( @@ -82,6 +98,7 @@ class GpuBufferStorageRegistry { }); } + // Registers a new factory for a storage type. template RegistryToken RegisterFactory(F&& factory) { if constexpr (kDisableRegistration) { @@ -90,6 +107,7 @@ class GpuBufferStorageRegistry { return Register(factory, Storage::GetProviderTypes()); } + // Registers a new converter from storage type StorageFrom to StorageTo. template RegistryToken RegisterConverter(F&& converter) { if constexpr (kDisableRegistration) { @@ -162,14 +180,26 @@ struct ForceStaticInstantiation { #endif // _MSC_VER }; -// T: storage type -// U...: ViewProvider +// Inherit from this class to define a new storage type. The storage type itself +// should be passed as the first template argument (CRTP), followed by one or +// more specializations of ViewProvider. +// +// Concrete storage types should implement the basic accessors from +// GpuBufferStorage, plus the view read/write getters for each ViewProvider they +// implement. This class handles the rest. +// +// Arguments: +// T: storage type +// U...: ViewProvider +// Example: +// class MyStorage : public GpuBufferStorageImpl< +// MyStorage, ViewProvider> template class GpuBufferStorageImpl : public GpuBufferStorage, public U... { public: static const std::vector& GetProviderTypes() { - static std::vector kHashes{kTypeId...}; - return kHashes; + static std::vector kProviderIds{kTypeId...}; + return kProviderIds; } // Exposing this as a function allows dependent initializers to call this to @@ -180,10 +210,11 @@ class GpuBufferStorageImpl : public GpuBufferStorage, public U... { } private: - virtual const void* down_cast(TypeId to) const override { + // Allows a down_cast to any of the view provider types in U. + const void* down_cast(TypeId to) const final { return down_cast_impl(to, types{}); } - TypeId storage_type() const override { return kTypeId; } + TypeId storage_type() const final { return kTypeId; } const void* down_cast_impl(TypeId to, types<>) const { return nullptr; } template From 1c0a1d0aab81bdea369ef912f3f0739cfe84ad81 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:32:27 -0800 Subject: [PATCH 034/137] Remove shared_ptr member from GlTextureView This only exists to support GlTexture's GetFrame API. It can be moved into GlTexture. PiperOrigin-RevId: 488812896 --- mediapipe/gpu/gl_calculator_helper.h | 8 ++++++-- mediapipe/gpu/gl_calculator_helper_impl_common.cc | 15 +++------------ mediapipe/gpu/gl_texture_view.cc | 1 - mediapipe/gpu/gl_texture_view.h | 4 ---- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index e44523202..0a0cc16cb 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -201,9 +201,13 @@ class GlTexture { void Release() { view_ = std::make_shared(); } private: - explicit GlTexture(GlTextureView view) - : view_(std::make_shared(std::move(view))) {} + explicit GlTexture(GlTextureView view, GpuBuffer gpu_buffer) + : gpu_buffer_(std::move(gpu_buffer)), + view_(std::make_shared(std::move(view))) {} friend class GlCalculatorHelperImpl; + // We store the GpuBuffer to support GetFrame, and to ensure that the storage + // outlives the view. + GpuBuffer gpu_buffer_; std::shared_ptr view_; }; diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc index c5c028d4f..6311d8905 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_common.cc @@ -101,7 +101,7 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, glBindTexture(view.target(), 0); } - return GlTexture(std::move(view)); + return GlTexture(std::move(view), gpu_buffer); } GlTexture GlCalculatorHelperImpl::CreateSourceTexture( @@ -143,7 +143,7 @@ template <> std::unique_ptr GlTexture::GetFrame() const { view_->DoneWriting(); std::shared_ptr view = - view_->gpu_buffer().GetReadView(); + gpu_buffer_.GetReadView(); auto copy = absl::make_unique(); copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); return copy; @@ -151,17 +151,8 @@ std::unique_ptr GlTexture::GetFrame() const { template <> std::unique_ptr GlTexture::GetFrame() const { - auto gpu_buffer = view_->gpu_buffer(); -#ifdef __EMSCRIPTEN__ - // When WebGL is used, the GL context may be spontaneously lost which can - // cause GpuBuffer allocations to fail. In that case, return a dummy buffer - // to allow processing of the current frame complete. - if (!gpu_buffer) { - return std::make_unique(); - } -#endif // __EMSCRIPTEN__ view_->DoneWriting(); - return absl::make_unique(gpu_buffer); + return absl::make_unique(gpu_buffer_); } GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( diff --git a/mediapipe/gpu/gl_texture_view.cc b/mediapipe/gpu/gl_texture_view.cc index 5d1862ddc..cae4039a4 100644 --- a/mediapipe/gpu/gl_texture_view.cc +++ b/mediapipe/gpu/gl_texture_view.cc @@ -7,7 +7,6 @@ void GlTextureView::Release() { if (detach_) detach_(*this); detach_ = nullptr; gl_context_ = nullptr; - gpu_buffer_ = nullptr; plane_ = 0; name_ = 0; width_ = 0; diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 8b47d620b..d6734ed71 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -43,7 +43,6 @@ class GlTextureView { name_ = other.name_; width_ = other.width_; height_ = other.height_; - gpu_buffer_ = std::move(other.gpu_buffer_); plane_ = other.plane_; detach_ = std::exchange(other.detach_, nullptr); done_writing_ = std::exchange(other.done_writing_, nullptr); @@ -55,7 +54,6 @@ class GlTextureView { int height() const { return height_; } GLenum target() const { return target_; } GLuint name() const { return name_; } - const GpuBuffer& gpu_buffer() const { return *gpu_buffer_; } int plane() const { return plane_; } using DetachFn = std::function; @@ -74,7 +72,6 @@ class GlTextureView { name_(name), width_(width), height_(height), - gpu_buffer_(std::move(gpu_buffer)), plane_(plane), detach_(std::move(detach)), done_writing_(std::move(done_writing)) {} @@ -93,7 +90,6 @@ class GlTextureView { // Note: when scale is not 1, we still give the nominal size of the image. int width_ = 0; int height_ = 0; - std::shared_ptr gpu_buffer_; // using shared_ptr temporarily int plane_ = 0; DetachFn detach_; mutable DoneWritingFn done_writing_; From 13b4b825d74672d69a69d501dac4caf41e3ed098 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:33:04 -0800 Subject: [PATCH 035/137] Remove std::shared_ptr argument from GetRead/WriteView PiperOrigin-RevId: 488813004 --- mediapipe/gpu/gl_texture_buffer.cc | 23 +++++++--------- mediapipe/gpu/gl_texture_buffer.h | 2 -- mediapipe/gpu/gl_texture_view.h | 12 +++------ mediapipe/gpu/gpu_buffer.h | 6 ++--- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 24 +++++++---------- .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 27 +++++++------------ .../gpu/gpu_buffer_storage_image_frame.h | 6 ++--- mediapipe/gpu/image_frame_view.h | 5 ++-- 8 files changed, 38 insertions(+), 67 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 4c2f15a8d..e57195a46 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -255,9 +255,8 @@ void GlTextureBuffer::WaitForConsumersOnGpu() { // precisely, on only one GL context. } -GlTextureView GlTextureBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { +GlTextureView GlTextureBuffer::GetReadView(internal::types, + int plane) const { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); @@ -269,13 +268,11 @@ GlTextureView GlTextureBuffer::GetReadView( DidRead(texture.gl_context()->CreateSyncToken()); }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, std::move(detach), - nullptr); + plane, std::move(detach), nullptr); } -GlTextureView GlTextureBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { +GlTextureView GlTextureBuffer::GetWriteView(internal::types, + int plane) { auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); @@ -286,8 +283,7 @@ GlTextureView GlTextureBuffer::GetWriteView( GlTextureView::DoneWritingFn done_writing = [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), - std::move(gpu_buffer), plane, nullptr, - std::move(done_writing)); + plane, nullptr, std::move(done_writing)); } void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { @@ -364,7 +360,7 @@ static std::shared_ptr ConvertToImageFrame( absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); buf->GetProducerContext()->Run([buf, &output] { - auto view = buf->GetReadView(internal::types{}, nullptr, 0); + auto view = buf->GetReadView(internal::types{}, 0); ReadTexture(view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); @@ -393,9 +389,8 @@ static std::shared_ptr ConvertToCvPixelBuffer( buf->width(), buf->height(), buf->format()); buf->GetProducerContext()->Run([buf, &output] { TempGlFramebuffer framebuffer; - auto src = buf->GetReadView(internal::types{}, nullptr, 0); - auto dst = - output->GetWriteView(internal::types{}, nullptr, 0); + auto src = buf->GetReadView(internal::types{}, 0); + auto dst = output->GetWriteView(internal::types{}, 0); CopyGlTexture(src, dst); glFlush(); }); diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index 1be24a86b..c7643fd1b 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -95,10 +95,8 @@ class GlTextureBuffer GpuBufferFormat format() const { return format_; } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; // If this texture is going to be used outside of the context that produced diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index d6734ed71..b8ead2708 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -65,8 +65,8 @@ class GlTextureView { friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageAhwb; GlTextureView(GlContext* context, GLenum target, GLuint name, int width, - int height, std::shared_ptr gpu_buffer, int plane, - DetachFn detach, DoneWritingFn done_writing) + int height, int plane, DetachFn detach, + DoneWritingFn done_writing) : gl_context_(context), target_(target), name_(name), @@ -108,12 +108,8 @@ class ViewProvider { // the same view implement the same signature. // Note that we allow different views to have custom signatures, providing // additional view-specific arguments that may be needed. - virtual GlTextureView GetReadView(types, - std::shared_ptr gpu_buffer, - int plane) const = 0; - virtual GlTextureView GetWriteView(types, - std::shared_ptr gpu_buffer, - int plane) = 0; + virtual GlTextureView GetReadView(types, int plane) const = 0; + virtual GlTextureView GetWriteView(types, int plane) = 0; }; } // namespace internal diff --git a/mediapipe/gpu/gpu_buffer.h b/mediapipe/gpu/gpu_buffer.h index 45146a322..56507d92f 100644 --- a/mediapipe/gpu/gpu_buffer.h +++ b/mediapipe/gpu/gpu_buffer.h @@ -106,8 +106,7 @@ class GpuBuffer { template decltype(auto) GetReadView(Args... args) const { return GetViewProviderOrDie(false).GetReadView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + internal::types{}, std::forward(args)...); } // Gets a write view of the specified type. The arguments depend on the @@ -115,8 +114,7 @@ class GpuBuffer { template decltype(auto) GetWriteView(Args... args) { return GetViewProviderOrDie(true).GetWriteView( - internal::types{}, std::make_shared(*this), - std::forward(args)...); + internal::types{}, std::forward(args)...); } // Attempts to access an underlying storage object of the specified type. diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index d68ac0db0..f3954a6e4 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -26,8 +26,7 @@ GpuBufferStorageCvPixelBuffer::GpuBufferStorageCvPixelBuffer( } GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( - std::shared_ptr gpu_buffer, int plane, - GlTextureView::DoneWritingFn done_writing) const { + int plane, GlTextureView::DoneWritingFn done_writing) const { CVReturn err; auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); @@ -60,33 +59,30 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetTexture( cv_texture.adopt(cv_texture_temp); return GlTextureView( gl_context.get(), CVOpenGLESTextureGetTarget(*cv_texture), - CVOpenGLESTextureGetName(*cv_texture), width(), height(), - std::move(gpu_buffer), plane, + CVOpenGLESTextureGetName(*cv_texture), width(), height(), plane, [cv_texture](mediapipe::GlTextureView&) { /* only retains cv_texture */ }, done_writing); #endif // TARGET_OS_OSX } GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer, - int plane) const { - return GetTexture(std::move(gpu_buffer), plane, nullptr); + internal::types, int plane) const { + return GetTexture(plane, nullptr); } GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer, - int plane) { - return GetTexture( - std::move(gpu_buffer), plane, - [this](const mediapipe::GlTextureView& view) { ViewDoneWriting(view); }); + internal::types, int plane) { + return GetTexture(plane, [this](const mediapipe::GlTextureView& view) { + ViewDoneWriting(view); + }); } std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, std::shared_ptr gpu_buffer) const { + internal::types) const { return CreateImageFrameForCVPixelBuffer(**this); } std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer) { + internal::types) { return CreateImageFrameForCVPixelBuffer(**this); } diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index e5bc5de43..a9389ab8a 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -19,11 +19,9 @@ class ViewProvider { public: virtual ~ViewProvider() = default; virtual CFHolder GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const = 0; + internal::types) const = 0; virtual CFHolder GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) = 0; + internal::types) = 0; }; } // namespace internal @@ -50,37 +48,30 @@ class GpuBufferStorageCvPixelBuffer CVPixelBufferGetPixelFormatType(**this)); } GlTextureView GetReadView(internal::types, - std::shared_ptr gpu_buffer, int plane) const override; GlTextureView GetWriteView(internal::types, - std::shared_ptr gpu_buffer, int plane) override; std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override; + internal::types) const override; std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override; + internal::types) override; CFHolder GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override; + internal::types) const override; CFHolder GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override; + internal::types) override; private: - GlTextureView GetTexture(std::shared_ptr gpu_buffer, int plane, + GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; void ViewDoneWriting(const GlTextureView& view); }; inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const { + internal::types) const { return *this; } inline CFHolder GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, std::shared_ptr gpu_buffer) { + internal::types) { return *this; } diff --git a/mediapipe/gpu/gpu_buffer_storage_image_frame.h b/mediapipe/gpu/gpu_buffer_storage_image_frame.h index 2cea3445e..ab547b9ea 100644 --- a/mediapipe/gpu/gpu_buffer_storage_image_frame.h +++ b/mediapipe/gpu/gpu_buffer_storage_image_frame.h @@ -29,13 +29,11 @@ class GpuBufferStorageImageFrame std::shared_ptr image_frame() const { return image_frame_; } std::shared_ptr image_frame() { return image_frame_; } std::shared_ptr GetReadView( - internal::types, - std::shared_ptr gpu_buffer) const override { + internal::types) const override { return image_frame_; } std::shared_ptr GetWriteView( - internal::types, - std::shared_ptr gpu_buffer) override { + internal::types) override { return image_frame_; } diff --git a/mediapipe/gpu/image_frame_view.h b/mediapipe/gpu/image_frame_view.h index 2fc6f2495..b7e58a824 100644 --- a/mediapipe/gpu/image_frame_view.h +++ b/mediapipe/gpu/image_frame_view.h @@ -12,9 +12,8 @@ class ViewProvider { public: virtual ~ViewProvider() = default; virtual std::shared_ptr GetReadView( - types, std::shared_ptr gpu_buffer) const = 0; - virtual std::shared_ptr GetWriteView( - types, std::shared_ptr gpu_buffer) = 0; + types) const = 0; + virtual std::shared_ptr GetWriteView(types) = 0; }; } // namespace internal From a28ccb0964e327c5041c40cc26769275b46ce3b7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:33:32 -0800 Subject: [PATCH 036/137] Remove unnecessary forward declarations PiperOrigin-RevId: 488813066 --- mediapipe/gpu/gl_texture_view.h | 3 --- mediapipe/gpu/gpu_buffer_storage.h | 1 - 2 files changed, 4 deletions(-) diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index b8ead2708..8a257cf53 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -25,8 +25,6 @@ namespace mediapipe { class GlContext; -class GlTextureViewManager; -class GpuBuffer; class GlTextureView { public: @@ -60,7 +58,6 @@ class GlTextureView { using DoneWritingFn = std::function; private: - friend class GpuBuffer; friend class GlTextureBuffer; friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageAhwb; diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index b15c9c843..55bb418cf 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -13,7 +13,6 @@ #include "mediapipe/gpu/gpu_buffer_format.h" namespace mediapipe { -class GpuBuffer; namespace internal { template From 8b319e963a4aa46db4f9c9d34c29bdf035f8f9a5 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:34:07 -0800 Subject: [PATCH 037/137] Add comment explaining ViewProvider This was only documented via examples (e.g. ViewProvider), but it's better to explain it properly in the header where the base case is defined. PiperOrigin-RevId: 488813144 --- mediapipe/gpu/gpu_buffer_storage.h | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mediapipe/gpu/gpu_buffer_storage.h b/mediapipe/gpu/gpu_buffer_storage.h index 55bb418cf..19661d930 100644 --- a/mediapipe/gpu/gpu_buffer_storage.h +++ b/mediapipe/gpu/gpu_buffer_storage.h @@ -18,6 +18,28 @@ namespace internal { template struct types {}; +// This template must be specialized for each view type V. Each specialization +// should define a pair of virtual methods called GetReadView and GetWriteView, +// whose first argument is a types tag object. The result type and optional +// further arguments will depend on the view type. +// +// Example: +// template <> +// class ViewProvider { +// public: +// virtual ~ViewProvider() = default; +// virtual MyView GetReadView(types) const = 0; +// virtual MyView GetWriteView(types) = 0; +// }; +// +// The additional arguments and result type are reflected in GpuBuffer's +// GetReadView and GetWriteView methods. +// +// Using a type tag for the first argument allows the methods to be overloaded, +// so that a single storage can implement provider methods for multiple views. +// Since these methods are not template methods, they can (and should) be +// virtual, which allows storage classes to override them, enforcing that all +// storages providing a given view type implement the same interface. template class ViewProvider; From 1979801a92ad5a4d12bf2dd1ae6611c39de3096a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:34:35 -0800 Subject: [PATCH 038/137] Remove GlCalculatorHelperImpl; merge with GlCalculatorHelper Originally, there were multiple implementations of GlCalculatorHelperImpl, depending on the platform and underlying GL APIs. These have all been refactored into other components, and the remaining code in this class is unified and much reduced in size. We can get rid of this implementation detail now. PiperOrigin-RevId: 488813220 --- mediapipe/gpu/BUILD | 2 - mediapipe/gpu/gl_calculator_helper.cc | 163 +++++++++++++---- mediapipe/gpu/gl_calculator_helper.h | 26 ++- mediapipe/gpu/gl_calculator_helper_impl.h | 82 --------- .../gpu/gl_calculator_helper_impl_common.cc | 169 ------------------ 5 files changed, 146 insertions(+), 296 deletions(-) delete mode 100644 mediapipe/gpu/gl_calculator_helper_impl.h delete mode 100644 mediapipe/gpu/gl_calculator_helper_impl_common.cc diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 196de3076..b0c1c22b2 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -749,11 +749,9 @@ cc_library( name = "gl_calculator_helper", srcs = [ "gl_calculator_helper.cc", - "gl_calculator_helper_impl_common.cc", ], hdrs = [ "gl_calculator_helper.h", - "gl_calculator_helper_impl.h", ], linkopts = select({ "//conditions:default": [], diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index ba1423977..7d317e0f1 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -20,18 +20,32 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_service.h" namespace mediapipe { -// The constructor and destructor need to be defined here so that -// std::unique_ptr can see the full definition of GlCalculatorHelperImpl. -// In the header, it is an incomplete type. GlCalculatorHelper::GlCalculatorHelper() {} -GlCalculatorHelper::~GlCalculatorHelper() {} +GlCalculatorHelper::~GlCalculatorHelper() { + if (!Initialized()) return; + RunInGlContext( + [this] { + if (framebuffer_) { + glDeleteFramebuffers(1, &framebuffer_); + framebuffer_ = 0; + } + return absl::OkStatus(); + }, + /*calculator_context=*/nullptr) + .IgnoreError(); +} + +void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, + GpuResources* gpu_resources) { + gpu_resources_ = gpu_resources; + gl_context_ = gpu_resources_->gl_context(cc); +} absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { CHECK(cc); @@ -39,19 +53,16 @@ absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { RET_CHECK(gpu_service.IsAvailable()) << "GPU service not available. Did you forget to call " "GlCalculatorHelper::UpdateContract?"; - // TODO return error from impl_ (needs two-stage init) - impl_ = - absl::make_unique(cc, &gpu_service.GetObject()); + InitializeInternal(cc, &gpu_service.GetObject()); return absl::OkStatus(); } void GlCalculatorHelper::InitializeForTest(GpuSharedData* gpu_shared) { - impl_ = absl::make_unique( - nullptr, gpu_shared->gpu_resources.get()); + InitializeInternal(nullptr, gpu_shared->gpu_resources.get()); } void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { - impl_ = absl::make_unique(nullptr, gpu_resources); + InitializeInternal(nullptr, gpu_resources); } // static @@ -88,44 +99,109 @@ absl::Status GlCalculatorHelper::SetupInputSidePackets( return absl::OkStatus(); } +absl::Status GlCalculatorHelper::RunInGlContext( + std::function gl_func, + CalculatorContext* calculator_context) { + if (calculator_context) { + return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), + calculator_context->InputTimestamp()); + } else { + return gl_context_->Run(std::move(gl_func)); + } +} + absl::Status GlCalculatorHelper::RunInGlContext( std::function gl_func) { - if (!impl_) return absl::InternalError("helper not initialized"); + if (!Initialized()) return absl::InternalError("helper not initialized"); // TODO: Remove LegacyCalculatorSupport from MediaPipe OSS. auto calculator_context = LegacyCalculatorSupport::Scoped::current(); - return impl_->RunInGlContext(gl_func, calculator_context); + return RunInGlContext(gl_func, calculator_context); } -GLuint GlCalculatorHelper::framebuffer() const { return impl_->framebuffer(); } +GLuint GlCalculatorHelper::framebuffer() const { return framebuffer_; } + +void GlCalculatorHelper::CreateFramebuffer() { + // Our framebuffer will have a color attachment but no depth attachment, + // so it's important that the depth test be off. It is disabled by default, + // but we wanted to be explicit. + // TODO: move this to glBindFramebuffer? + glDisable(GL_DEPTH_TEST); + glGenFramebuffers(1, &framebuffer_); +} void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { - return impl_->BindFramebuffer(dst); +#ifdef __ANDROID__ + // On (some?) Android devices, attaching a new texture to the frame buffer + // does not seem to detach the old one. As a result, using that texture + // for texturing can produce incorrect output. See b/32091368 for details. + // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 + // or glFramebufferTexture2D with a texture ID of 0. + glBindFramebuffer(GL_FRAMEBUFFER, 0); +#endif + if (!framebuffer_) { + CreateFramebuffer(); + } + glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); + glViewport(0, 0, dst.width(), dst.height()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), + dst.name(), 0); + +#ifndef NDEBUG + GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + if (status != GL_FRAMEBUFFER_COMPLETE) { + VLOG(2) << "incomplete framebuffer: " << status; + } +#endif } -GlTexture GlCalculatorHelper::CreateSourceTexture( - const GpuBuffer& pixel_buffer) { - return impl_->CreateSourceTexture(pixel_buffer); +GlTexture GlCalculatorHelper::MapGpuBuffer(const GpuBuffer& gpu_buffer, + GlTextureView view) { + if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { + // TODO: do the params need to be reset here?? + glBindTexture(view.target(), view.name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + gpu_buffer.format(), view.plane(), GetGlVersion()); + gl_context_->SetStandardTextureParams(view.target(), + info.gl_internal_format); + glBindTexture(view.target(), 0); + } + + return GlTexture(std::move(view), gpu_buffer); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer) { + return CreateSourceTexture(gpu_buffer, 0); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& gpu_buffer, + int plane) { + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const ImageFrame& image_frame) { - return impl_->CreateSourceTexture(image_frame); -} - -GlTexture GlCalculatorHelper::CreateSourceTexture(const GpuBuffer& pixel_buffer, - int plane) { - return impl_->CreateSourceTexture(pixel_buffer, plane); + auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); } GpuBuffer GlCalculatorHelper::GpuBufferWithImageFrame( std::shared_ptr image_frame) { - return impl_->GpuBufferWithImageFrame(std::move(image_frame)); + return GpuBuffer( + std::make_shared(std::move(image_frame))); } GpuBuffer GlCalculatorHelper::GpuBufferCopyingImageFrame( const ImageFrame& image_frame) { - return impl_->GpuBufferCopyingImageFrame(image_frame); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); + // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only + // deals with absl::Status in MediaPipe OSS. + CHECK_OK(maybe_buffer.status()); + return GpuBuffer(std::move(maybe_buffer).value()); +#else + return GpuBuffer(GlTextureBuffer::Create(image_frame)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER } void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, @@ -136,23 +212,36 @@ void GlCalculatorHelper::GetGpuBufferDimensions(const GpuBuffer& pixel_buffer, *height = pixel_buffer.height(); } -GlTexture GlCalculatorHelper::CreateDestinationTexture(int output_width, - int output_height, +GlTexture GlCalculatorHelper::CreateDestinationTexture(int width, int height, GpuBufferFormat format) { - return impl_->CreateDestinationTexture(output_width, output_height, format); -} + if (!framebuffer_) { + CreateFramebuffer(); + } -GlContext& GlCalculatorHelper::GetGlContext() const { - return impl_->GetGlContext(); -} - -GlVersion GlCalculatorHelper::GetGlVersion() const { - return impl_->GetGlVersion(); + GpuBuffer gpu_buffer = + gpu_resources_->gpu_buffer_pool().GetBuffer(width, height, format); + return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); } GlTexture GlCalculatorHelper::CreateSourceTexture( const mediapipe::Image& image) { - return impl_->CreateSourceTexture(image.GetGpuBuffer()); + return CreateSourceTexture(image.GetGpuBuffer()); +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + std::shared_ptr view = + gpu_buffer_.GetReadView(); + auto copy = absl::make_unique(); + copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); + return copy; +} + +template <> +std::unique_ptr GlTexture::GetFrame() const { + view_->DoneWriting(); + return absl::make_unique(gpu_buffer_); } template <> diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 0a0cc16cb..727be7826 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -33,7 +33,6 @@ namespace mediapipe { -class GlCalculatorHelperImpl; class GlTexture; class GpuResources; struct GpuSharedData; @@ -161,15 +160,30 @@ class GlCalculatorHelper { // TODO: do we need an unbind method too? void BindFramebuffer(const GlTexture& dst); - GlContext& GetGlContext() const; + GlContext& GetGlContext() const { return *gl_context_; } - GlVersion GetGlVersion() const; + GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } // Check if the calculator helper has been previously initialized. - bool Initialized() { return impl_ != nullptr; } + bool Initialized() { return gpu_resources_ != nullptr; } private: - std::unique_ptr impl_; + void InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources); + + absl::Status RunInGlContext(std::function gl_func, + CalculatorContext* calculator_context); + + // Makes a GpuBuffer accessible as a texture in the GL context. + GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); + + // Create the framebuffer for rendering. + void CreateFramebuffer(); + + std::shared_ptr gl_context_; + + GLuint framebuffer_ = 0; + + GpuResources* gpu_resources_ = nullptr; }; // Represents an OpenGL texture, and is a 'view' into the memory pool. @@ -204,7 +218,7 @@ class GlTexture { explicit GlTexture(GlTextureView view, GpuBuffer gpu_buffer) : gpu_buffer_(std::move(gpu_buffer)), view_(std::make_shared(std::move(view))) {} - friend class GlCalculatorHelperImpl; + friend class GlCalculatorHelper; // We store the GpuBuffer to support GetFrame, and to ensure that the storage // outlives the view. GpuBuffer gpu_buffer_; diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h deleted file mode 100644 index 72b3265fe..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ -#define MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ - -#include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" - -#ifdef __OBJC__ -#import -#import -#endif // __OBJC__ - -#ifdef __ANDROID__ -#include "mediapipe/gpu/gl_texture_buffer_pool.h" -#endif - -namespace mediapipe { - -// This class implements the GlCalculatorHelper for iOS and Android. -// See GlCalculatorHelper for details on these methods. -class GlCalculatorHelperImpl { - public: - explicit GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources); - ~GlCalculatorHelperImpl(); - - absl::Status RunInGlContext(std::function gl_func, - CalculatorContext* calculator_context); - - GlTexture CreateSourceTexture(const ImageFrame& image_frame); - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer); - - // Note: multi-plane support is currently only available on iOS. - GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer, int plane); - - // Creates a framebuffer and returns the texture that it is bound to. - GlTexture CreateDestinationTexture(int output_width, int output_height, - GpuBufferFormat format); - - GpuBuffer GpuBufferWithImageFrame(std::shared_ptr image_frame); - GpuBuffer GpuBufferCopyingImageFrame(const ImageFrame& image_frame); - - GLuint framebuffer() const { return framebuffer_; } - void BindFramebuffer(const GlTexture& dst); - - GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } - - GlContext& GetGlContext() const; - - // For internal use. - static void ReadTexture(const GlTextureView& view, void* output, size_t size); - - private: - // Makes a GpuBuffer accessible as a texture in the GL context. - GlTexture MapGpuBuffer(const GpuBuffer& gpu_buffer, GlTextureView view); - - // Create the framebuffer for rendering. - void CreateFramebuffer(); - - std::shared_ptr gl_context_; - - GLuint framebuffer_ = 0; - - GpuResources& gpu_resources_; -}; - -} // namespace mediapipe - -#endif // MEDIAPIPE_GPU_GL_CALCULATOR_HELPER_IMPL_H_ diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc deleted file mode 100644 index 6311d8905..000000000 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/gpu/gl_calculator_helper_impl.h" -#include "mediapipe/gpu/gpu_buffer_format.h" -#include "mediapipe/gpu/gpu_shared_data_internal.h" -#include "mediapipe/gpu/image_frame_view.h" - -namespace mediapipe { - -GlCalculatorHelperImpl::GlCalculatorHelperImpl(CalculatorContext* cc, - GpuResources* gpu_resources) - : gpu_resources_(*gpu_resources) { - gl_context_ = gpu_resources_.gl_context(cc); -} - -GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} - -GlContext& GlCalculatorHelperImpl::GetGlContext() const { return *gl_context_; } - -absl::Status GlCalculatorHelperImpl::RunInGlContext( - std::function gl_func, - CalculatorContext* calculator_context) { - if (calculator_context) { - return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), - calculator_context->InputTimestamp()); - } else { - return gl_context_->Run(std::move(gl_func)); - } -} - -void GlCalculatorHelperImpl::CreateFramebuffer() { - // Our framebuffer will have a color attachment but no depth attachment, - // so it's important that the depth test be off. It is disabled by default, - // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? - glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); -} - -void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) { -#ifdef __ANDROID__ - // On (some?) Android devices, attaching a new texture to the frame buffer - // does not seem to detach the old one. As a result, using that texture - // for texturing can produce incorrect output. See b/32091368 for details. - // To fix this, we have to call either glBindFramebuffer with a FBO id of 0 - // or glFramebufferTexture2D with a texture ID of 0. - glBindFramebuffer(GL_FRAMEBUFFER, 0); -#endif - if (!framebuffer_) { - CreateFramebuffer(); - } - glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - glViewport(0, 0, dst.width(), dst.height()); - glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, dst.target(), - dst.name(), 0); - -#ifndef NDEBUG - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); - if (status != GL_FRAMEBUFFER_COMPLETE) { - VLOG(2) << "incomplete framebuffer: " << status; - } -#endif -} - -GlTexture GlCalculatorHelperImpl::MapGpuBuffer(const GpuBuffer& gpu_buffer, - GlTextureView view) { - if (gpu_buffer.format() != GpuBufferFormat::kUnknown) { - // TODO: do the params need to be reset here?? - glBindTexture(view.target(), view.name()); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - gpu_buffer.format(), view.plane(), GetGlVersion()); - gl_context_->SetStandardTextureParams(view.target(), - info.gl_internal_format); - glBindTexture(view.target(), 0); - } - - return GlTexture(std::move(view), gpu_buffer); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer) { - return CreateSourceTexture(gpu_buffer, 0); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const GpuBuffer& gpu_buffer, int plane) { - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(plane)); -} - -GlTexture GlCalculatorHelperImpl::CreateSourceTexture( - const ImageFrame& image_frame) { - auto gpu_buffer = GpuBufferCopyingImageFrame(image_frame); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetReadView(0)); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferWithImageFrame( - std::shared_ptr image_frame) { - return GpuBuffer( - std::make_shared(std::move(image_frame))); -} - -GpuBuffer GlCalculatorHelperImpl::GpuBufferCopyingImageFrame( - const ImageFrame& image_frame) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - auto maybe_buffer = CreateCVPixelBufferCopyingImageFrame(image_frame); - // Converts absl::StatusOr to absl::Status since CHECK_OK() currently only - // deals with absl::Status in MediaPipe OSS. - CHECK_OK(maybe_buffer.status()); - return GpuBuffer(std::move(maybe_buffer).value()); -#else - return GpuBuffer(GlTextureBuffer::Create(image_frame)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - view_->DoneWriting(); - std::shared_ptr view = - gpu_buffer_.GetReadView(); - auto copy = absl::make_unique(); - copy->CopyFrom(*view, ImageFrame::kDefaultAlignmentBoundary); - return copy; -} - -template <> -std::unique_ptr GlTexture::GetFrame() const { - view_->DoneWriting(); - return absl::make_unique(gpu_buffer_); -} - -GlTexture GlCalculatorHelperImpl::CreateDestinationTexture( - int width, int height, GpuBufferFormat format) { - if (!framebuffer_) { - CreateFramebuffer(); - } - - GpuBuffer gpu_buffer = - gpu_resources_.gpu_buffer_pool().GetBuffer(width, height, format); - return MapGpuBuffer(gpu_buffer, gpu_buffer.GetWriteView(0)); -} - -} // namespace mediapipe From 63e20896391dda07baa25733cc023db233945f8b Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:35:04 -0800 Subject: [PATCH 039/137] Deprecate a bunch of old stuff in GlCalculatorHelper PiperOrigin-RevId: 488813296 --- mediapipe/gpu/BUILD | 1 + mediapipe/gpu/gl_calculator_helper.h | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index b0c1c22b2..4fb59f1b5 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -784,6 +784,7 @@ cc_library( ":shader_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_cc_proto", + "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_contract", diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 727be7826..af897bbe9 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -17,6 +17,7 @@ #include +#include "absl/base/attributes.h" #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" @@ -61,6 +62,7 @@ class GlCalculatorHelper { // Can be used to initialize the helper outside of a calculator. Useful for // testing. void InitializeForTest(GpuResources* gpu_resources); + ABSL_DEPRECATED("Use InitializeForTest(GpuResources)") void InitializeForTest(GpuSharedData* gpu_shared); // This method can be called from GetContract to set up the needed GPU @@ -69,6 +71,7 @@ class GlCalculatorHelper { // This method can be called from FillExpectations to set the correct types // for the shared GL input side packet(s). + ABSL_DEPRECATED("Use UpdateContract") static absl::Status SetupInputSidePackets(PacketTypeSet* input_side_packets); // Execute the provided function within the helper's GL context. On some @@ -235,12 +238,14 @@ class GlTexture { // it is better to keep const-safety and accept having two versions of the // same thing. template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(const T& collection, const std::string& tag, int index) -> decltype(collection.Tag(tag)) { return collection.UsesTags() ? collection.Tag(tag) : collection.Index(index); } template +ABSL_DEPRECATED("Only for legacy calculators") auto TagOrIndex(T* collection, const std::string& tag, int index) -> decltype(collection->Tag(tag)) { return collection->UsesTags() ? collection->Tag(tag) @@ -248,12 +253,14 @@ auto TagOrIndex(T* collection, const std::string& tag, int index) } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(const T& collection, const std::string& tag, int index) { return collection.UsesTags() ? collection.HasTag(tag) : index < collection.NumEntries(); } template +ABSL_DEPRECATED("Only for legacy calculators") bool HasTagOrIndex(T* collection, const std::string& tag, int index) { return collection->UsesTags() ? collection->HasTag(tag) : index < collection->NumEntries(); From febfc2029b38411a1835175d0bf3a647684475d9 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 15 Nov 2022 18:35:32 -0800 Subject: [PATCH 040/137] Annotate plane argument PiperOrigin-RevId: 488813363 --- mediapipe/gpu/gl_texture_buffer.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index e57195a46..09703d89d 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -360,7 +360,7 @@ static std::shared_ptr ConvertToImageFrame( absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); buf->GetProducerContext()->Run([buf, &output] { - auto view = buf->GetReadView(internal::types{}, 0); + auto view = buf->GetReadView(internal::types{}, /*plane=*/0); ReadTexture(view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); @@ -389,8 +389,9 @@ static std::shared_ptr ConvertToCvPixelBuffer( buf->width(), buf->height(), buf->format()); buf->GetProducerContext()->Run([buf, &output] { TempGlFramebuffer framebuffer; - auto src = buf->GetReadView(internal::types{}, 0); - auto dst = output->GetWriteView(internal::types{}, 0); + auto src = buf->GetReadView(internal::types{}, /*plane=*/0); + auto dst = + output->GetWriteView(internal::types{}, /*plane=*/0); CopyGlTexture(src, dst); glFlush(); }); From f7aef677fc1830af167a4ae989b8ca5abcac485a Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 15 Nov 2022 18:59:06 -0800 Subject: [PATCH 041/137] Add running mode to all vision tasks PiperOrigin-RevId: 488816785 --- mediapipe/tasks/web/vision/core/BUILD | 25 +++++-- ...nning_mode.ts => vision_task_options.d.ts} | 27 ++++---- .../web/vision/core/vision_task_runner.ts | 66 +++++++++++++++++++ .../tasks/web/vision/gesture_recognizer/BUILD | 5 +- .../gesture_recognizer/gesture_recognizer.ts | 48 +++++++++----- .../gesture_recognizer_options.d.ts | 7 +- .../tasks/web/vision/hand_landmarker/BUILD | 5 +- .../vision/hand_landmarker/hand_landmarker.ts | 47 ++++++++----- .../hand_landmarker_options.d.ts | 7 +- .../tasks/web/vision/image_classifier/BUILD | 5 +- .../image_classifier/image_classifier.ts | 51 +++++++++----- .../image_classifier_options.d.ts | 7 +- .../tasks/web/vision/image_embedder/BUILD | 8 +-- .../vision/image_embedder/image_embedder.ts | 49 ++++++-------- .../image_embedder_options.d.ts | 15 +---- .../tasks/web/vision/object_detector/BUILD | 5 +- .../vision/object_detector/object_detector.ts | 45 +++++++++---- .../object_detector_options.d.ts | 7 +- 18 files changed, 281 insertions(+), 148 deletions(-) rename mediapipe/tasks/web/vision/core/{running_mode.ts => vision_task_options.d.ts} (58%) create mode 100644 mediapipe/tasks/web/vision/core/vision_task_runner.ts diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 7ab822b7c..8c405ae6e 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,11 +1,26 @@ # This package contains options shared by all MediaPipe Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_library( - name = "running_mode", - srcs = ["running_mode.ts"], - deps = ["//mediapipe/tasks/cc/core/proto:base_options_jspb_proto"], +mediapipe_ts_declaration( + name = "vision_task_options", + srcs = ["vision_task_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + ], +) + +mediapipe_ts_library( + name = "vision_task_runner", + srcs = ["vision_task_runner.ts"], + deps = [ + ":vision_task_options", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + ], ) diff --git a/mediapipe/tasks/web/vision/core/running_mode.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts similarity index 58% rename from mediapipe/tasks/web/vision/core/running_mode.ts rename to mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 1e9b1b9a7..8b9562e46 100644 --- a/mediapipe/tasks/web/vision/core/running_mode.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -14,23 +14,26 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {BaseOptions} from '../../../../tasks/web/core/base_options'; /** - * The running mode of a task. + * The two running modes of a video task. * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ export type RunningMode = 'image'|'video'; -/** Configues the `useStreamMode` option . */ -export function configureRunningMode( - options: {runningMode?: RunningMode}, - proto?: BaseOptionsProto): BaseOptionsProto { - proto = proto ?? new BaseOptionsProto(); - if ('runningMode' in options) { - const useStreamMode = options.runningMode === 'video'; - proto.setUseStreamMode(useStreamMode); - } - return proto; + +/** The options for configuring a MediaPipe vision task. */ +export declare interface VisionTaskOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The running mode of the task. Default to the image mode. + * Vision tasks have two running modes: + * 1) The image mode for processing single image inputs. + * 2) The video mode for processing decoded frames of a video. + */ + runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts new file mode 100644 index 000000000..372ce9ba7 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -0,0 +1,66 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; + +import {VisionTaskOptions} from './vision_task_options'; + +/** Base class for all MediaPipe Vision Tasks. */ +export abstract class VisionTaskRunner extends TaskRunner { + protected abstract baseOptions?: BaseOptionsProto|undefined; + + /** Configures the shared options of a vision task. */ + async setOptions(options: VisionTaskOptions): Promise { + this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + if ('runningMode' in options) { + const useStreamMode = + !!options.runningMode && options.runningMode !== 'image'; + this.baseOptions.setUseStreamMode(useStreamMode); + } + } + + /** Sends an image packet to the graph and awaits results. */ + protected abstract process(input: ImageSource, timestamp: number): T; + + /** Sends a single image to the graph and awaits results. */ + protected processImageData(image: ImageSource): T { + if (!!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with image mode. ' + + '\'runningMode\' must be set to \'image\'.'); + } + return this.process(image, performance.now()); + } + + /** Sends a single video frame to the graph and awaits results. */ + protected processVideoData(imageFrame: ImageSource, timestamp: number): T { + if (!this.baseOptions?.getUseStreamMode()) { + throw new Error( + 'Task is not initialized with video mode. ' + + '\'runningMode\' must be set to \'video\'.'); + } + return this.process(imageFrame, timestamp); + } +} + + diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index d67974a16..f2b668239 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -19,6 +19,7 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_jspb_proto", @@ -27,11 +28,10 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -47,5 +47,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 6c8072ff5..8e745534e 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -19,6 +19,7 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {GestureClassifierGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options_pb'; import {GestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options_pb'; import {HandGestureRecognizerGraphOptions} from '../../../../tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options_pb'; @@ -27,10 +28,9 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -64,7 +64,8 @@ FULL_IMAGE_RECT.setWidth(1); FULL_IMAGE_RECT.setHeight(1); /** Performs hand gesture recognition on images. */ -export class GestureRecognizer extends TaskRunner { +export class GestureRecognizer extends + VisionTaskRunner { private gestures: Category[][] = []; private landmarks: Landmark[][] = []; private worldLandmarks: Landmark[][] = []; @@ -156,10 +157,14 @@ export class GestureRecognizer extends TaskRunner { this.handGestureRecognizerGraphOptions); this.initDefaults(); + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); } /** @@ -171,12 +176,8 @@ export class GestureRecognizer extends TaskRunner { * * @param options The options for the gesture recognizer. */ - async setOptions(options: GestureRecognizerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: GestureRecognizerOptions): Promise { + await super.setOptions(options); if ('numHands' in options) { this.handDetectorGraphOptions.setNumHands( @@ -233,12 +234,27 @@ export class GestureRecognizer extends TaskRunner { /** * Performs gesture recognition on the provided single image and waits * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image A single image to process. * @return The detected gestures. */ - recognize(imageSource: ImageSource, timestamp: number = performance.now()): + recognize(image: ImageSource): GestureRecognizerResult { + return this.processImageData(image); + } + + /** + * Performs gesture recognition on the provided video frame and waits + * synchronously for the response. + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The detected gestures. + */ + recognizeForVideo(videoFrame: ImageSource, timestamp: number): + GestureRecognizerResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the gesture recognition and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): GestureRecognizerResult { this.gestures = []; this.landmarks = []; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts index 45601a74c..dd8fc9548 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -14,14 +14,11 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Gesture Recognizer Task */ -export declare interface GestureRecognizerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface GestureRecognizerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the GestureRecognizer. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 25c70e0a5..36f1d7eb7 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -19,14 +19,14 @@ mediapipe_ts_library( "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:rect_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -41,5 +41,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index af10305b2..0aba5c82c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -19,14 +19,14 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationList} from '../../../../framework/formats/classification_pb'; import {LandmarkList, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {NormalizedRect} from '../../../../framework/formats/rect_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detector/proto/hand_detector_graph_options_pb'; import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -58,7 +58,7 @@ FULL_IMAGE_RECT.setWidth(1); FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ -export class HandLandmarker extends TaskRunner { +export class HandLandmarker extends VisionTaskRunner { private landmarks: Landmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -138,10 +138,14 @@ export class HandLandmarker extends TaskRunner { this.options.setHandDetectorGraphOptions(this.handDetectorGraphOptions); this.initDefaults(); + } - // Disables the automatic render-to-screen code, which allows for pure - // CPU processing. - this.setAutoRenderToScreen(false); + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); } /** @@ -153,12 +157,8 @@ export class HandLandmarker extends TaskRunner { * * @param options The options for the hand landmarker. */ - async setOptions(options: HandLandmarkerOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: HandLandmarkerOptions): Promise { + await super.setOptions(options); // Configure hand detector options. if ('numHands' in options) { @@ -186,12 +186,27 @@ export class HandLandmarker extends TaskRunner { /** * Performs hand landmarks detection on the provided single image and waits * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The detected hand landmarks. */ - detect(imageSource: ImageSource, timestamp: number = performance.now()): + detect(image: ImageSource): HandLandmarkerResult { + return this.processImageData(image); + } + + /** + * Performs hand landmarks detection on the provided video frame and waits + * synchronously for the response. + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The detected hand landmarks. + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): + HandLandmarkerResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the hand landmarker graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): HandLandmarkerResult { this.landmarks = []; this.worldLandmarks = []; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts index 53ad9440a..fe79b7089 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_options.d.ts @@ -14,13 +14,10 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe HandLandmarker Task */ -export declare interface HandLandmarkerOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface HandLandmarkerOptions extends VisionTaskOptions { /** * The maximum number of hands can be detected by the HandLandmarker. * Defaults to 1. diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 8506f3574..e7e830332 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -16,15 +16,15 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -39,5 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 5d60e4a21..0011e9c55 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -17,12 +17,12 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -42,7 +42,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs classification on images. */ -export class ImageClassifier extends TaskRunner { +export class ImageClassifier extends VisionTaskRunner { private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); @@ -105,6 +105,14 @@ export class ImageClassifier extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the image classifier. * @@ -114,28 +122,39 @@ export class ImageClassifier extends TaskRunner { * * @param options The options for the image classifier. */ - async setOptions(options: ImageClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: ImageClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Performs image classification on the provided image and waits synchronously - * for the response. + * Performs image classification on the provided single image and waits + * synchronously for the response. * - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The classification result of the image */ - classify(imageSource: ImageSource, timestamp?: number): + classify(image: ImageSource): ImageClassifierResult { + return this.processImageData(image); + } + + /** + * Performs image classification on the provided video frame and waits + * synchronously for the response. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The classification result of the image + */ + classifyForVideo(videoFrame: ImageSource, timestamp: number): + ImageClassifierResult { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the image classification graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts index a5f5c2386..c1141d28f 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as ImageClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Ooptions to configure the image classifier task. */ +export declare interface ImageClassifierOptions extends ClassifierOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index 13ff2e4d6..ce1c25700 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -16,15 +16,15 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/tasks/web/vision/core:running_mode", + "//mediapipe/tasks/web/vision/core:vision_task_options", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -39,6 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/vision/core:running_mode", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 91d9b5119..d17bc72fa 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -17,13 +17,12 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {configureRunningMode} from '../../../../tasks/web/vision/core/running_mode'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -43,7 +42,7 @@ export * from './image_embedder_result'; export {ImageSource}; // Used in the public API /** Performs embedding extraction on images. */ -export class ImageEmbedder extends TaskRunner { +export class ImageEmbedder extends VisionTaskRunner { private readonly options = new ImageEmbedderGraphOptions(); private embeddings: ImageEmbedderResult = {embeddings: []}; @@ -105,6 +104,14 @@ export class ImageEmbedder extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the image embedder. * @@ -114,24 +121,16 @@ export class ImageEmbedder extends TaskRunner { * * @param options The options for the image embedder. */ - async setOptions(options: ImageEmbedderOptions): Promise { - let baseOptionsProto = this.options.getBaseOptions(); - if (options.baseOptions) { - baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, baseOptionsProto); - } - baseOptionsProto = configureRunningMode(options, baseOptionsProto); - this.options.setBaseOptions(baseOptionsProto); - + override async setOptions(options: ImageEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } /** - * Performs embedding extraction on the provided image and waits synchronously - * for the response. + * Performs embedding extraction on the provided single image and waits + * synchronously for the response. * * Only use this method when the `useStreamMode` option is not set or * expliclity set to `false`. @@ -140,12 +139,7 @@ export class ImageEmbedder extends TaskRunner { * @return The classification result of the image */ embed(image: ImageSource): ImageEmbedderResult { - if (!!this.options.getBaseOptions()?.getUseStreamMode()) { - throw new Error( - 'Task is not initialized with image mode. ' + - '\'runningMode\' must be set to \'image\'.'); - } - return this.performEmbeddingExtraction(image, performance.now()); + return this.processImageData(image); } /** @@ -160,16 +154,11 @@ export class ImageEmbedder extends TaskRunner { */ embedForVideo(imageFrame: ImageSource, timestamp: number): ImageEmbedderResult { - if (!this.options.getBaseOptions()?.getUseStreamMode()) { - throw new Error( - 'Task is not initialized with video mode. ' + - '\'runningMode\' must be set to \'video\' or \'live_stream\'.'); - } - return this.performEmbeddingExtraction(imageFrame, timestamp); + return this.processVideoData(imageFrame, timestamp); } - /** Runs the embedding extractio and blocks on the response. */ - private performEmbeddingExtraction(image: ImageSource, timestamp: number): + /** Runs the embedding extraction and blocks on the response. */ + protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. this.addGpuBufferAsImageToStream( diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts index 4d795d0d8..10000825c 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts @@ -15,17 +15,8 @@ */ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; -import {RunningMode} from '../../../../tasks/web/vision/core/running_mode'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** The options for configuring a MediaPipe image embedder task. */ -export declare interface ImageEmbedderOptions extends EmbedderOptions { - /** - * The running mode of the task. Default to the image mode. - * Image embedder has three running modes: - * 1) The image mode for embedding image on single image inputs. - * 2) The video mode for embedding image on the decoded frames of a video. - * 3) The live stream mode for embedding image on the live stream of input - * data, such as from camera. - */ - runningMode?: RunningMode; -} +export declare interface ImageEmbedderOptions extends EmbedderOptions, + VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index a74dc9211..0975a9fd4 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -17,11 +17,11 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:detection_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) @@ -35,5 +35,6 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e17a42020..e6cbd8627 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -17,10 +17,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; // Placeholder for internal dependency on trusted resource url @@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API // tslint:disable:jspb-use-builder-pattern /** Performs object detection on images. */ -export class ObjectDetector extends TaskRunner { +export class ObjectDetector extends VisionTaskRunner { private detections: Detection[] = []; private readonly options = new ObjectDetectorOptionsProto(); @@ -103,6 +103,14 @@ export class ObjectDetector extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the object detector. * @@ -112,12 +120,8 @@ export class ObjectDetector extends TaskRunner { * * @param options The options for the object detector. */ - async setOptions(options: ObjectDetectorOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: ObjectDetectorOptions): Promise { + await super.setOptions(options); // Note that we have to support both JSPB and ProtobufJS, hence we // have to expliclity clear the values instead of setting them to @@ -158,12 +162,27 @@ export class ObjectDetector extends TaskRunner { /** * Performs object detection on the provided single image and waits * synchronously for the response. - * @param imageSource An image source to process. - * @param timestamp The timestamp of the current frame, in ms. If not - * provided, defaults to `performance.now()`. + * @param image An image to process. * @return The list of detected objects */ - detect(imageSource: ImageSource, timestamp?: number): Detection[] { + detect(image: ImageSource): Detection[] { + return this.processImageData(image); + } + + /** + * Performs object detection on the provided vidoe frame and waits + * synchronously for the response. + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @return The list of detected objects + */ + detectForVideo(videoFrame: ImageSource, timestamp: number): Detection[] { + return this.processVideoData(videoFrame, timestamp); + } + + /** Runs the object detector graph and blocks on the response. */ + protected override process(imageSource: ImageSource, timestamp: number): + Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; this.addGpuBufferAsImageToStream( diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index eec12cf17..1d20ce1e2 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -14,13 +14,10 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export interface ObjectDetectorOptions extends VisionTaskOptions { /** * The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. From dc9578d2263f99c64ab503fb50b727330c7b06e0 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 08:27:30 -0800 Subject: [PATCH 042/137] Internal change PiperOrigin-RevId: 488946809 --- mediapipe/tasks/cc/core/BUILD | 3 +++ mediapipe/tasks/cc/vision/image_segmenter/BUILD | 3 +++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 291dd29fe..f14457073 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,6 +22,9 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], + visibility = [ + "//mediapipe/tasks:internal", + ], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 4c43a07f5..7206a45ea 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -20,6 +20,9 @@ cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], + visibility = [ + "//mediapipe/tasks:internal", + ], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", From cdd44e77b75da34287938dfe222e220a780f98c7 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 10:03:11 -0800 Subject: [PATCH 043/137] Internal change PiperOrigin-RevId: 488969539 --- .../python/vision/gesture_recognizer/gesture_recognizer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 9bac22133..8a6e474d7 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -93,7 +93,7 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.5): + def _test_accuracy(self, model, threshold=0.25): # Test on _train_data because of our limited dataset size _, accuracy = model.evaluate(self._train_data) tf.compat.v1.logging.info(f'train accuracy: {accuracy}') From 512a531b9e09a681b1a6ee02a08ddf290a48a0f9 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 10:30:23 -0800 Subject: [PATCH 044/137] Internal change PiperOrigin-RevId: 488977390 --- third_party/external_files.bzl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 1f0b00289..72ca95e66 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -90,8 +90,8 @@ def external_files(): http_file( name = "com_google_mediapipe_canned_gesture_classifier_tflite", - sha256 = "2fc7e279966a7a9e15fc869223793e390791fc61fdc0062f9bc7d0eef6be98a2", - urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668124189331326"], + sha256 = "ee121d85979de1b86126faabb0a0f4d2e4039c3e33e2cd687db50571001b24d0", + urls = ["https://storage.googleapis.com/mediapipe-assets/canned_gesture_classifier.tflite?generation=1668550473107417"], ) http_file( @@ -294,8 +294,8 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_tflite", - sha256 = "54abe78de1d1cd5e3cdaa0dab01db18e3ec7e09a76e7c3b5fa278572f7a60977", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668124192126494"], + sha256 = "927e4f6cbe6451da6b4fd1485e2576a6f8dbd95062666661cbd9dea893c41d01", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder.tflite?generation=1668550476472972"], ) http_file( @@ -990,14 +990,14 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_keras_metadata_pb", - sha256 = "24268b69429be4e307f9ab099ba20d1de7c40e4191a53f6a92dcbbd97a7047d3", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668124196996131"], + sha256 = "c76b856101e2284293a5e5963b7c445e407a0b3e56ec63eb78f64d883e51e3aa", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/keras_metadata.pb?generation=1668550482128410"], ) http_file( name = "com_google_mediapipe_gesture_embedder_saved_model_pb", - sha256 = "f3a2870ba3ef537a4f6a5889ffc5b7061ad98f9fd96ec431a62116892f100659", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668124199460071"], + sha256 = "0082d37c5b85487fbf553e00a63f640945faf3da2d561a5f5a24c3194fecda6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/saved_model.pb?generation=1668550484904822"], ) http_file( @@ -1038,12 +1038,12 @@ def external_files(): http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_data-00000-of-00001", - sha256 = "9fdb750c4bac67afb9c0f61916510930b496cc47e7f89449aee2bec6b6ed0af8", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668124201918980"], + sha256 = "c156c9654c9ffb1091bb9f06c71080bd1e428586276d3f39c33fbab27fe0522d", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.data-00000-of-00001?generation=1668550487965052"], ) http_file( name = "com_google_mediapipe_gesture_embedder_variables_variables_index", - sha256 = "3ccbcee9488fec4627d496abd9837997276b32b839a4d0ae434bd806fe380b86", - urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668124204353848"], + sha256 = "76ea482b8da6bdb3d65d3b2ea989c1699c9fa0d6df0cb6d80863d1dc6fe7c4bd", + urls = ["https://storage.googleapis.com/mediapipe-assets/gesture_embedder/variables/variables.index?generation=1668550490691823"], ) From 74474d859e0891fc97b4038b7b8ecb9420c4b522 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 13:58:21 -0800 Subject: [PATCH 045/137] Update image_classifier demo with new ImageClassifierOption changes PiperOrigin-RevId: 489031381 --- .../vision/image_classifier/image_classifier_demo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py index 5832ea53a..f382e28aa 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -61,12 +61,14 @@ def run(data_dir: str, export_dir: str, data = image_classifier.Dataset.from_folder(data_dir) train_data, rest_data = data.split(0.8) validation_data, test_data = rest_data.split(0.5) - + model_options = image_classifier.ImageClassifierOptions( + supported_model=model_spec, + hparams=image_classifier.HParams(export_dir=export_dir), + ) model = image_classifier.ImageClassifier.create( - model_spec=model_spec, train_data=train_data, validation_data=validation_data, - hparams=image_classifier.HParams(model_dir=export_dir)) + options=model_options) _, acc = model.evaluate(test_data) print('Test accuracy: %f' % acc) @@ -83,7 +85,6 @@ def run(data_dir: str, export_dir: str, raise ValueError(f'Quantization: {quantization} is not recognized') model.export_model(quantization_config=quantization_config) - model.export_labels(export_dir) def main(_) -> None: From 3cdf0f65365c5f13673034e9abf9ebbbef90c0b2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 16 Nov 2022 14:36:14 -0800 Subject: [PATCH 046/137] Fix a crash that occurred when a model returns fewer vector elements than before PiperOrigin-RevId: 489041814 --- mediapipe/web/graph_runner/wasm_mediapipe_lib.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts index 9ecf094ca..5f8040a33 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts @@ -406,7 +406,7 @@ export class WasmMediaPipeLib { */ setVectorListener( outputStreamName: string, callbackFcn: (data: T[]) => void) { - const buffer: T[] = []; + let buffer: T[] = []; this.wasmModule.vectorListeners = this.wasmModule.vectorListeners || {}; this.wasmModule.vectorListeners[outputStreamName] = (data: unknown, index: number, length: number) => { @@ -419,6 +419,7 @@ export class WasmMediaPipeLib { // the underlying data elements once we leave the scope of the // listener. callbackFcn(buffer); + buffer = []; } }; } From b6b72d5e4e9b8a3b176331489cae78cc3e9c77df Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 15:55:06 -0800 Subject: [PATCH 047/137] Add MuxCalculator test case where graph is being closed while SELECT has not been received. PiperOrigin-RevId: 489061902 --- .../calculators/core/mux_calculator_test.cc | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index 86d2fab42..a3ac8a27a 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -398,6 +398,99 @@ TEST(MuxCalculatorTest, HandleTimestampBoundUpdates) { MP_ASSERT_OK(graph.WaitUntilDone()); } +TEST(MuxCalculatorTest, HandlesCloseGracefully) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + EXPECT_TRUE(output_packets.empty()); +} + +TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie( + R"pb( + # This is required in order for EXPECT_DEATH to work everywhere + executor { name: "" type: "ApplicationThreadExecutor" } + + input_stream: "select" + input_stream: "value_0" + input_stream: "value_1" + node { + calculator: "MuxCalculator" + input_stream: "SELECT:select" + input_stream: "INPUT:0:value_0" + input_stream: "INPUT:1:value_1" + output_stream: "OUTPUT:output" + input_stream_handler { + input_stream_handler: "DefaultInputStreamHandler" + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + // Observe packets. + std::vector output_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output", + [&output_packets](const Packet& p) -> absl::Status { + output_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_timestamp_bounds=*/true)); + + // Start graph. + MP_ASSERT_OK(graph.StartRun({})); + + // Add single packet wait for completion and close. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "value_0", MakePacket(0).At(Timestamp(1000)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + // Currently MuxCalculator crashes with a correct packet set from + // DefaultInputStreamHandler. The SELECT packet is missing at Timestamp 1000, + // and an empty packet is the correct representation of that. + EXPECT_DEATH( + { + (void)graph.CloseAllInputStreams(); + (void)graph.WaitUntilDone(); + }, + "Check failed: payload_"); +} + } // namespace } // namespace mediapipe From 90eb4a19d8593d366ddf7aed894d8bb1161da39c Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 16 Nov 2022 18:11:00 -0800 Subject: [PATCH 048/137] Internal change PiperOrigin-RevId: 489088227 --- mediapipe/framework/deps/status_builder.cc | 23 ++--------- mediapipe/framework/deps/status_builder.h | 19 +-------- .../framework/deps/status_builder_test.cc | 39 ------------------- mediapipe/framework/deps/status_macros.h | 29 +++++++------- 4 files changed, 20 insertions(+), 90 deletions(-) diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index 70775949d..0202b8689 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -97,39 +97,24 @@ absl::Status StatusBuilder::Impl::JoinMessageToStatus() { }()); } -StatusBuilder::Impl::Impl(const absl::Status& status, const char* file, - int line) - : status(status), line(line), file(file), stream() {} - -StatusBuilder::Impl::Impl(absl::Status&& status, const char* file, int line) - : status(std::move(status)), line(line), file(file), stream() {} - StatusBuilder::Impl::Impl(const absl::Status& status, mediapipe::source_location location) - : status(status), - line(location.line()), - file(location.file_name()), - stream() {} + : status(status), location(location), stream() {} StatusBuilder::Impl::Impl(absl::Status&& status, mediapipe::source_location location) - : status(std::move(status)), - line(location.line()), - file(location.file_name()), - stream() {} + : status(std::move(status)), location(location), stream() {} StatusBuilder::Impl::Impl(const Impl& other) : status(other.status), - line(other.line), - file(other.file), + location(other.location), no_logging(other.no_logging), stream(other.stream.str()), join_style(other.join_style) {} StatusBuilder::Impl& StatusBuilder::Impl::operator=(const Impl& other) { status = other.status; - line = other.line; - file = other.file; + location = other.location; no_logging = other.no_logging; stream = std::ostringstream(other.stream.str()); join_style = other.join_style; diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index d2e40d575..ae11699d2 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -60,17 +60,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { ? nullptr : std::make_unique(absl::Status(code, ""), location)) {} - StatusBuilder(const absl::Status& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(original_status, file, line)) {} - - StatusBuilder(absl::Status&& original_status, const char* file, int line) - : impl_(original_status.ok() - ? nullptr - : std::make_unique(std::move(original_status), file, - line)) {} - bool ok() const { return !impl_; } StatusBuilder& SetAppend() &; @@ -109,8 +98,6 @@ class ABSL_MUST_USE_RESULT StatusBuilder { kPrepend, }; - Impl(const absl::Status& status, const char* file, int line); - Impl(absl::Status&& status, const char* file, int line); Impl(const absl::Status& status, mediapipe::source_location location); Impl(absl::Status&& status, mediapipe::source_location location); Impl(const Impl&); @@ -120,10 +107,8 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // The status that the result will be based on. absl::Status status; - // The line to record if this file is logged. - int line; - // Not-owned: The file to record if this status is logged. - const char* file; + // The source location to record if this file is logged. + mediapipe::source_location location; // Logging disabled if true. bool no_logging = false; // The additional messages added with `<<`. This is nullptr when status_ is diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index 560acd3c6..f517bb909 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -33,21 +33,6 @@ TEST(StatusBuilder, OkStatusRvalue) { ASSERT_EQ(status, absl::OkStatus()); } -TEST(StatusBuilder, OkStatusFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::OkStatus(), "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - -TEST(StatusBuilder, OkStatusFileAndLineLvalueStatus) { - const auto original_status = absl::OkStatus(); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_EQ(status, absl::OkStatus()); -} - TEST(StatusBuilder, AnnotateMode) { absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, "original message"), @@ -60,30 +45,6 @@ TEST(StatusBuilder, AnnotateMode) { "original message; annotated message1 annotated message2"); } -TEST(StatusBuilder, AnnotateModeFileAndLineRvalueStatus) { - absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, - "original message"), - "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - -TEST(StatusBuilder, AnnotateModeFileAndLineLvalueStatus) { - const auto original_status = - absl::Status(absl::StatusCode::kNotFound, "original message"); - absl::Status status = StatusBuilder(original_status, "hello.cc", 1234) - << "annotated message1 " - << "annotated message2"; - ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_EQ(status.message(), - "original message; annotated message1 annotated message2"); -} - TEST(StatusBuilder, PrependModeLvalue) { StatusBuilder builder( absl::Status(absl::StatusCode::kInvalidArgument, "original message"), diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index 757d99392..92bbf0b84 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -81,11 +81,11 @@ // MP_RETURN_IF_ERROR(foo.Method(args...)); // return absl::OkStatus(); // } -#define MP_RETURN_IF_ERROR(expr) \ - STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ - status_macro_internal_adaptor = {(expr), __FILE__, __LINE__}) { \ - } else /* NOLINT */ \ +#define MP_RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (mediapipe::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr), MEDIAPIPE_LOC}) { \ + } else /* NOLINT */ \ return status_macro_internal_adaptor.Consume() // Executes an expression `rexpr` that returns a `absl::StatusOr`. On @@ -156,14 +156,14 @@ return mediapipe::StatusBuilder( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__)) + MEDIAPIPE_LOC)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ mediapipe::StatusBuilder _( \ std::move(STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__)) \ .status(), \ - __FILE__, __LINE__); \ + MEDIAPIPE_LOC); \ (void)_; /* error_expression is allowed to not use this variable */ \ return (error_expression)) #define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ @@ -201,18 +201,17 @@ namespace status_macro_internal { // that declares a variable. class StatusAdaptorForMacros { public: - StatusAdaptorForMacros(const absl::Status& status, const char* file, int line) - : builder_(status, file, line) {} + StatusAdaptorForMacros(const absl::Status& status, source_location location) + : builder_(status, location) {} - StatusAdaptorForMacros(absl::Status&& status, const char* file, int line) - : builder_(std::move(status), file, line) {} + StatusAdaptorForMacros(absl::Status&& status, source_location location) + : builder_(std::move(status), location) {} - StatusAdaptorForMacros(const StatusBuilder& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(const StatusBuilder& builder, + source_location /*location*/) : builder_(builder) {} - StatusAdaptorForMacros(StatusBuilder&& builder, const char* /* file */, - int /* line */) + StatusAdaptorForMacros(StatusBuilder&& builder, source_location /*location*/) : builder_(std::move(builder)) {} StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; From e66e88802c42610441dd9acfd193a9ff8e022231 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 16 Nov 2022 18:32:59 -0800 Subject: [PATCH 049/137] Change NPM Bundle to ESM PiperOrigin-RevId: 489091370 --- mediapipe/tasks/web/BUILD | 80 ++++++------------- mediapipe/tasks/web/audio.ts | 8 +- mediapipe/tasks/web/audio/BUILD | 12 --- mediapipe/tasks/web/audio/index.ts | 17 ---- mediapipe/tasks/web/package.json | 12 +-- mediapipe/tasks/web/rollup.config.iife.mjs | 21 ----- ...ollup.config.cjs.mjs => rollup.config.mjs} | 4 +- mediapipe/tasks/web/text.ts | 10 ++- mediapipe/tasks/web/text/BUILD | 13 --- mediapipe/tasks/web/text/index.ts | 18 ----- mediapipe/tasks/web/vision.ts | 22 ++++- mediapipe/tasks/web/vision/BUILD | 16 ---- mediapipe/tasks/web/vision/index.ts | 21 ----- 13 files changed, 67 insertions(+), 187 deletions(-) delete mode 100644 mediapipe/tasks/web/audio/index.ts delete mode 100644 mediapipe/tasks/web/rollup.config.iife.mjs rename mediapipe/tasks/web/{rollup.config.cjs.mjs => rollup.config.mjs} (86%) delete mode 100644 mediapipe/tasks/web/text/index.ts delete mode 100644 mediapipe/tasks/web/vision/index.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index b8777e785..e9703e37a 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -24,35 +24,25 @@ mediapipe_files(srcs = [ mediapipe_ts_library( name = "audio_lib", srcs = ["audio.ts"], - deps = ["//mediapipe/tasks/web/audio:audio_lib"], -) - -rollup_bundle( - name = "audio_cjs_bundle", - config_file = "rollup.config.cjs.mjs", - entry_point = "audio.ts", - format = "cjs", - output_dir = False, deps = [ - ":audio_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", + "//mediapipe/tasks/web/audio/audio_classifier", ], ) rollup_bundle( - name = "audio_iife_bundle", - config_file = "rollup.config.iife.mjs", + name = "audio_bundle", + config_file = "rollup.config.mjs", entry_point = "audio.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":audio_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -69,8 +59,7 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", - ":audio_cjs_bundle", - ":audio_iife_bundle", + ":audio_bundle", ], ) @@ -79,35 +68,26 @@ pkg_npm( mediapipe_ts_library( name = "text_lib", srcs = ["text.ts"], - deps = ["//mediapipe/tasks/web/text:text_lib"], -) - -rollup_bundle( - name = "text_cjs_bundle", - config_file = "rollup.config.cjs.mjs", - entry_point = "text.ts", - format = "cjs", - output_dir = False, deps = [ - ":text_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", ], ) rollup_bundle( - name = "text_iife_bundle", - config_file = "rollup.config.iife.mjs", + name = "text_bundle", + config_file = "rollup.config.mjs", entry_point = "text.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":text_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -124,8 +104,7 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", - ":text_cjs_bundle", - ":text_iife_bundle", + ":text_bundle", ], ) @@ -134,35 +113,29 @@ pkg_npm( mediapipe_ts_library( name = "vision_lib", srcs = ["vision.ts"], - deps = ["//mediapipe/tasks/web/vision:vision_lib"], -) - -rollup_bundle( - name = "vision_cjs_bundle", - config_file = "rollup.config.cjs.mjs", - entry_point = "vision.ts", - format = "cjs", - output_dir = False, deps = [ - ":vision_lib", - "@npm//@rollup/plugin-commonjs", - "@npm//@rollup/plugin-node-resolve", - "@npm//@rollup/plugin-replace", + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/object_detector", ], ) rollup_bundle( - name = "vision_iife_bundle", - config_file = "rollup.config.iife.mjs", + name = "vision_bundle", + config_file = "rollup.config.mjs", entry_point = "vision.ts", - format = "iife", + format = "esm", output_dir = False, + sourcemap = "false", deps = [ ":vision_lib", "@npm//@rollup/plugin-commonjs", "@npm//@rollup/plugin-node-resolve", "@npm//@rollup/plugin-replace", "@npm//@rollup/plugin-terser", + "@npm//google-protobuf", ], ) @@ -179,7 +152,6 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", - ":vision_cjs_bundle", - ":vision_iife_bundle", + ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 4a3b80594..764fd8393 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,4 +14,10 @@ * limitations under the License. */ -export * from '../../tasks/web/audio/index'; +import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const AudioClassifier = AudioClassifierImpl; + +export {AudioClassifier}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..69b0408e9 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1,13 +1 @@ # This contains the MediaPipe Audio Tasks. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_library( - name = "audio_lib", - srcs = ["index.ts"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - ], -) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts deleted file mode 100644 index a5083b326..000000000 --- a/mediapipe/tasks/web/audio/index.ts +++ /dev/null @@ -1,17 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; diff --git a/mediapipe/tasks/web/package.json b/mediapipe/tasks/web/package.json index 1870f18a6..89c9a599e 100644 --- a/mediapipe/tasks/web/package.json +++ b/mediapipe/tasks/web/package.json @@ -2,20 +2,10 @@ "name": "@mediapipe/tasks-__NAME__", "version": "__VERSION__", "description": "__DESCRIPTION__", - "main": "__NAME___cjs_bundle.js", - "module": "__NAME___cjs_bundle.js", - "jsdeliver": "__NAME___iife_bundle.js", - "exports": { - ".": "./__NAME___cjs_bundle.js", - "./loader": "./wasm/__NAME___wasm_internal.js", - "./wasm": "./wasm/__NAME___wasm_internal.wasm" - }, + "main": "__NAME___bundle.js", "author": "mediapipe@google.com", "license": "Apache-2.0", "types": "__TYPES__", - "dependencies": { - "google-protobuf": "^3.21.2" - }, "homepage": "http://mediapipe.dev", "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ] } diff --git a/mediapipe/tasks/web/rollup.config.iife.mjs b/mediapipe/tasks/web/rollup.config.iife.mjs deleted file mode 100644 index 1320927aa..000000000 --- a/mediapipe/tasks/web/rollup.config.iife.mjs +++ /dev/null @@ -1,21 +0,0 @@ -import resolve from '@rollup/plugin-node-resolve'; -import commonjs from '@rollup/plugin-commonjs'; -import terser from '@rollup/plugin-terser'; -import replace from '@rollup/plugin-replace'; - -export default { - output: { - name: 'bundle', - sourcemap: false - }, - plugins: [ - // Workaround for https://github.com/protocolbuffers/protobuf-javascript/issues/151 - replace({ - 'var calculator_options_pb = {};': 'var calculator_options_pb = {}; var mediapipe_framework_calculator_options_pb = calculator_options_pb;', - delimiters: ['', ''] - }), - resolve({browser: true}), - commonjs(), - terser() - ] -} diff --git a/mediapipe/tasks/web/rollup.config.cjs.mjs b/mediapipe/tasks/web/rollup.config.mjs similarity index 86% rename from mediapipe/tasks/web/rollup.config.cjs.mjs rename to mediapipe/tasks/web/rollup.config.mjs index 5f8ca1848..e633bf702 100644 --- a/mediapipe/tasks/web/rollup.config.cjs.mjs +++ b/mediapipe/tasks/web/rollup.config.mjs @@ -1,6 +1,7 @@ import resolve from '@rollup/plugin-node-resolve'; import commonjs from '@rollup/plugin-commonjs'; import replace from '@rollup/plugin-replace'; +import terser from '@rollup/plugin-terser'; export default { plugins: [ @@ -10,6 +11,7 @@ export default { delimiters: ['', ''] }), resolve(), - commonjs() + commonjs(), + terser() ] } diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index f8a0b6457..39d101237 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,4 +14,12 @@ * limitations under the License. */ -export * from '../../tasks/web/text/index'; +import {TextClassifier as TextClassifierImpl} from '../../tasks/web/text/text_classifier/text_classifier'; +import {TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/text_embedder/text_embedder'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const TextClassifier = TextClassifierImpl; +const TextEmbedder = TextEmbedderImpl; + +export {TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 4b465b0f5..edd23c7d4 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1,14 +1 @@ # This contains the MediaPipe Text Tasks. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_library( - name = "text_lib", - srcs = ["index.ts"], - deps = [ - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], -) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts deleted file mode 100644 index d50db209c..000000000 --- a/mediapipe/tasks/web/text/index.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../../tasks/web/text/text_classifier/text_classifier'; -export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 6ff8f725b..4e4fab43f 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,4 +14,24 @@ * limitations under the License. */ -export * from '../../tasks/web/vision/index'; +import {GestureRecognizer as GestureRecognizerImpl} from '../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +import {HandLandmarker as HandLandmarkerImpl} from '../../tasks/web/vision/hand_landmarker/hand_landmarker'; +import {ImageClassifier as ImageClassifierImpl} from '../../tasks/web/vision/image_classifier/image_classifier'; +import {ImageEmbedder as ImageEmbedderImpl} from '../../tasks/web/vision/image_embedder/image_embedder'; +import {ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/object_detector/object_detector'; + +// Declare the variables locally so that Rollup in OSS includes them explcilty +// as exports. +const GestureRecognizer = GestureRecognizerImpl; +const HandLandmarker = HandLandmarkerImpl; +const ImageClassifier = ImageClassifierImpl; +const ImageEmbedder = ImageEmbedderImpl; +const ObjectDetector = ObjectDetectorImpl; + +export { + GestureRecognizer, + HandLandmarker, + ImageClassifier, + ImageEmbedder, + ObjectDetector +}; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3c45fbfa6..7267744e2 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1,17 +1 @@ # This contains the MediaPipe Vision Tasks. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_library( - name = "vision_lib", - srcs = ["index.ts"], - deps = [ - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], -) diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts deleted file mode 100644 index d68c00cc7..000000000 --- a/mediapipe/tasks/web/vision/index.ts +++ /dev/null @@ -1,21 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export * from '../../../tasks/web/vision/image_classifier/image_classifier'; -export * from '../../../tasks/web/vision/image_embedder/image_embedder'; -export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; -export * from '../../../tasks/web/vision/object_detector/object_detector'; From 6fc277ee1c34eeba9fda1e7fde90b705a4ee5824 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 16 Nov 2022 18:34:14 -0800 Subject: [PATCH 050/137] Internal change PiperOrigin-RevId: 489091534 --- mediapipe/gpu/gl_context.cc | 8 ++++++-- mediapipe/gpu/gl_context.h | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 7f7ba0e23..91d2837c5 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -826,10 +826,14 @@ std::shared_ptr GlContext::CreateSyncToken() { return token; } -bool GlContext::IsAnyContextCurrent() { +PlatformGlContext GlContext::GetCurrentNativeContext() { ContextBinding ctx; GetCurrentContextBinding(&ctx); - return ctx.context != kPlatformGlContextNone; + return ctx.context; +} + +bool GlContext::IsAnyContextCurrent() { + return GetCurrentNativeContext() != kPlatformGlContextNone; } std::shared_ptr diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 957cb510f..7f5168d8b 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -307,6 +307,10 @@ class GlContext : public std::enable_shared_from_this { // the GlContext class, is current. static bool IsAnyContextCurrent(); + // Returns the current native context, whether managed by this class or not. + // Useful as a cross-platform way to get the current PlatformGlContext. + static PlatformGlContext GetCurrentNativeContext(); + // Creates a synchronization token for the current, non-GlContext-owned // context. This can be passed to MediaPipe so it can synchronize with the // commands issued in the external context up to this point. From 899c87466ec7cc62b5b60f10564c997c49bc9395 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 16 Nov 2022 20:55:18 -0800 Subject: [PATCH 051/137] Add MP Tasks entrypoints PiperOrigin-RevId: 489110875 --- mediapipe/tasks/web/audio/BUILD | 12 ++++++++++++ mediapipe/tasks/web/audio/index.ts | 17 +++++++++++++++++ mediapipe/tasks/web/text/BUILD | 13 +++++++++++++ mediapipe/tasks/web/text/index.ts | 18 ++++++++++++++++++ mediapipe/tasks/web/vision/BUILD | 16 ++++++++++++++++ mediapipe/tasks/web/vision/index.ts | 21 +++++++++++++++++++++ 6 files changed, 97 insertions(+) create mode 100644 mediapipe/tasks/web/audio/index.ts create mode 100644 mediapipe/tasks/web/text/index.ts create mode 100644 mediapipe/tasks/web/vision/index.ts diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 69b0408e9..4f6e48b28 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -1 +1,13 @@ # This contains the MediaPipe Audio Tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "audio_lib", + srcs = ["index.ts"], + deps = [ + "//mediapipe/tasks/web/audio/audio_classifier", + ], +) diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts new file mode 100644 index 000000000..a5083b326 --- /dev/null +++ b/mediapipe/tasks/web/audio/index.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index edd23c7d4..4b465b0f5 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -1 +1,14 @@ # This contains the MediaPipe Text Tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "text_lib", + srcs = ["index.ts"], + deps = [ + "//mediapipe/tasks/web/text/text_classifier", + "//mediapipe/tasks/web/text/text_embedder", + ], +) diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts new file mode 100644 index 000000000..d50db209c --- /dev/null +++ b/mediapipe/tasks/web/text/index.ts @@ -0,0 +1,18 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/text/text_classifier/text_classifier'; +export * from '../../../tasks/web/text/text_embedder/text_embedder'; diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 7267744e2..3c45fbfa6 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -1 +1,17 @@ # This contains the MediaPipe Vision Tasks. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "vision_lib", + srcs = ["index.ts"], + deps = [ + "//mediapipe/tasks/web/vision/gesture_recognizer", + "//mediapipe/tasks/web/vision/hand_landmarker", + "//mediapipe/tasks/web/vision/image_classifier", + "//mediapipe/tasks/web/vision/image_embedder", + "//mediapipe/tasks/web/vision/object_detector", + ], +) diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts new file mode 100644 index 000000000..d68c00cc7 --- /dev/null +++ b/mediapipe/tasks/web/vision/index.ts @@ -0,0 +1,21 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export * from '../../../tasks/web/vision/image_classifier/image_classifier'; +export * from '../../../tasks/web/vision/image_embedder/image_embedder'; +export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; +export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; +export * from '../../../tasks/web/vision/object_detector/object_detector'; From 5a6837d034f9583e2f43659c388638ac14ad0b7e Mon Sep 17 00:00:00 2001 From: kinaryml Date: Wed, 16 Nov 2022 22:08:52 -0800 Subject: [PATCH 052/137] Fix errors that will occur in python 3.11 --- mediapipe/tasks/python/audio/audio_classifier.py | 3 ++- mediapipe/tasks/python/audio/audio_embedder.py | 3 ++- mediapipe/tasks/python/text/text_classifier.py | 4 +++- mediapipe/tasks/python/text/text_embedder.py | 4 +++- mediapipe/tasks/python/vision/gesture_recognizer.py | 6 ++++-- mediapipe/tasks/python/vision/image_classifier.py | 3 ++- mediapipe/tasks/python/vision/image_embedder.py | 3 ++- 7 files changed, 18 insertions(+), 8 deletions(-) diff --git a/mediapipe/tasks/python/audio/audio_classifier.py b/mediapipe/tasks/python/audio/audio_classifier.py index 7955cc4dc..2dd1cc4a3 100644 --- a/mediapipe/tasks/python/audio/audio_classifier.py +++ b/mediapipe/tasks/python/audio/audio_classifier.py @@ -70,7 +70,8 @@ class AudioClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[[AudioClassifierResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/audio/audio_embedder.py b/mediapipe/tasks/python/audio/audio_embedder.py index a774d71e9..4484064ee 100644 --- a/mediapipe/tasks/python/audio/audio_embedder.py +++ b/mediapipe/tasks/python/audio/audio_embedder.py @@ -71,7 +71,8 @@ class AudioEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.AUDIO_CLIPS - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) result_callback: Optional[Callable[[AudioEmbedderResult, int], None]] = None @doc_controls.do_not_generate_docs diff --git a/mediapipe/tasks/python/text/text_classifier.py b/mediapipe/tasks/python/text/text_classifier.py index 92d547f20..c6095e1c3 100644 --- a/mediapipe/tasks/python/text/text_classifier.py +++ b/mediapipe/tasks/python/text/text_classifier.py @@ -14,6 +14,7 @@ """MediaPipe text classifier task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -48,7 +49,8 @@ class TextClassifierOptions: classifier_options: Options for the text classification task. """ base_options: _BaseOptions - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextClassifierGraphOptionsProto: diff --git a/mediapipe/tasks/python/text/text_embedder.py b/mediapipe/tasks/python/text/text_embedder.py index f3e5eecbe..1a32796a3 100644 --- a/mediapipe/tasks/python/text/text_embedder.py +++ b/mediapipe/tasks/python/text/text_embedder.py @@ -14,6 +14,7 @@ """MediaPipe text embedder task.""" import dataclasses +from typing import Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter @@ -49,7 +50,8 @@ class TextEmbedderOptions: embedder_options: Options for the text embedder task. """ base_options: _BaseOptions - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) @doc_controls.do_not_generate_docs def to_pb2(self) -> _TextEmbedderGraphOptionsProto: diff --git a/mediapipe/tasks/python/vision/gesture_recognizer.py b/mediapipe/tasks/python/vision/gesture_recognizer.py index 9b6fd8cab..8addebe4c 100644 --- a/mediapipe/tasks/python/vision/gesture_recognizer.py +++ b/mediapipe/tasks/python/vision/gesture_recognizer.py @@ -181,9 +181,11 @@ class GestureRecognizerOptions: min_hand_presence_confidence: Optional[float] = 0.5 min_tracking_confidence: Optional[float] = 0.5 canned_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) custom_gesture_classifier_options: Optional[ - _ClassifierOptions] = _ClassifierOptions() + _ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[ [GestureRecognizerResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_classifier.py b/mediapipe/tasks/python/vision/image_classifier.py index 763160e1e..d3c2965ba 100644 --- a/mediapipe/tasks/python/vision/image_classifier.py +++ b/mediapipe/tasks/python/vision/image_classifier.py @@ -70,7 +70,8 @@ class ImageClassifierOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - classifier_options: _ClassifierOptions = _ClassifierOptions() + classifier_options: Optional[_ClassifierOptions] = dataclasses.field( + default_factory=lambda: _ClassifierOptions()) result_callback: Optional[Callable[ [ImageClassifierResult, image_module.Image, int], None]] = None diff --git a/mediapipe/tasks/python/vision/image_embedder.py b/mediapipe/tasks/python/vision/image_embedder.py index f299fa590..06624d16e 100644 --- a/mediapipe/tasks/python/vision/image_embedder.py +++ b/mediapipe/tasks/python/vision/image_embedder.py @@ -69,7 +69,8 @@ class ImageEmbedderOptions: """ base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - embedder_options: _EmbedderOptions = _EmbedderOptions() + embedder_options: Optional[_EmbedderOptions] = dataclasses.field( + default_factory=lambda: _EmbedderOptions()) result_callback: Optional[Callable[ [ImageEmbedderResult, image_module.Image, int], None]] = None From ea4989b6f146b9589fdd048ec4702a7c5384fe52 Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 17 Nov 2022 00:06:17 -0800 Subject: [PATCH 053/137] Internal change PiperOrigin-RevId: 489135553 --- .../core/flow_limiter_calculator_test.cc | 96 ++------- mediapipe/framework/BUILD | 1 + mediapipe/framework/calculator_graph.cc | 26 ++- mediapipe/framework/calculator_graph.h | 6 + .../framework/calculator_graph_bounds_test.cc | 194 +++++++++++++++++- mediapipe/util/packet_test_util.h | 80 +++++++- 6 files changed, 302 insertions(+), 101 deletions(-) diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 45bace271..5d0594de9 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -85,75 +85,6 @@ std::string SourceString(Timestamp t) { : absl::StrCat("Timestamp(", t.DebugString(), ")"); } -template -std::string SourceString(Packet packet) { - std::ostringstream oss; - if (packet.IsEmpty()) { - oss << "Packet()"; - } else { - oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" - << packet.Get() << ")"; - } - oss << ".At(" << SourceString(packet.Timestamp()) << ")"; - return oss.str(); -} - -template -class PacketsEqMatcher - : public ::testing::MatcherInterface { - public: - PacketsEqMatcher(PacketContainer packets) : packets_(packets) {} - void DescribeTo(::std::ostream* os) const override { - *os << "The expected packet contents: \n"; - Print(packets_, os); - } - bool MatchAndExplain( - const PacketContainer& value, - ::testing::MatchResultListener* listener) const override { - if (!Equals(packets_, value)) { - if (listener->IsInterested()) { - *listener << "The actual packet contents: \n"; - Print(value, listener->stream()); - } - return false; - } - return true; - } - - private: - bool Equals(const PacketContainer& c1, const PacketContainer& c2) const { - if (c1.size() != c2.size()) { - return false; - } - for (auto i1 = c1.begin(), i2 = c2.begin(); i1 != c1.end(); ++i1, ++i2) { - Packet p1 = *i1, p2 = *i2; - if (p1.Timestamp() != p2.Timestamp() || p1.IsEmpty() != p2.IsEmpty() || - (!p1.IsEmpty() && - p1.Get() != p2.Get())) { - return false; - } - } - return true; - } - void Print(const PacketContainer& packets, ::std::ostream* os) const { - for (auto it = packets.begin(); it != packets.end(); ++it) { - const Packet& packet = *it; - *os << (it == packets.begin() ? "{" : ""); - *os << SourceString(packet); - *os << (std::next(it) == packets.end() ? "}" : ", "); - } - } - - const PacketContainer packets_; -}; - -template -::testing::Matcher PacketsEq( - const PacketContainer& packets) { - return MakeMatcher( - new PacketsEqMatcher(packets)); -} - // A Calculator::Process callback function. typedef std::function @@ -743,9 +674,6 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { // The processing time "sleep_time" is reduced from 22ms to 12ms to create // the same frame rate as FlowLimiterCalculatorTest::TwoInputStreams. TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -839,13 +767,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { input_packets_[0], input_packets_[2], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); + // Exactly the timestamps released by FlowLimiterCalculator for in_1_sampled. std::vector expected_output_2 = { input_packets_[0], input_packets_[2], input_packets_[4], input_packets_[15], input_packets_[17], input_packets_[19], }; - EXPECT_THAT(out_2_packets, IntPacketsEq(expected_output_2)); + EXPECT_THAT(out_2_packets, + ElementsAreArray(PacketMatchers(expected_output_2))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -871,7 +802,8 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { MakePacket(true).At(Timestamp(190000)), MakePacket(false).At(Timestamp(200000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } std::vector StripBoundsUpdates(const std::vector& packets, @@ -891,9 +823,6 @@ std::vector StripBoundsUpdates(const std::vector& packets, // Shows how FlowLimiterCalculator releases auxiliary input packets. // In this test, auxiliary input packets arrive at twice the primary rate. TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { - auto BoolPacketsEq = PacketsEq, bool>; - auto IntPacketsEq = PacketsEq, int>; - // Configure the test. SetUpInputData(); SetUpSimulationClock(); @@ -1011,7 +940,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(6).At(Timestamp(60000)), Packet().At(Timestamp(80000)), }; - EXPECT_THAT(out_1_packets_, IntPacketsEq(expected_output)); + EXPECT_THAT(out_1_packets_, + ElementsAreArray(PacketMatchers(expected_output))); // Packets following input packets 2 and 6, and not input packets 4 and 8. std::vector expected_auxiliary_output = { @@ -1031,12 +961,13 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { }; std::vector actual_2 = StripBoundsUpdates(out_2_packets, Timestamp(90000)); - EXPECT_THAT(actual_2, IntPacketsEq(expected_auxiliary_output)); + EXPECT_THAT(actual_2, + ElementsAreArray(PacketMatchers(expected_auxiliary_output))); std::vector expected_3 = StripBoundsUpdates(expected_auxiliary_output, Timestamp(39999)); std::vector actual_3 = StripBoundsUpdates(out_3_packets, Timestamp(39999)); - EXPECT_THAT(actual_3, IntPacketsEq(expected_3)); + EXPECT_THAT(actual_3, ElementsAreArray(PacketMatchers(expected_3))); // Validate the ALLOW stream output. std::vector expected_allow = { @@ -1045,7 +976,8 @@ TEST_F(FlowLimiterCalculatorTest, AuxiliaryInputs) { MakePacket(true).At(Timestamp(60000)), MakePacket(false).At(Timestamp(80000)), }; - EXPECT_THAT(allow_packets_, BoolPacketsEq(expected_allow)); + EXPECT_THAT(allow_packets_, + ElementsAreArray(PacketMatchers(expected_allow))); } } // anonymous namespace diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 19c51853c..8ccdac3b9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1469,6 +1469,7 @@ cc_test( "//mediapipe/framework/stream_handler:mux_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/tool:sink", + "//mediapipe/util:packet_test_util", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index c17a2e1e2..526a74835 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -98,14 +98,13 @@ void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { manager_->LockIntroData(); } +void CalculatorGraph::GraphInputStream::SetNextTimestampBound( + Timestamp timestamp) { + shard_.SetNextTimestampBound(timestamp); +} + void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { - // Since GraphInputStream doesn't allow SetOffset() and - // SetNextTimestampBound(), the timestamp bound to propagate is only - // determined by the timestamp of the output packets. - CHECK(!shard_.IsEmpty()) << "Shard with name \"" << manager_->Name() - << "\" failed"; - manager_->PropagateUpdatesToMirrors( - shard_.LastAddedPacketTimestamp().NextAllowedInStream(), &shard_); + manager_->PropagateUpdatesToMirrors(shard_.NextTimestampBound(), &shard_); } void CalculatorGraph::GraphInputStream::Close() { @@ -868,6 +867,19 @@ absl::Status CalculatorGraph::AddPacketToInputStream( return AddPacketToInputStreamInternal(stream_name, std::move(packet)); } +absl::Status CalculatorGraph::SetInputStreamTimestampBound( + const std::string& stream_name, Timestamp timestamp) { + std::unique_ptr* stream = + mediapipe::FindOrNull(graph_input_streams_, stream_name); + RET_CHECK(stream).SetNoLogging() << absl::Substitute( + "SetInputStreamTimestampBound called on input stream \"$0\" which is not " + "a graph input stream.", + stream_name); + (*stream)->SetNextTimestampBound(timestamp); + (*stream)->PropagateUpdatesToMirrors(); + return absl::OkStatus(); +} + // We avoid having two copies of this code for AddPacketToInputStream( // const Packet&) and AddPacketToInputStream(Packet &&) by having this // internal-only templated version. T&& is a forwarding reference here, so diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index c51476102..04f9de45f 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -257,6 +257,10 @@ class CalculatorGraph { absl::Status AddPacketToInputStream(const std::string& stream_name, Packet&& packet); + // Indicates that input will arrive no earlier than a certain timestamp. + absl::Status SetInputStreamTimestampBound(const std::string& stream_name, + Timestamp timestamp); + // Sets the queue size of a graph input stream, overriding the graph default. absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name, int max_queue_size); @@ -425,6 +429,8 @@ class CalculatorGraph { void AddPacket(Packet&& packet) { shard_.AddPacket(std::move(packet)); } + void SetNextTimestampBound(Timestamp timestamp); + void PropagateUpdatesToMirrors(); void Close(); diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index b55f9459d..d149337cc 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/strings/str_replace.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_framework.h" @@ -24,6 +26,7 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/thread_pool_executor.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/util/packet_test_util.h" namespace mediapipe { namespace { @@ -1536,7 +1539,7 @@ class EmptyPacketCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(EmptyPacketCalculator); -// This test shows that an output timestamp bound can be specified by outputing +// This test shows that an output timestamp bound can be specified by outputting // an empty packet with a settled timestamp. TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { // OffsetAndBoundCalculator runs on parallel threads and sends ts @@ -1580,6 +1583,195 @@ TEST(CalculatorGraphBoundsTest, EmptyPacketOutput) { EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); } + // Shut down the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows that input timestamp bounds can be specified using +// CalculatorGraph::SetInputStreamTimestampBound. +TEST(CalculatorGraphBoundsTest, SetInputStreamTimestampBound) { + std::string config_str = R"( + input_stream: "input_0" + node { + calculator: "ProcessBoundToPacketCalculator" + input_stream: "input_0" + output_stream: "output_0" + } + )"; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector output_0_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { + output_0_packets.push_back(p); + return absl::OkStatus(); + })); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send in timestamp bounds. + for (int i = 0; i < 9; ++i) { + const int ts = 10 + i * 10; + MP_ASSERT_OK(graph.SetInputStreamTimestampBound( + "input_0", Timestamp(ts).NextAllowedInStream())); + MP_ASSERT_OK(graph.WaitUntilIdle()); + } + + // 9 timestamp bounds are converted to packets. + EXPECT_EQ(output_0_packets.size(), 9); + for (int i = 0; i < 9; ++i) { + EXPECT_EQ(output_0_packets[i].Timestamp(), Timestamp(10 + i * 10)); + } + + // Shutdown the graph. + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// This test shows how an input stream with infrequent packets, such as +// configuration protobufs, can be consumed while processing more frequent +// packets, such as video frames. +TEST(CalculatorGraphBoundsTest, TimestampBoundsForInfrequentInput) { + // PassThroughCalculator consuming two input streams, with default ISH. + std::string config_str = R"pb( + input_stream: "INFREQUENT:config" + input_stream: "FREQUENT:frame" + node { + calculator: "PassThroughCalculator" + input_stream: "CONFIG:config" + input_stream: "VIDEO:frame" + output_stream: "VIDEO:output_frame" + output_stream: "CONFIG:output_config" + } + )pb"; + + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(config_str); + CalculatorGraph graph; + std::vector frame_packets; + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_frame", + [&](const Packet& p) { + frame_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + std::vector config_packets; + MP_ASSERT_OK(graph.ObserveOutputStream( + "output_config", + [&](const Packet& p) { + config_packets.push_back(p); + return absl::OkStatus(); + }, + /*observe_bound_updates=*/true)); + MP_ASSERT_OK(graph.StartRun({})); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Utility functions to send packets or timestamp bounds. + auto send_fn = [&](std::string stream, std::string value, int ts) { + MP_ASSERT_OK(graph.AddPacketToInputStream( + stream, + MakePacket(absl::StrCat(value)).At(Timestamp(ts)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + auto bound_fn = [&](std::string stream, int ts) { + MP_ASSERT_OK(graph.SetInputStreamTimestampBound(stream, Timestamp(ts))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + }; + + // Send in a frame packet. + send_fn("frame", "frame_0", 0); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, ElementsAreArray(PacketMatchers({}))); + bound_fn("config", 10000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_1", 20000); + // The frame is not processed yet. + // The PassThroughCalculator with TimestampOffset 0 now propagates + // Timestamp bound 10000 to both "output_frame" and "output_config", + // which appears here as Packet().At(Timestamp(9999). The timestamp + // bounds at 29999 and 50000 are propagated similarly. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + }))); + bound_fn("config", 30000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_2", 40000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + }))); + send_fn("config", "config_1", 50000); + // The frame is processed after a fresh config arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + }))); + + // Send in a frame packet. + send_fn("frame", "frame_3", 60000); + // The frame is not processed yet. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + }))); + bound_fn("config", 70000); + // The frame is processed after a fresh config timestamp bound arrives. + EXPECT_THAT(frame_packets, + ElementsAreArray(PacketMatchers({ + MakePacket("frame_0").At(Timestamp(0)), + Packet().At(Timestamp(9999)), + MakePacket("frame_1").At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + MakePacket("frame_2").At(Timestamp(40000)), + Packet().At(Timestamp(50000)), + MakePacket("frame_3").At(Timestamp(60000)), + }))); + + // One config packet is deleivered. + EXPECT_THAT(config_packets, + ElementsAreArray(PacketMatchers({ + Packet().At(Timestamp(0)), + Packet().At(Timestamp(9999)), + Packet().At(Timestamp(20000)), + Packet().At(Timestamp(29999)), + Packet().At(Timestamp(40000)), + MakePacket("config_1").At(Timestamp(50000)), + Packet().At(Timestamp(60000)), + }))); + // Shutdown the graph. MP_ASSERT_OK(graph.CloseAllPacketSources()); MP_ASSERT_OK(graph.WaitUntilDone()); diff --git a/mediapipe/util/packet_test_util.h b/mediapipe/util/packet_test_util.h index 106d7f8d4..61e9322e1 100644 --- a/mediapipe/util/packet_test_util.h +++ b/mediapipe/util/packet_test_util.h @@ -32,30 +32,29 @@ namespace mediapipe { namespace internal { template -class PacketMatcher : public ::testing::MatcherInterface { +class PacketMatcher : public testing::MatcherInterface { public: template explicit PacketMatcher(InnerMatcher inner_matcher) : inner_matcher_( - ::testing::SafeMatcherCast(inner_matcher)) {} + testing::SafeMatcherCast(inner_matcher)) {} // Returns true iff the packet contains value of PayloadType satisfying // the inner matcher. - bool MatchAndExplain( - const Packet& packet, - ::testing::MatchResultListener* listener) const override { + bool MatchAndExplain(const Packet& packet, + testing::MatchResultListener* listener) const override { if (!packet.ValidateAsType().ok()) { *listener << packet.DebugString() << " does not contain expected type " << ExpectedTypeName(); return false; } - ::testing::StringMatchResultListener match_listener; + testing::StringMatchResultListener match_listener; const PayloadType& payload = packet.Get(); const bool matches = inner_matcher_.MatchAndExplain(payload, &match_listener); const std::string explanation = match_listener.str(); *listener << packet.DebugString() << " containing value " - << ::testing::PrintToString(payload); + << testing::PrintToString(payload); if (!explanation.empty()) { *listener << ", which " << explanation; } @@ -78,9 +77,28 @@ class PacketMatcher : public ::testing::MatcherInterface { return ::mediapipe::Demangle(typeid(PayloadType).name()); } - const ::testing::Matcher inner_matcher_; + const testing::Matcher inner_matcher_; }; +inline std::string SourceString(Timestamp t) { + return (t.IsSpecialValue()) + ? t.DebugString() + : absl::StrCat("Timestamp(", t.DebugString(), ")"); +} + +template +std::string SourceString(Packet packet) { + std::ostringstream oss; + if (packet.IsEmpty()) { + oss << "Packet()"; + } else { + oss << "MakePacket<" << MediaPipeTypeStringOrDemangled() << ">(" + << packet.Get() << ")"; + } + oss << ".At(" << SourceString(packet.Timestamp()) << ")"; + return oss.str(); +} + } // namespace internal // Creates matcher validating that the packet contains value of expected type @@ -91,9 +109,9 @@ class PacketMatcher : public ::testing::MatcherInterface { // // EXPECT_THAT(MakePacket(42), PacketContains(Eq(42))) template -inline ::testing::Matcher PacketContains( +inline testing::Matcher PacketContains( InnerMatcher inner_matcher) { - return ::testing::MakeMatcher( + return testing::MakeMatcher( new internal::PacketMatcher(inner_matcher)); } @@ -110,7 +128,7 @@ inline ::testing::Matcher PacketContains( // Eq(42))) template -inline ::testing::Matcher PacketContainsTimestampAndPayload( +inline testing::Matcher PacketContainsTimestampAndPayload( TimestampMatcher timestamp_matcher, ContentMatcher content_matcher) { return testing::AllOf( testing::Property("Packet::Timestamp", &Packet::Timestamp, @@ -118,6 +136,46 @@ inline ::testing::Matcher PacketContainsTimestampAndPayload( PacketContains(content_matcher)); } +template +class PacketEqMatcher : public testing::MatcherInterface { + public: + PacketEqMatcher(Packet packet) : packet_(packet) {} + void DescribeTo(::std::ostream* os) const override { + *os << "The expected packet: " << internal::SourceString(packet_); + } + bool MatchAndExplain(Packet value, + testing::MatchResultListener* listener) const override { + bool unequal = (value.Timestamp() != packet_.Timestamp() || + value.IsEmpty() != packet_.IsEmpty() || + (!value.IsEmpty() && value.Get() != packet_.Get())); + if (unequal && listener->IsInterested()) { + *listener << "The actual packet: " << internal::SourceString(value); + } + return !unequal; + } + const Packet packet_; +}; + +template +testing::Matcher PacketEq(Packet packet) { + return MakeMatcher(new PacketEqMatcher(packet)); +} + +template +std::vector> PacketMatchers( + std::vector packets) { + std::vector> result; + for (const auto& packet : packets) { + result.push_back(PacketEq(packet)); + } + return result; +} + +} // namespace mediapipe + +namespace mediapipe { +using mediapipe::PacketContains; +using mediapipe::PacketContainsTimestampAndPayload; } // namespace mediapipe #endif // MEDIAPIPE_UTIL_PACKET_TEST_UTIL_H_ From 3ccf7308e03933ceb6285e7f347d2865c7a4d540 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 17 Nov 2022 05:26:56 -0800 Subject: [PATCH 054/137] Add shared options for Text and Audio Tasks PiperOrigin-RevId: 489186644 --- .../audioembedder/AudioEmbedderResult.java | 4 +- .../tasks/audio/core/RunningMode.java | 2 +- .../tasks/web/audio/audio_classifier/BUILD | 1 + .../audio_classifier_options.d.ts | 7 ++- mediapipe/tasks/web/audio/core/BUILD | 13 ++++++ .../web/audio/core/audio_task_options.d.ts | 44 +++++++++++++++++++ .../tasks/web/core/classifier_options.d.ts | 5 +-- .../tasks/web/core/embedder_options.d.ts | 5 +-- mediapipe/tasks/web/text/core/BUILD | 11 +++++ .../web/text/core/text_task_options.d.ts | 23 ++++++++++ .../tasks/web/text/text_classifier/BUILD | 1 + .../text_classifier_options.d.ts | 7 ++- mediapipe/tasks/web/text/text_embedder/BUILD | 1 + .../text_embedder/text_embedder_options.d.ts | 7 ++- mediapipe/tasks/web/vision/core/BUILD | 2 +- .../web/vision/core/vision_task_options.d.ts | 2 +- .../image_classifier_options.d.ts | 2 +- .../image_embedder_options.d.ts | 2 +- 18 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 mediapipe/tasks/web/audio/core/BUILD create mode 100644 mediapipe/tasks/web/audio/core/audio_task_options.d.ts create mode 100644 mediapipe/tasks/web/text/core/BUILD create mode 100644 mediapipe/tasks/web/text/core/text_task_options.d.ts diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java index ee4df0198..a986048f0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedderResult.java @@ -65,8 +65,8 @@ public abstract class AudioEmbedderResult implements TaskResult { /** * Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents - * one audio embedding result in an audio stream, and s only available when running with the audio - * stream mode. + * one audio embedding result in an audio stream, and is only available when running with the + * audio stream mode. */ public abstract Optional embeddingResult(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java index f0a123810..a778eae46 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java @@ -20,7 +20,7 @@ package com.google.mediapipe.tasks.audio.core; *
    *
  • AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips. *
  • AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from - * microphone. + * a microphone. *
*/ public enum RunningMode { diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 6a78116c3..412af3bea 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -36,6 +36,7 @@ mediapipe_ts_declaration( "audio_classifier_result.d.ts", ], deps = [ + "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts index 93bd9927e..975b1e315 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as AudioClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; + +/** Options to configure the MediaPipe Audio Classifier Task */ +export declare interface AudioClassifierOptions extends ClassifierOptions, + AudioTaskOptions {} diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD new file mode 100644 index 000000000..ed60f2435 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -0,0 +1,13 @@ +# This package contains options shared by all MediaPipe Audio Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_declaration( + name = "audio_task_options", + srcs = ["audio_task_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts new file mode 100644 index 000000000..58a6e55d8 --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts @@ -0,0 +1,44 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** + * MediaPipe audio task running mode. A MediaPipe audio task can be run with + * two different modes: + * - audio_clips: The mode for running a mediapipe audio task on independent + * audio clips. + * - audio_stream: The mode for running a mediapipe audio task on an audio + * stream, such as from a microphone. + * + */ +export type RunningMode = 'audio_clips'|'audio_stream'; + +/** The options for configuring a MediaPipe Audio Task. */ +export declare interface AudioTaskOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; + + /** + * The running mode of the task. Default to the audio_clips mode. + * Audio tasks have two running modes: + * 1) The mode for running a mediapipe audio task on independent + * audio clips. + * 2) The mode for running a mediapipe audio task on an audio + * stream, such as from a microphone. + */ + runningMode?: RunningMode; +} diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 3dec8d27e..1d804d629 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -16,11 +16,8 @@ import {BaseOptions} from '../../../tasks/web/core/base_options'; -/** Options to configure the Mediapipe Classifier Task. */ +/** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * The locale to use for display names specified through the TFLite Model * Metadata, if any. Defaults to English. diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts index 78ddad1ae..3ec2a170c 100644 --- a/mediapipe/tasks/web/core/embedder_options.d.ts +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -16,11 +16,8 @@ import {BaseOptions} from '../../../tasks/web/core/base_options'; -/** Options to configure the MediaPipe Embedder Task */ +/** Options to configure a MediaPipe Embedder Task */ export declare interface EmbedderOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - /** * Whether to normalize the returned feature vector with L2 norm. Use this * option only if the model does not already contain a native L2_NORMALIZATION diff --git a/mediapipe/tasks/web/text/core/BUILD b/mediapipe/tasks/web/text/core/BUILD new file mode 100644 index 000000000..3e7faec93 --- /dev/null +++ b/mediapipe/tasks/web/text/core/BUILD @@ -0,0 +1,11 @@ +# This package contains options shared by all MediaPipe Texxt Tasks for Web. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_declaration( + name = "text_task_options", + srcs = ["text_task_options.d.ts"], + deps = ["//mediapipe/tasks/web/core"], +) diff --git a/mediapipe/tasks/web/text/core/text_task_options.d.ts b/mediapipe/tasks/web/text/core/text_task_options.d.ts new file mode 100644 index 000000000..4874e35bf --- /dev/null +++ b/mediapipe/tasks/web/text/core/text_task_options.d.ts @@ -0,0 +1,23 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions} from '../../../../tasks/web/core/base_options'; + +/** The options for configuring a MediaPipe Text task. */ +export declare interface TextTaskOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; +} diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 7dbbb18ca..8c3b8e226 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -40,5 +40,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts index 51b2b3947..b50767e1a 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {ClassifierOptions as TextClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; + +/** Options to configure the MediaPipe Text Classifier Task */ +export declare interface TextClassifierOptions extends ClassifierOptions, + TextTaskOptions {} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index bebd612dd..17b5eac06 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -39,5 +39,6 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts index 9af263765..9ea570304 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -14,4 +14,9 @@ * limitations under the License. */ -export {EmbedderOptions as TextEmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; + +/** Options to configure the MediaPipe Text Embedder Task */ +export declare interface TextEmbedderOptions extends EmbedderOptions, + TextTaskOptions {} diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 8c405ae6e..e3a5edf33 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -1,4 +1,4 @@ -# This package contains options shared by all MediaPipe Tasks for Web. +# This package contains options shared by all MediaPipe Vision Tasks for Web. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index 8b9562e46..e04eb6596 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -17,7 +17,7 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; /** - * The two running modes of a video task. + * The two running modes of a vision task. * 1) The image mode for processing single image inputs. * 2) The video mode for processing decoded frames of a video. */ diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts index c1141d28f..e99dd2b69 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_options.d.ts @@ -17,6 +17,6 @@ import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; -/** Ooptions to configure the image classifier task. */ +/** Options to configure the MediaPipe Image Classifier Task. */ export declare interface ImageClassifierOptions extends ClassifierOptions, VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts index 10000825c..8a04be5e1 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder_options.d.ts @@ -17,6 +17,6 @@ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; -/** The options for configuring a MediaPipe image embedder task. */ +/** Options for configuring a MediaPipe Image Embedder task. */ export declare interface ImageEmbedderOptions extends EmbedderOptions, VisionTaskOptions {} From 1fb0902aa06d45ebc73f5337d9f65f06c418c24b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 14:01:14 -0800 Subject: [PATCH 055/137] Update gesture_recognizer test PiperOrigin-RevId: 489301508 --- .../vision/gesture_recognizer/gesture_recognizer_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 8a6e474d7..39272cbbc 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,6 +14,7 @@ import io import os +import random import tempfile from unittest import mock as unittest_mock import zipfile @@ -41,6 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() + random.seed(1234) all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) @@ -93,11 +95,11 @@ class GestureRecognizerTest(tf.test.TestCase): tflite_file=gesture_classifier_tflite_file, size=[1, model.embedding_size]) - def _test_accuracy(self, model, threshold=0.25): + def _test_accuracy(self, model, threshold=0.0): # Test on _train_data because of our limited dataset size _, accuracy = model.evaluate(self._train_data) tf.compat.v1.logging.info(f'train accuracy: {accuracy}') - self.assertGreaterEqual(accuracy, threshold) + self.assertGreater(accuracy, threshold) @unittest_mock.patch.object( gesture_recognizer.hyperparameters, From a7bd725e65e34ea416b15ceeffed972a2b205071 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 16:06:04 -0800 Subject: [PATCH 056/137] Internal change PiperOrigin-RevId: 489331826 --- mediapipe/gpu/gl_context.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 91d2837c5..53e3ff8b7 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -290,8 +290,15 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { // some Emscripten cases), there might be some existing tripped error. ForceClearExistingGlErrors(); - absl::string_view version_string( - reinterpret_cast(glGetString(GL_VERSION))); + absl::string_view version_string; + const GLubyte* version_string_ptr = glGetString(GL_VERSION); + if (version_string_ptr != nullptr) { + version_string = reinterpret_cast(version_string_ptr); + } else { + // This may happen when using SwiftShader, but the numeric versions are + // available and will be used instead. + LOG(WARNING) << "failed to get GL_VERSION string"; + } // We will decide later whether we want to use the version numbers we query // for, or instead derive that information from the context creation result, @@ -333,7 +340,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) { } LOG(INFO) << "GL version: " << gl_major_version_ << "." << gl_minor_version_ - << " (" << glGetString(GL_VERSION) << ")"; + << " (" << version_string << ")"; { auto status = GetGlExtensions(); if (!status.ok()) { From ab3a5f0fbf1883c4d1dfe1df2db80a7045a390c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 16:28:08 -0800 Subject: [PATCH 057/137] Make MuxCalculator with DefaultInputStreamHandler to handle graph closure gracefully PiperOrigin-RevId: 489336722 --- mediapipe/calculators/core/mux_calculator.cc | 4 ++++ .../calculators/core/mux_calculator_test.cc | 16 ++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index a0ce2ae34..88b04a32b 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -41,6 +41,10 @@ class MuxCalculator : public Node { StreamHandler("MuxInputStreamHandler")); absl::Status Process(CalculatorContext* cc) final { + if (kSelect(cc).IsStream() && kSelect(cc).IsEmpty()) { + return absl::OkStatus(); + } + int select = *kSelect(cc); RET_CHECK(0 <= select && select < kIn(cc).Count()); if (!kIn(cc)[select].IsEmpty()) { diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index a3ac8a27a..6b9434be9 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -439,7 +439,7 @@ TEST(MuxCalculatorTest, HandlesCloseGracefully) { EXPECT_TRUE(output_packets.empty()); } -TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { +TEST(MuxCalculatorTest, HandlesCloseGracefullyWithDeafultInputStreamHandler) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( R"pb( @@ -480,15 +480,11 @@ TEST(MuxCalculatorTest, CrashesOnCloseWithDeafultInputStreamHandler) { MP_ASSERT_OK(graph.AddPacketToInputStream( "value_0", MakePacket(0).At(Timestamp(1000)))); MP_ASSERT_OK(graph.WaitUntilIdle()); - // Currently MuxCalculator crashes with a correct packet set from - // DefaultInputStreamHandler. The SELECT packet is missing at Timestamp 1000, - // and an empty packet is the correct representation of that. - EXPECT_DEATH( - { - (void)graph.CloseAllInputStreams(); - (void)graph.WaitUntilDone(); - }, - "Check failed: payload_"); + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + ASSERT_EQ(output_packets.size(), 1); + EXPECT_TRUE(output_packets[0].IsEmpty()); } } // namespace From 6f3cb340e153af68c31462a337ee0bf1c113f7cd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 17 Nov 2022 17:14:56 -0800 Subject: [PATCH 058/137] Internal change PiperOrigin-RevId: 489345940 --- .../tasks/web/audio/audio_classifier/BUILD | 2 +- .../audio_classifier/audio_classifier.ts | 2 +- mediapipe/tasks/web/core/BUILD | 4 ++-- mediapipe/tasks/web/core/task_runner.ts | 6 +++--- .../tasks/web/text/text_classifier/BUILD | 2 +- .../text/text_classifier/text_classifier.ts | 2 +- mediapipe/tasks/web/text/text_embedder/BUILD | 2 +- .../web/text/text_embedder/text_embedder.ts | 2 +- mediapipe/tasks/web/vision/core/BUILD | 2 +- .../web/vision/core/vision_task_runner.ts | 2 +- .../tasks/web/vision/gesture_recognizer/BUILD | 2 +- .../gesture_recognizer/gesture_recognizer.ts | 2 +- .../tasks/web/vision/hand_landmarker/BUILD | 2 +- .../vision/hand_landmarker/hand_landmarker.ts | 2 +- .../tasks/web/vision/image_classifier/BUILD | 2 +- .../image_classifier/image_classifier.ts | 2 +- .../tasks/web/vision/image_embedder/BUILD | 2 +- .../vision/image_embedder/image_embedder.ts | 2 +- .../tasks/web/vision/object_detector/BUILD | 2 +- .../vision/object_detector/object_detector.ts | 2 +- mediapipe/web/graph_runner/BUILD | 20 ++++++------------- ...{wasm_mediapipe_lib.ts => graph_runner.ts} | 14 ++++++------- ...image_lib.ts => graph_runner_image_lib.ts} | 10 +++++----- .../register_model_resources_graph_service.ts | 10 +++++----- 24 files changed, 46 insertions(+), 54 deletions(-) rename mediapipe/web/graph_runner/{wasm_mediapipe_lib.ts => graph_runner.ts} (99%) rename mediapipe/web/graph_runner/{wasm_mediapipe_image_lib.ts => graph_runner_image_lib.ts} (83%) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 412af3bea..9e1fcbc51 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 76b926723..5533b0eaa 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index e9ef85d46..6eca8bb4a 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -18,9 +18,9 @@ mediapipe_ts_library( "task_runner.ts", ], deps = [ + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_image_lib_ts", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", ], ) diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c948930fc..67aa4e4df 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -15,12 +15,12 @@ */ import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; -import {SupportImage} from '../../../web/graph_runner/wasm_mediapipe_image_lib'; -import {WasmMediaPipeLib, WasmModule} from '../../../web/graph_runner/wasm_mediapipe_lib'; +import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; +import {GraphRunner, WasmModule} from '../../../web/graph_runner/graph_runner'; // tslint:disable-next-line:enforce-name-casing const WasmMediaPipeImageLib = - SupportModelResourcesGraphService(SupportImage(WasmMediaPipeLib)); + SupportModelResourcesGraphService(SupportImage(GraphRunner)); /** Base class for all MediaPipe Tasks. */ export abstract class TaskRunner extends WasmMediaPipeImageLib { diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 8c3b8e226..71ef02c92 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -26,7 +26,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index d4f413efa..04789f5e1 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 17b5eac06..c555f8d33 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 7c631683d..57b91d575 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -23,7 +23,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index e3a5edf33..1d8944f14 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -21,6 +21,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 372ce9ba7..79ff45156 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -17,7 +17,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index f2b668239..ddfd1a327 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -32,7 +32,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8e745534e..dd050d0f1 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -31,7 +31,7 @@ import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 36f1d7eb7..1849687c5 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -27,7 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 0aba5c82c..32b1eed4b 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -27,7 +27,7 @@ import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index e7e830332..ebe64ecf4 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 0011e9c55..b59cb6fb1 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index ce1c25700..feb3ae054 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index d17bc72fa..c60665052 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -23,7 +23,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index 0975a9fd4..b6bef6bfa 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -22,7 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:wasm_mediapipe_lib_ts", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index e6cbd8627..44046cd1e 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/wasm_mediapipe_lib'; +import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; diff --git a/mediapipe/web/graph_runner/BUILD b/mediapipe/web/graph_runner/BUILD index dab6be50f..5c12947af 100644 --- a/mediapipe/web/graph_runner/BUILD +++ b/mediapipe/web/graph_runner/BUILD @@ -3,32 +3,24 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = [ - ":internal", "//mediapipe/tasks:internal", ]) -package_group( - name = "internal", - packages = [ - "//mediapipe/app/pursuit/wasm/web_ml_cpu/typescript/...", - ], -) - mediapipe_ts_library( - name = "wasm_mediapipe_lib_ts", + name = "graph_runner_ts", srcs = [ - ":wasm_mediapipe_lib.ts", + ":graph_runner.ts", ], allow_unoptimized_namespaces = True, ) mediapipe_ts_library( - name = "wasm_mediapipe_image_lib_ts", + name = "graph_runner_image_lib_ts", srcs = [ - ":wasm_mediapipe_image_lib.ts", + ":graph_runner_image_lib.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) mediapipe_ts_library( @@ -37,5 +29,5 @@ mediapipe_ts_library( ":register_model_resources_graph_service.ts", ], allow_unoptimized_namespaces = True, - deps = [":wasm_mediapipe_lib_ts"], + deps = [":graph_runner_ts"], ) diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts b/mediapipe/web/graph_runner/graph_runner.ts similarity index 99% rename from mediapipe/web/graph_runner/wasm_mediapipe_lib.ts rename to mediapipe/web/graph_runner/graph_runner.ts index 5f8040a33..7de5aa33b 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -129,7 +129,7 @@ declare global { declare function importScripts(...urls: Array): void; /** - * Valid types of image sources which we can run our WasmMediaPipeLib over. + * Valid types of image sources which we can run our GraphRunner over. */ export type ImageSource = HTMLCanvasElement|HTMLVideoElement|HTMLImageElement|ImageData|ImageBitmap; @@ -138,7 +138,7 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing WasmMediaPipeLib and +// Internal type of constructors used for initializing GraphRunner and // subclasses. type WasmMediaPipeConstructor = (new ( @@ -151,7 +151,7 @@ type WasmMediaPipeConstructor = * into canvas, or else return the output WebGLTexture. Takes a WebAssembly * Module (must be instantiated to self.Module). */ -export class WasmMediaPipeLib { +export class GraphRunner { // TODO: These should be protected/private, but are left exposed for // now so that we can use proper TS mixins with this class as a base. This // should be somewhat fixed when we create our .d.ts files. @@ -989,7 +989,7 @@ async function runScript(scriptUrl: string) { /** * Global function to initialize Wasm blob and load runtime assets for a * specialized MediaPipe library. This allows us to create a requested - * subclass inheriting from WasmMediaPipeLib. + * subclass inheriting from GraphRunner. * @param constructorFcn The name of the class to instantiate via "new". * @param wasmLoaderScript Url for the wasm-runner script; produced by the build * process. @@ -1043,12 +1043,12 @@ export async function createMediaPipeLib( * @return promise A promise which will resolve when initialization has * completed successfully. */ -export async function createWasmMediaPipeLib( +export async function createGraphRunner( wasmLoaderScript?: string, assetLoaderScript?: string, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, - fileLocator?: FileLocator): Promise { + fileLocator?: FileLocator): Promise { return createMediaPipeLib( - WasmMediaPipeLib, wasmLoaderScript, assetLoaderScript, glCanvas, + GraphRunner, wasmLoaderScript, assetLoaderScript, glCanvas, fileLocator); } diff --git a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts similarity index 83% rename from mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts rename to mediapipe/web/graph_runner/graph_runner_image_lib.ts index 3b45e8230..e886999cb 100644 --- a/mediapipe/web/graph_runner/wasm_mediapipe_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -1,12 +1,12 @@ -import {ImageSource, WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {ImageSource, GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -19,10 +19,10 @@ export declare interface WasmImageModule { } /** - * An implementation of WasmMediaPipeLib that supports binding GPU image data as + * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(WasmMediaPipeLib);` + * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index e85d63b06..bc9c93e8a 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -1,12 +1,12 @@ -import {WasmMediaPipeLib} from './wasm_mediapipe_lib'; +import {GraphRunner} from './graph_runner'; /** - * We extend from a WasmMediaPipeLib constructor. This ensures our mixin has + * We extend from a GraphRunner constructor. This ensures our mixin has * access to the wasmModule, among other things. The `any` type is required for * mixin constructors. */ // tslint:disable-next-line:no-any -type LibConstructor = new (...args: any[]) => WasmMediaPipeLib; +type LibConstructor = new (...args: any[]) => GraphRunner; /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler @@ -17,11 +17,11 @@ export declare interface WasmModuleRegisterModelResources { } /** - * An implementation of WasmMediaPipeLib that supports registering model + * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * WasmMediaPipeLib);` + * GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService( From efcdedbd59a135d757a49b0ff27b656e793386ad Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Thu, 17 Nov 2022 18:14:58 -0800 Subject: [PATCH 059/137] Remove redundant _ios targets PiperOrigin-RevId: 489355333 --- mediapipe/gpu/BUILD | 14 -------------- mediapipe/objc/BUILD | 4 ++-- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 4fb59f1b5..27d91f21a 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -810,20 +810,6 @@ cc_library( }), ) -# TODO: remove -objc_library( - name = "gl_calculator_helper_ios", - copts = [ - "-Wno-shorten-64-to-32", - ], - visibility = ["//visibility:public"], - deps = [ - ":gl_calculator_helper", - "//mediapipe/objc:mediapipe_framework_ios", - "//mediapipe/objc:util", - ], -) - objc_library( name = "MPPMetalHelper", srcs = ["MPPMetalHelper.mm"], diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 48c9b181a..d77692164 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -147,7 +147,7 @@ objc_library( visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ ":mediapipe_framework_ios", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_simple_shaders", ], @@ -173,7 +173,7 @@ objc_library( deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", - "//mediapipe/gpu:gl_calculator_helper_ios", + "//mediapipe/gpu:gl_calculator_helper", ], ) From ae44012c0c5a53916f9ee01b3c745868836c784b Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Fri, 18 Nov 2022 08:39:37 -0800 Subject: [PATCH 060/137] Allowing BypassCalculator to accept InputSidePackets. PiperOrigin-RevId: 489483992 --- mediapipe/calculators/core/bypass_calculator.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc index efc0612ec..4e007329b 100644 --- a/mediapipe/calculators/core/bypass_calculator.cc +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -111,6 +111,10 @@ class BypassCalculator : public Node { cc->Outputs().Get(id).SetAny(); } } + for (auto id = cc->InputSidePackets().BeginId(); + id != cc->InputSidePackets().EndId(); ++id) { + cc->InputSidePackets().Get(id).SetAny(); + } return absl::OkStatus(); } From e046982a3c6706625c997df50e51e19157624ac7 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 18 Nov 2022 08:44:02 -0800 Subject: [PATCH 061/137] Internal change PiperOrigin-RevId: 489484898 --- .../tensor/audio_to_tensor_calculator.cc | 49 ++++++++++++++++--- .../tensor/audio_to_tensor_calculator.proto | 13 +++++ 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc index d0513518a..9cb23a393 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.cc @@ -43,6 +43,7 @@ namespace api2 { namespace { using Options = ::mediapipe::AudioToTensorCalculatorOptions; +using DftTensorFormat = Options::DftTensorFormat; using FlushMode = Options::FlushMode; std::vector HannWindow(int window_size, bool sqrt_hann) { @@ -188,6 +189,8 @@ class AudioToTensorCalculator : public Node { int padding_samples_before_; int padding_samples_after_; FlushMode flush_mode_; + DftTensorFormat dft_tensor_format_; + Timestamp initial_timestamp_ = Timestamp::Unstarted(); int64 cumulative_input_samples_ = 0; Timestamp next_output_timestamp_ = Timestamp::Unstarted(); @@ -273,6 +276,7 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { } padding_samples_before_ = options.padding_samples_before(); padding_samples_after_ = options.padding_samples_after(); + dft_tensor_format_ = options.dft_tensor_format(); flush_mode_ = options.flush_mode(); RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ @@ -492,14 +496,43 @@ absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), timestamp); } - Matrix fft_output_matrix = - Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); - fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); - // The last two elements are the DFT Nyquist values. - fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part - fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part - ASSIGN_OR_RETURN(output_tensor, - ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); + switch (dft_tensor_format_) { + case Options::WITH_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_ - 2) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ - 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(fft_output_matrix, + {2, fft_size_ / 2})); + break; + } + case Options::WITH_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data(), 1, fft_size_); + fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_ + 2); + fft_output_matrix(1) = 0.0f; // DC imagery part. + // The last two elements are Nyquist component. + fft_output_matrix(fft_size_) = fft_output_[1]; // Nyquist real part + fft_output_matrix(fft_size_ + 1) = 0.0f; // Nyquist imagery part + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ + 2) / 2})); + break; + } + case Options::WITHOUT_DC_AND_NYQUIST: { + Matrix fft_output_matrix = + Eigen::Map(fft_output_.data() + 2, 1, fft_size_ - 2); + ASSIGN_OR_RETURN( + output_tensor, + ConvertToTensor(fft_output_matrix, {2, (fft_size_ - 2) / 2})); + break; + } + default: + return absl::InvalidArgumentError("Unsupported dft tensor format."); + } + } else { ASSIGN_OR_RETURN(output_tensor, ConvertToTensor(block, {num_channels_, num_samples_})); diff --git a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto index cff6b2878..aa3c1229c 100644 --- a/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto +++ b/mediapipe/calculators/tensor/audio_to_tensor_calculator.proto @@ -68,4 +68,17 @@ message AudioToTensorCalculatorOptions { } optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; + + enum DftTensorFormat { + DFT_TENSOR_FORMAT_UNKNOWN = 0; + // The output dft tensor without dc and nyquist components. + WITHOUT_DC_AND_NYQUIST = 1; + // The output dft tensor contains the nyquist component as the last + // two values. + WITH_NYQUIST = 2; + // The output dft tensor contains the dc component as the first two values + // and the nyquist component as the last two values. + WITH_DC_AND_NYQUIST = 3; + } + optional DftTensorFormat dft_tensor_format = 11 [default = WITH_NYQUIST]; } From 2f361e2f4791fa774db5cb20dbc888f89c234447 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 08:51:30 -0800 Subject: [PATCH 062/137] Internal change PiperOrigin-RevId: 489486417 --- mediapipe/util/tracking/BUILD | 3 +-- mediapipe/util/tracking/motion_analysis.cc | 2 +- .../util/tracking/region_flow_computation.cc | 16 ++++++---------- .../tracking/region_flow_computation_test.cc | 2 +- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 319e99d5b..3f1ebb353 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -458,7 +458,6 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", ], ) @@ -739,7 +738,7 @@ cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", - "//mediapipe/framework/port:opencv_highgui", + "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", diff --git a/mediapipe/util/tracking/motion_analysis.cc b/mediapipe/util/tracking/motion_analysis.cc index 0b7678889..5b6a970cf 100644 --- a/mediapipe/util/tracking/motion_analysis.cc +++ b/mediapipe/util/tracking/motion_analysis.cc @@ -791,7 +791,7 @@ void MotionAnalysis::VisualizeBlurAnalysisRegions(cv::Mat* input_view) { region_flow_computation_->ComputeBlurMask(*input_view, &corner_values, &mask); cv::Mat mask_3c; - cv::cvtColor(mask, mask_3c, CV_GRAY2RGB); + cv::cvtColor(mask, mask_3c, cv::COLOR_GRAY2RGB); cv::addWeighted(*input_view, 0.5, mask_3c, 0.5, -128, *input_view); } diff --git a/mediapipe/util/tracking/region_flow_computation.cc b/mediapipe/util/tracking/region_flow_computation.cc index cfd5c23c2..708c868b5 100644 --- a/mediapipe/util/tracking/region_flow_computation.cc +++ b/mediapipe/util/tracking/region_flow_computation.cc @@ -30,6 +30,7 @@ #include "absl/container/node_hash_set.h" #include "absl/memory/memory.h" #include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_features2d_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_video_inc.h" @@ -935,12 +936,13 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, // Area based method best for downsampling. // For color images to temporary buffer. cv::Mat& resized = source.channels() == 1 ? dest_frame : *curr_color_image_; - cv::resize(source, resized, resized.size(), 0, 0, CV_INTER_AREA); + cv::resize(source, resized, resized.size(), 0, 0, cv::INTER_AREA); source_ptr = &resized; // Resize feature extraction mask if needed. if (!source_mask.empty()) { dest_mask.create(resized.rows, resized.cols, CV_8UC1); - cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, CV_INTER_NN); + cv::resize(source_mask, dest_mask, dest_mask.size(), 0, 0, + cv::INTER_NEAREST); } } else if (!source_mask.empty()) { source_mask.copyTo(dest_mask); @@ -954,7 +956,7 @@ bool RegionFlowComputation::InitFrame(const cv::Mat& source, const int dimension = visual_options.tiny_image_dimension(); data->tiny_image.create(dimension, dimension, type); cv::resize(*source_ptr, data->tiny_image, data->tiny_image.size(), 0, 0, - CV_INTER_AREA); + cv::INTER_AREA); } if (source_ptr->channels() == 1 && @@ -2286,7 +2288,7 @@ void RegionFlowComputation::ExtractFeatures( // Initialize mask from frame's feature extraction mask, by downsampling and // negating the latter mask. if (!data->mask.empty()) { - cv::resize(data->mask, mask, mask.size(), 0, 0, CV_INTER_NN); + cv::resize(data->mask, mask, mask.size(), 0, 0, cv::INTER_NEAREST); for (int y = 0; y < mask.rows; ++y) { uint8* mask_ptr = mask.ptr(y); for (int x = 0; x < mask.cols; ++x) { @@ -2590,12 +2592,6 @@ void RegionFlowComputation::TrackFeatures(FrameTrackingData* from_data_ptr, cv::_InputArray input_frame2(data2.pyramid); #endif - // Using old c-interface for OpenCV's 2.2 tracker. - CvTermCriteria criteria; - criteria.type = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER; - criteria.max_iter = options_.tracking_options().tracking_iterations(); - criteria.epsilon = 0.02f; - feature_track_error_.resize(num_features); feature_status_.resize(num_features); if (use_cv_tracking_) { diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 0ac6dc2a5..435a8e200 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -28,7 +28,7 @@ #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" -#include "mediapipe/framework/port/opencv_highgui_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" From 03d388fecffe3734d8f6878f6f0def404065076b Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 09:49:23 -0800 Subject: [PATCH 063/137] Add hand landmark named index constants PiperOrigin-RevId: 489498248 --- .../tasks/cc/components/containers/BUILD | 5 ++ .../tasks/cc/components/containers/landmark.h | 48 +++++++++++++ .../tasks/components/containers/BUILD | 12 ++++ .../components/containers/HandLandmark.java | 72 +++++++++++++++++++ .../python/components/containers/landmark.py | 26 +++++++ .../web/components/containers/landmark.d.ts | 25 +++++++ 6 files changed, 188 insertions(+) create mode 100644 mediapipe/tasks/cc/components/containers/landmark.h create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index bd66a0f28..2f5f8be5b 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,3 +49,8 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "landmark", + hdrs = ["landmark.h"], +) diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/components/containers/landmark.h new file mode 100644 index 000000000..6fdd294ae --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmark.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ + +namespace mediapipe::tasks::components::containers { + +// The 21 hand landmarks. +enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +}; + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..869157295 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -74,6 +74,18 @@ android_library( ], ) +android_library( + name = "handlandmark", + srcs = ["HandLandmark.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "landmark", srcs = ["Landmark.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java new file mode 100644 index 000000000..da7c4e0ca --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java @@ -0,0 +1,72 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import androidx.annotation.IntDef; + +/** The 21 hand landmarks. */ +public final class HandLandmark { + public static final int NUM_LANDMARKS = 21; + + public static final int WRIST = 0; + public static final int THUMB_CMC = 1; + public static final int THUMB_MCP = 2; + public static final int THUMB_IP = 3; + public static final int THUMB_TIP = 4; + public static final int INDEX_FINGER_MCP = 5; + public static final int INDEX_FINGER_PIP = 6; + public static final int INDEX_FINGER_DIP = 7; + public static final int INDEX_FINGER_TIP = 8; + public static final int MIDDLE_FINGER_MCP = 9; + public static final int MIDDLE_FINGER_PIP = 10; + public static final int MIDDLE_FINGER_DIP = 11; + public static final int MIDDLE_FINGER_TIP = 12; + public static final int RING_FINGER_MCP = 13; + public static final int RING_FINGER_PIP = 14; + public static final int RING_FINGER_DIP = 15; + public static final int RING_FINGER_TIP = 16; + public static final int PINKY_MCP = 17; + public static final int PINKY_PIP = 18; + public static final int PINKY_DIP = 19; + public static final int PINKY_TIP = 20; + + /** Represents a hand landmark type. */ + @IntDef({ + WRIST, + THUMB_CMC, + THUMB_MCP, + THUMB_IP, + THUMB_TIP, + INDEX_FINGER_MCP, + INDEX_FINGER_PIP, + INDEX_FINGER_DIP, + INDEX_FINGER_TIP, + MIDDLE_FINGER_MCP, + MIDDLE_FINGER_PIP, + MIDDLE_FINGER_DIP, + MIDDLE_FINGER_TIP, + RING_FINGER_MCP, + RING_FINGER_PIP, + RING_FINGER_DIP, + RING_FINGER_TIP, + PINKY_MCP, + PINKY_PIP, + PINKY_DIP, + PINKY_TIP, + }) + public @interface HandLandmarkType {} + + private HandLandmark() {} +} diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index dee2a16ad..81b2943dc 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -14,6 +14,7 @@ """Landmark data class.""" import dataclasses +import enum from typing import Optional from mediapipe.framework.formats import landmark_pb2 @@ -120,3 +121,28 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) + + +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..352717a2f 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -33,3 +33,28 @@ export declare interface Landmark { /** Whether this landmark is normalized with respect to the image size. */ normalized: boolean; } + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} From ac212c15070854b407812148739f6e1b72089a75 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Fri, 18 Nov 2022 10:06:47 -0800 Subject: [PATCH 064/137] Internal change PiperOrigin-RevId: 489502255 --- mediapipe/calculators/audio/BUILD | 1 - mediapipe/calculators/core/BUILD | 6 ++---- mediapipe/calculators/image/BUILD | 10 +++++----- mediapipe/calculators/tensor/BUILD | 6 +++--- mediapipe/calculators/tensorflow/BUILD | 14 ++++++++------ mediapipe/calculators/tflite/BUILD | 6 +++--- mediapipe/calculators/util/BUILD | 9 ++++----- mediapipe/calculators/video/BUILD | 4 ++-- mediapipe/framework/BUILD | 4 ---- mediapipe/framework/formats/BUILD | 8 +++++--- mediapipe/framework/formats/motion/BUILD | 4 ++-- mediapipe/framework/profiler/BUILD | 4 ++++ mediapipe/framework/stream_handler/BUILD | 4 ++-- mediapipe/framework/tool/BUILD | 7 ++----- mediapipe/gpu/BUILD | 1 - 15 files changed, 42 insertions(+), 46 deletions(-) diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index ba461e4a7..555f7543f 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -197,7 +197,6 @@ cc_library( ":spectrogram_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index ecd878115..39837fadb 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -341,7 +341,6 @@ cc_test( srcs = ["concatenate_proto_list_calculator_test.cc"], deps = [ ":concatenate_proto_list_calculator", - ":concatenate_vector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -403,7 +402,6 @@ cc_test( srcs = ["clip_vector_size_calculator_test.cc"], deps = [ ":clip_vector_size_calculator", - "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", @@ -956,10 +954,10 @@ cc_library( deps = [ ":split_vector_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 89e2d371c..c78bc5cf7 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -159,8 +159,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", @@ -186,8 +186,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -290,10 +290,10 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:ret_check", @@ -361,12 +361,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", + "//mediapipe/util:color_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - "//mediapipe/util:color_cc_proto", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", ] + select({ @@ -630,8 +630,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 3f1278397..4c06df0ff 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -433,6 +433,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":inference_calculator_cc_proto", ":inference_calculator_options_lib", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -794,12 +795,12 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:tensor", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", ] + selects.with_or({ ":compute_shader_unavailable": [], @@ -1279,7 +1280,6 @@ cc_library( "//mediapipe/gpu:MPPMetalHelper", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1378,9 +1378,9 @@ cc_library( "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:port", + "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/util:resource_util", "@org_tensorflow//tensorflow/lite:framework", - "//mediapipe/gpu:gpu_origin_cc_proto", "//mediapipe/framework/port:statusor", ] + selects.with_or({ "//mediapipe/gpu:disable_gpu": [], diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index d0dfc12ab..45f64f4f7 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -346,8 +346,8 @@ cc_library( srcs = ["matrix_to_tensor_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":matrix_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -414,7 +414,7 @@ cc_library( "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:detection_cc_proto", # build_cleaner: keep + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", @@ -451,8 +451,8 @@ cc_library( srcs = ["tensorflow_inference_calculator.cc"], visibility = ["//visibility:public"], deps = [ - ":tensorflow_session", ":tensorflow_inference_calculator_cc_proto", + ":tensorflow_session", "@com_google_absl//absl/log:check", "//mediapipe/framework:timestamp", "@com_google_absl//absl/base:core_headers", @@ -515,6 +515,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -546,6 +547,7 @@ cc_library( "//mediapipe/framework/deps:clock", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", + "@org_tensorflow//tensorflow/core:protos_all_cc", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -666,8 +668,8 @@ cc_library( srcs = ["tensor_to_matrix_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:time_series_header_cc_proto", ":tensor_to_matrix_calculator_cc_proto", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:status", @@ -704,10 +706,10 @@ cc_library( srcs = ["tensor_to_vector_float_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", "//mediapipe/framework/port:ret_check", - ":tensor_to_vector_float_calculator_options_cc_proto", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", @@ -1083,7 +1085,6 @@ cc_test( linkstatic = 1, deps = [ ":tensor_to_image_frame_calculator", - ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:image_frame", @@ -1236,6 +1237,7 @@ cc_test( data = [":test_frozen_graph"], linkstatic = 1, deps = [ + ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", ":tensorflow_inference_calculator", ":tensorflow_session_from_frozen_graph_generator", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 2007a4fe1..8edaeee02 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -289,8 +289,8 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_converter_calculator_cc_proto", + "//mediapipe/util/tflite:config", "//mediapipe/util:resource_util", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -410,15 +410,15 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/util/tflite:config", ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/util/tflite:config", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/deps:file_path", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:location", - "//mediapipe/framework/formats/object_detection:anchor_cc_proto", "//mediapipe/framework/port:ret_check", "@org_tensorflow//tensorflow/lite:framework", ] + selects.with_or({ diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 3a9ddc36f..24e976a73 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -23,8 +23,8 @@ cc_library( srcs = ["alignment_points_to_rects_calculator.cc"], visibility = ["//visibility:public"], deps = [ + ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", - "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -266,8 +266,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/util:color_cc_proto", "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", @@ -755,7 +755,6 @@ cc_library( deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/port:ret_check", @@ -1313,8 +1312,8 @@ cc_library( srcs = ["to_image_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:image_frame", @@ -1336,8 +1335,8 @@ cc_library( srcs = ["from_image_calculator.cc"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 53d968151..2db3ed252 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -342,12 +342,12 @@ cc_library( "//mediapipe/framework/port:opencv_features2d", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", + "//mediapipe/util/tracking:box_tracker_cc_proto", + "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util:resource_util", "//mediapipe/util/tracking", "//mediapipe/util/tracking:box_detector", "//mediapipe/util/tracking:box_tracker", - "//mediapipe/util/tracking:box_tracker_cc_proto", - "//mediapipe/util/tracking:flow_packager_cc_proto", "//mediapipe/util/tracking:tracking_visualization_utilities", ] + select({ "//mediapipe:android": [ diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 8ccdac3b9..e3429f1e9 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1039,7 +1039,6 @@ cc_library( ":graph_service_manager", ":port", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -1660,9 +1659,6 @@ cc_test( "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:default_side_packet_calculator", "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:packet_generator_cc_proto", - "//mediapipe/framework:status_handler_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:template_parser", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index c3241d911..e13bb2704 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -133,9 +133,9 @@ cc_library( "//visibility:public", ], deps = [ + ":affine_transform_data_cc_proto", "//mediapipe/framework:port", "//mediapipe/framework:type_map", - "//mediapipe/framework/formats:affine_transform_data_cc_proto", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:point", @@ -209,8 +209,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", - "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -241,6 +241,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":location", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:opencv_imgproc", ], alwayslink = 1, @@ -251,6 +252,7 @@ cc_test( srcs = ["location_opencv_test.cc"], deps = [ ":location_opencv", + "//mediapipe/framework/formats/annotation:rasterization_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:rectangle", ], @@ -346,8 +348,8 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/synchronization", "//mediapipe/framework:port", "//mediapipe/framework:type_map", diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 28e0bfc6a..9819d262c 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -16,10 +16,10 @@ # Description: # Working with dense optical flow in mediapipe. -licenses(["notice"]) - load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +licenses(["notice"]) + package(default_visibility = ["//visibility:private"]) proto_library( diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 237aa825f..b53a1ac39 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -334,6 +334,10 @@ cc_library( "graph_profiler_stub.h", ], visibility = ["//mediapipe/framework:__pkg__"], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_profile_cc_proto", + ], ) cc_test( diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 8771a8773..866a5120e 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -13,6 +13,8 @@ # limitations under the License. # +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") + licenses(["notice"]) package( @@ -20,8 +22,6 @@ package( features = ["-layering_check"], ) -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") - proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index e54fb2177..52d04b4b1 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -299,12 +299,12 @@ mediapipe_cc_test( data = [":node_chain_subgraph.proto"], requires_full_emulation = False, deps = [ + ":node_chain_subgraph_cc_proto", ":options_field_util", ":options_registry", ":options_syntax_util", ":options_util", "//mediapipe/calculators/core:flow_limiter_calculator", - "//mediapipe/calculators/core:flow_limiter_calculator_cc_proto", "//mediapipe/framework:basic_types_registration", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", @@ -312,6 +312,7 @@ mediapipe_cc_test( "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", + "//mediapipe/framework/testdata:night_light_calculator_cc_proto", "//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib", "//mediapipe/util:header_util", @@ -486,7 +487,6 @@ cc_library( deps = [ ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework/deps:proto_descriptor_cc_proto", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:numbers", "//mediapipe/framework/port:ret_check", @@ -738,9 +738,7 @@ cc_test( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:graph_service_manager", - "//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:packet", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:status_handler", @@ -923,7 +921,6 @@ cc_test( "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:stream_handler_cc_proto", "//mediapipe/framework:subgraph", "//mediapipe/framework:test_calculators", "//mediapipe/framework/port:gtest_main", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 27d91f21a..10a8d7fff 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -783,7 +783,6 @@ cc_library( ":image_frame_view", ":shader_util", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework:calculator_cc_proto", "@com_google_absl//absl/base:core_headers", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", From e2052a6a517fe1d8ce487f46a9856a225644d3f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 11:11:22 -0800 Subject: [PATCH 065/137] Rename embedding postprocessor "configure" method for consistency with classification postprocessor. PiperOrigin-RevId: 489518257 --- .../audio/audio_embedder/audio_embedder_graph.cc | 10 ++++++---- .../processors/embedding_postprocessing_graph.cc | 6 +++--- .../processors/embedding_postprocessing_graph.h | 2 +- .../embedding_postprocessing_graph_test.cc | 14 +++++++------- .../cc/text/text_embedder/text_embedder_graph.cc | 10 ++++++---- .../vision/image_embedder/image_embedder_graph.cc | 10 ++++++---- 6 files changed, 29 insertions(+), 23 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index 7667feaa3..f093b4d25 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -158,10 +158,12 @@ class AudioEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio embedding on // audio files. Disables timestamp aggregation by not connecting the diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc index 880aec5d7..ad4881e12 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.cc @@ -150,7 +150,7 @@ absl::StatusOr> GetHeadNames( } // namespace -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options) { @@ -193,8 +193,8 @@ absl::Status ConfigureEmbeddingPostprocessing( // timestamp aggregation is required. // // The recommended way of using this graph is through the GraphBuilder API using -// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more -// details. +// the 'ConfigureEmbeddingPostprocessingGraph()' function. See header file for +// more details. class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h index 58606ed80..889992463 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h @@ -58,7 +58,7 @@ namespace processors { // The embedding result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -absl::Status ConfigureEmbeddingPostprocessing( +absl::Status ConfigureEmbeddingPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::EmbedderOptions& embedder_options, proto::EmbeddingPostprocessingGraphOptions* options); diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 84d84d648..163e46ee8 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -95,8 +95,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -117,8 +117,8 @@ TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { options_in.set_quantize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -138,8 +138,8 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { options_in.set_l2_normalize(true); proto::EmbeddingPostprocessingGraphOptions options_out; - MP_ASSERT_OK(ConfigureEmbeddingPostprocessing(*model_resources, options_in, - &options_out)); + MP_ASSERT_OK(ConfigureEmbeddingPostprocessingGraph(*model_resources, + options_in, &options_out)); EXPECT_THAT( options_out, @@ -164,7 +164,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors." "EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing( + MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessingGraph( *model_resources, options, &postprocessing .GetOptions())); diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index 79eedb6b5..c54636ee2 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -128,10 +128,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding result. diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index 11e25144c..bf0dcf3c7 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -151,10 +151,12 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { // inference results. auto& postprocessing = graph.AddNode( "mediapipe.tasks.components.processors.EmbeddingPostprocessingGraph"); - MP_RETURN_IF_ERROR(components::processors::ConfigureEmbeddingPostprocessing( - model_resources, task_options.embedder_options(), - &postprocessing.GetOptions())); + MP_RETURN_IF_ERROR( + components::processors::ConfigureEmbeddingPostprocessingGraph( + model_resources, task_options.embedder_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the embedding results. From 71ae496a2001d1206b792bedd45d4027d7f043c7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 12:10:47 -0800 Subject: [PATCH 066/137] Add AudioEmbedder documentation PiperOrigin-RevId: 489532283 --- .../audio_embedder/audio_embedder_graph.cc | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc index f093b4d25..187f11f7f 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_graph.cc @@ -100,6 +100,46 @@ void ConfigureAudioToTensorCalculator( } } // namespace +// An "AudioEmebdderGraph" performs embedding extractions. +// - Accepts CPU audio buffer and outputs embedding results on CPU. +// +// Inputs: +// AUDIO - Matrix +// Audio buffer to perform classification on. +// SAMPLE_RATE - double @Optional +// The sample rate of the corresponding audio data in the "AUDIO" stream. +// If sample rate is not provided, the "AUDIO" stream must carry a time +// series stream header with sample rate info. +// +// Outputs: +// EMBEDDINGS - EmbeddingResult @Optional +// The embedding results aggregated by head. Only produces results if +// the graph if the 'use_stream_mode' option is true. +// TIMESTAMPED_EMBEDDINGS - std::vector @Optional +// The embedding result aggregated by timestamp, then by head. Only +// produces results if the graph if the 'use_stream_mode' option is false. +// +// Example: +// node { +// calculator: "mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph" +// input_stream: "AUDIO:audio_in" +// input_stream: "SAMPLE_RATE:sample_rate_in" +// output_stream: "EMBEDDINGS:embeddings_out" +// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out" +// options { +// [mediapipe.tasks.audio.audio_embedder.proto.AudioEmbedderGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// embedder_options { +// l2_normalize: true +// } +// } +// } +// } class AudioEmbedderGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( From 1b594a0310f9c1bc3ece2562455bba0f812efd3a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Fri, 18 Nov 2022 12:42:58 -0800 Subject: [PATCH 067/137] Return error status when any tflite input and output tensor doesn't have valid dimensionality information that is needed to allocate Gl/Metal buffer before calling ModifyGraphWithDelegate. PiperOrigin-RevId: 489539740 --- mediapipe/calculators/tensor/BUILD | 2 ++ mediapipe/calculators/tensor/inference_calculator_gl.cc | 8 ++++++++ .../calculators/tensor/inference_calculator_metal.cc | 7 +++++++ 3 files changed, 17 insertions(+) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 4c06df0ff..2a573fc44 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -464,6 +464,7 @@ cc_library( "//mediapipe/gpu:gl_calculator_helper", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", ], alwayslink = 1, @@ -513,6 +514,7 @@ cc_library( "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/util/tflite:config", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index bd8eb3eed..27b8bc23a 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -20,6 +20,7 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.pb.h" #include "mediapipe/framework/calculator_context.h" @@ -154,6 +155,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ @@ -171,6 +176,9 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::LoadDelegate( // Create and bind output buffers. for (int i = 0; i < output_size_; ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); gpu_buffers_out_.emplace_back(absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{std::vector{ diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index a85071f3e..750f0456e 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -22,6 +22,7 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "mediapipe/calculators/tensor/inference_calculator.h" #import "mediapipe/gpu/MPPMetalHelper.h" #include "mediapipe/gpu/MPPMetalUtil.h" @@ -245,6 +246,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( const auto& input_indices = interpreter_->inputs(); for (int i = 0; i < input_indices.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Input tensor at index [%d] doesn't specify dimensions.", + input_indices[i]); // Create and bind input buffer. std::vector dims{tensor->dims->data, tensor->dims->data + tensor->dims->size}; @@ -266,6 +270,9 @@ absl::Status InferenceCalculatorMetalImpl::CreateConverters( output_shapes_.resize(output_indices.size()); for (int i = 0; i < output_shapes_.size(); ++i) { const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size > 0) << absl::StrFormat( + "Output tensor at index [%d] doesn't specify dimensions.", + output_indices[i]); RET_CHECK(tensor->dims->size <= 4); // Create and bind output buffers. // Channels are always padded to multiple of 4. From 524ac3ca61dc165f23a8d6ce29a9ff36d2fa7e98 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 12:45:56 -0800 Subject: [PATCH 068/137] Internal change for Model Maker PiperOrigin-RevId: 489540387 --- mediapipe/model_maker/python/core/tasks/classifier.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 200726864..f376edffa 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -91,6 +91,10 @@ class Classifier(custom_model.CustomModel): self._history = self._model.fit( x=train_dataset, epochs=self._hparams.epochs, + # `steps_per_epoch` is intentionally set to None in case the dataset + # is not repeated. Otherwise, the training process will stop when the + # dataset is exhausted even if there are epochs remaining. + steps_per_epoch=None, validation_data=validation_dataset, callbacks=self._callbacks) From bbd5da7971aa0d39bbeba638de34ded860bd30b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 18 Nov 2022 17:10:54 -0800 Subject: [PATCH 069/137] Added the gray scale image support for the ImageToTensorCalculator on CPU. PiperOrigin-RevId: 489593917 --- .../tensor/image_to_tensor_calculator_test.cc | 79 ++++++++++++++++--- .../image_to_tensor_converter_opencv.cc | 29 ++++--- .../tensor/image_to_tensor_utils.cc | 7 +- 3 files changed, 93 insertions(+), 22 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 07a5f9fe1..7ea60d98e 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -54,6 +54,13 @@ cv::Mat GetRgba(absl::string_view path) { return rgb; } +cv::Mat GetGray(absl::string_view path) { + cv::Mat bgr = cv::imread(file::JoinPath("./", path)); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + // Image to tensor test template. // No processing/assertions should be done after the function is invoked. void RunTestWithInputImagePacket(const Packet& input_image_packet, @@ -147,29 +154,34 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, ASSERT_THAT(tensor_vec, testing::SizeIs(1)); const Tensor& tensor = tensor_vec[0]; + const int channels = tensor.shape().dims[3]; + ASSERT_TRUE(channels == 1 || channels == 3); auto view = tensor.GetCpuReadView(); cv::Mat tensor_mat; if (output_int_tensor) { if (range_min < 0) { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8SC1 : CV_8SC3, const_cast(view.buffer())); } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_8UC1 : CV_8UC3, const_cast(view.buffer())); } } else { EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); - tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, + tensor_mat = cv::Mat(tensor_height, tensor_width, + channels == 1 ? CV_32FC1 : CV_32FC3, const_cast(view.buffer())); } cv::Mat result_rgb; auto transformation = GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); - tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, - transformation.offset); + tensor_mat.convertTo(result_rgb, channels == 1 ? CV_8UC1 : CV_8UC3, + transformation.scale, transformation.offset); cv::Mat diff; cv::absdiff(result_rgb, expected_result, diff); @@ -185,17 +197,27 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet, MP_ASSERT_OK(graph.WaitUntilDone()); } +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + CHECK(false) << "Unsupported input image channles: " << image_channels; +} + Packet MakeImageFramePacket(cv::Mat input) { - ImageFrame input_image( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {}); + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input) { mediapipe::Image input_image(std::make_shared( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {})); + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } @@ -429,6 +451,24 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { /*border_mode=*/{}, roi); } +TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation.png"), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/{}, roi); +} + TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { mediapipe::NormalizedRect roi; @@ -448,6 +488,25 @@ TEST(ImageToTensorCalculatorTest, /*border_mode=*/BorderMode::kZero, roi); } +TEST(ImageToTensorCalculatorTest, + LargeSubRectKeepAspectWithRotationBorderZeroGray) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.5f); + roi.set_y_center(0.5f); + roi.set_width(1.5f); + roi.set_height(1.1f); + roi.set_rotation(M_PI * -15.0f / 180.0f); + RunTest(GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"), + GetGray("/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + /*border_mode=*/BorderMode::kZero, roi); +} + TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { mediapipe::NormalizedRect roi; roi.set_x_center(0.5f); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index f910b59f3..76e46f99d 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -48,15 +48,19 @@ class OpenCvProcessor : public ImageToTensorConverter { switch (tensor_type_) { case Tensor::ElementType::kInt8: mat_type_ = CV_8SC3; + mat_gray_type_ = CV_8SC1; break; case Tensor::ElementType::kFloat32: mat_type_ = CV_32FC3; + mat_gray_type_ = CV_32FC1; break; case Tensor::ElementType::kUInt8: mat_type_ = CV_8UC3; + mat_gray_type_ = CV_8UC1; break; default: mat_type_ = -1; + mat_gray_type_ = -1; } } @@ -64,11 +68,13 @@ class OpenCvProcessor : public ImageToTensorConverter { float range_min, float range_max, int tensor_buffer_offset, Tensor& output_tensor) override { - if (input.image_format() != mediapipe::ImageFormat::SRGB && - input.image_format() != mediapipe::ImageFormat::SRGBA) { - return InvalidArgumentError( - absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", - static_cast(input.image_format()))); + const bool is_supported_format = + input.image_format() == mediapipe::ImageFormat::SRGB || + input.image_format() == mediapipe::ImageFormat::SRGBA || + input.image_format() == mediapipe::ImageFormat::GRAY8; + if (!is_supported_format) { + return InvalidArgumentError(absl::StrCat( + "Unsupported format: ", static_cast(input.image_format()))); } // TODO: Remove the check once tensor_buffer_offset > 0 is // supported. @@ -82,17 +88,18 @@ class OpenCvProcessor : public ImageToTensorConverter { const int output_channels = output_shape.dims[3]; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; + const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, mat_type_, + dst = cv::Mat(output_height, output_width, dst_data_type, buffer_view.buffer()); break; default: @@ -137,7 +144,8 @@ class OpenCvProcessor : public ImageToTensorConverter { auto transform, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); + transformed.convertTo(dst, dst_data_type, transform.scale, + transform.offset); return absl::OkStatus(); } @@ -148,7 +156,7 @@ class OpenCvProcessor : public ImageToTensorConverter { RET_CHECK_EQ(output_shape.dims[0], 1) << "Handling batch dimension not equal to 1 is not implemented in this " "converter."; - RET_CHECK_EQ(output_shape.dims[3], 3) + RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); } @@ -156,6 +164,7 @@ class OpenCvProcessor : public ImageToTensorConverter { enum cv::BorderTypes border_mode_; Tensor::ElementType tensor_type_; int mat_type_; + int mat_gray_type_; }; } // namespace diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index 3f4c05d4e..d27c595b5 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,8 +253,11 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // All of the processors except for Metal expect 3 channels. - return 3; + // The output tensor channel is 1 for the input image with 1 channel; And the + // output tensor channels is 3 for the input image with 3 or 4 channels. + // TODO: Add a unittest here to test the behavior on GPU, i.e. + // failure. + return image.channels() == 1 ? 1 : 3; } absl::StatusOr> GetInputImage( From eb8ef1ace0a2b4c84c04a468478d8eb8463daeed Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Fri, 18 Nov 2022 19:41:05 -0800 Subject: [PATCH 070/137] Use shared_from_this in GlTextureBuffer::GetReadView, GetWriteView This ensures that the callbacks in GlTextureView won't call an expired object, even if user code holds a GlTextureView after releasing the buffer. Note that GlTextureBuffer is not always held by a shared_ptr, but it always is when GpuBuffer calls GetRead/WriteView on it. An alternative solution would have been to have GpuBuffer pass its shared_ptr to the view method, which could have been implemented with some compile-time logic to detect whether the method expects such an argument. However, that doesn't seem necessary. PiperOrigin-RevId: 489611843 --- mediapipe/gpu/gl_texture_buffer.cc | 23 +++++++++++++++++------ mediapipe/gpu/gl_texture_buffer.h | 3 ++- mediapipe/gpu/gpu_buffer_test.cc | 22 ++++++++++++++++++++++ 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 09703d89d..7f77cd4b3 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -260,13 +260,18 @@ GlTextureView GlTextureBuffer::GetReadView(internal::types, auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); - GlTextureView::DetachFn detach = [this](GlTextureView& texture) { - // Inform the GlTextureBuffer that we have finished accessing its - // contents, and create a consumer sync point. - DidRead(texture.gl_context()->CreateSyncToken()); - }; + GlTextureView::DetachFn detach = + [texbuf = shared_from_this()](GlTextureView& texture) { + // Inform the GlTextureBuffer that we have finished accessing its + // contents, and create a consumer sync point. + texbuf->DidRead(texture.gl_context()->CreateSyncToken()); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), plane, std::move(detach), nullptr); } @@ -276,12 +281,18 @@ GlTextureView GlTextureBuffer::GetWriteView(internal::types, auto gl_context = GlContext::GetCurrent(); CHECK(gl_context); CHECK_EQ(plane, 0); + // Note that this method is only supposed to be called by GpuBuffer, which + // ensures this condition is satisfied. + DCHECK(!weak_from_this().expired()) + << "GlTextureBuffer must be held in shared_ptr to get a GlTextureView"; // Insert wait call to sync with the producer. WaitOnGpu(); Reuse(); // TODO: the producer wait should probably be part of Reuse in the // case when there are no consumers. GlTextureView::DoneWritingFn done_writing = - [this](const GlTextureView& texture) { ViewDoneWriting(texture); }; + [texbuf = shared_from_this()](const GlTextureView& texture) { + texbuf->ViewDoneWriting(texture); + }; return GlTextureView(gl_context.get(), target(), name(), width(), height(), plane, nullptr, std::move(done_writing)); } diff --git a/mediapipe/gpu/gl_texture_buffer.h b/mediapipe/gpu/gl_texture_buffer.h index c7643fd1b..f785571a1 100644 --- a/mediapipe/gpu/gl_texture_buffer.h +++ b/mediapipe/gpu/gl_texture_buffer.h @@ -35,7 +35,8 @@ class GlCalculatorHelperImpl; // Implements a GPU memory buffer as an OpenGL texture. For internal use. class GlTextureBuffer : public internal::GpuBufferStorageImpl< - GlTextureBuffer, internal::ViewProvider> { + GlTextureBuffer, internal::ViewProvider>, + public std::enable_shared_from_this { public: // This is called when the texture buffer is deleted. It is passed a sync // token created at that time on the GlContext. If the GlTextureBuffer has diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index 796cb1d9d..145b71806 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -14,6 +14,8 @@ #include "mediapipe/gpu/gpu_buffer.h" +#include + #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -206,5 +208,25 @@ TEST_F(GpuBufferTest, Overwrite) { } } +TEST_F(GpuBufferTest, GlTextureViewRetainsWhatItNeeds) { + GpuBuffer buffer(300, 200, GpuBufferFormat::kBGRA32); + { + std::shared_ptr view = buffer.GetWriteView(); + EXPECT_EQ(view->Width(), 300); + EXPECT_EQ(view->Height(), 200); + FillImageFrameRGBA(*view, 255, 0, 0, 255); + } + + RunInGlContext([buffer = std::move(buffer)]() mutable { + // This is not a recommended pattern, but let's make sure that we don't + // crash if the buffer is released before the view. The view can hold + // callbacks into its underlying storage. + auto view = buffer.GetReadView(0); + buffer = nullptr; + }); + // We're really checking that we haven't crashed. + EXPECT_TRUE(true); +} + } // anonymous namespace } // namespace mediapipe From e853f04b79bb47e9542f54ba34065de3c5dcbd73 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 18 Nov 2022 19:53:21 -0800 Subject: [PATCH 071/137] Create AudioTaskRunner PiperOrigin-RevId: 489613573 --- .../tasks/audio/core/BaseAudioTaskApi.java | 1 + .../tasks/web/audio/audio_classifier/BUILD | 4 +- .../audio_classifier/audio_classifier.ts | 53 ++++++++--------- mediapipe/tasks/web/audio/core/BUILD | 14 ++++- .../web/audio/core/audio_task_options.d.ts | 21 ------- .../tasks/web/audio/core/audio_task_runner.ts | 58 +++++++++++++++++++ 6 files changed, 98 insertions(+), 53 deletions(-) create mode 100644 mediapipe/tasks/web/audio/core/audio_task_runner.ts diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 8eaf0adcb..2782f8d36 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -116,6 +116,7 @@ public class BaseAudioTaskApi implements AutoCloseable { defaultSampleRate = sampleRate; } } + /** * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 9e1fcbc51..498b17845 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -17,14 +17,14 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 5533b0eaa..0c54a4718 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -18,10 +18,10 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {AudioClassifierGraphOptions} from '../../../../tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -47,9 +47,8 @@ const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ -export class AudioClassifier extends TaskRunner { +export class AudioClassifier extends AudioTaskRunner { private classificationResults: AudioClassifierResult[] = []; - private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); /** @@ -111,6 +110,14 @@ export class AudioClassifier extends TaskRunner { wasmLoaderOptions, new Uint8Array(graphData)); } + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + /** * Sets new options for the audio classifier. * @@ -120,34 +127,19 @@ export class AudioClassifier extends TaskRunner { * * @param options The options for the audio classifier. */ - async setOptions(options: AudioClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: AudioClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } /** - * Sets the sample rate for all calls to `classify()` that omit an explicit - * sample rate. `48000` is used as a default if this method is not called. - * - * @param sampleRate A sample rate (e.g. `44100`). - */ - setDefaultSampleRate(sampleRate: number) { - this.defaultSampleRate = sampleRate; - } - - /** - * Performs audio classification on the provided audio data and waits + * Performs audio classification on the provided audio clip and waits * synchronously for the response. * - * @param audioData An array of raw audio capture data, like - * from a call to getChannelData on an AudioBuffer. + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. * @param sampleRate The sample rate in Hz of the provided audio data. If not * set, defaults to the sample rate set via `setDefaultSampleRate()` or * `48000` if no custom default was set. @@ -155,18 +147,21 @@ export class AudioClassifier extends TaskRunner { */ classify(audioData: Float32Array, sampleRate?: number): AudioClassifierResult[] { - sampleRate = sampleRate ?? this.defaultSampleRate; + return this.processAudioClip(audioData, sampleRate); + } + /** Sends an audio package to the graph and returns the classifications. */ + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioClassifierResult[] { // Configures the number of samples in the WASM layer. We re-configure the // number of samples and the sample rate for every frame, but ignore other // side effects of this function (such as sending the input side packet and // the input stream header). this.configureAudio( /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); - - const timestamp = performance.now(); - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); - this.addAudioToStream(audioData, timestamp); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index ed60f2435..91ebbf524 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,6 +1,6 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) @@ -11,3 +11,15 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/core", ], ) + +mediapipe_ts_library( + name = "audio_task_runner", + srcs = ["audio_task_runner.ts"], + deps = [ + ":audio_task_options", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner", + ], +) diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts index 58a6e55d8..e3068625d 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts @@ -16,29 +16,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; -/** - * MediaPipe audio task running mode. A MediaPipe audio task can be run with - * two different modes: - * - audio_clips: The mode for running a mediapipe audio task on independent - * audio clips. - * - audio_stream: The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - * - */ -export type RunningMode = 'audio_clips'|'audio_stream'; - /** The options for configuring a MediaPipe Audio Task. */ export declare interface AudioTaskOptions { /** Options to configure the loading of the model assets. */ baseOptions?: BaseOptions; - - /** - * The running mode of the task. Default to the audio_clips mode. - * Audio tasks have two running modes: - * 1) The mode for running a mediapipe audio task on independent - * audio clips. - * 2) The mode for running a mediapipe audio task on an audio - * stream, such as from a microphone. - */ - runningMode?: RunningMode; } diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts new file mode 100644 index 000000000..ceff3895b --- /dev/null +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -0,0 +1,58 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; + +import {AudioTaskOptions} from './audio_task_options'; + +/** Base class for all MediaPipe Audio Tasks. */ +export abstract class AudioTaskRunner extends TaskRunner { + protected abstract baseOptions?: BaseOptionsProto|undefined; + private defaultSampleRate = 48000; + + /** Configures the shared options of an audio task. */ + async setOptions(options: AudioTaskOptions): Promise { + this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + + /** + * Sets the sample rate for API calls that omit an explicit sample rate. + * `48000` is used as a default if this method is not called. + * + * @param sampleRate A sample rate (e.g. `44100`). + */ + setDefaultSampleRate(sampleRate: number) { + this.defaultSampleRate = sampleRate; + } + + /** Sends an audio packet to the graph and awaits results. */ + protected abstract process( + audioData: Float32Array, sampleRate: number, timestampMs: number): T; + + /** Sends a single audio clip to the graph and awaits results. */ + protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { + return this.process( + audioData, sampleRate ?? this.defaultSampleRate, performance.now()); + } +} + + From bbcbd5fc6c8fcefaf45da9c126a6f7aa8b6386c2 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Sat, 19 Nov 2022 04:47:55 -0800 Subject: [PATCH 072/137] Audio Embedder for Web PiperOrigin-RevId: 489669966 --- mediapipe/tasks/web/BUILD | 1 + mediapipe/tasks/web/audio.ts | 4 +- mediapipe/tasks/web/audio/BUILD | 1 + .../tasks/web/audio/audio_embedder/BUILD | 43 ++++ .../audio/audio_embedder/audio_embedder.ts | 211 ++++++++++++++++++ .../audio_embedder_options.d.ts | 22 ++ .../audio_embedder/audio_embedder_result.d.ts | 17 ++ mediapipe/tasks/web/audio/index.ts | 1 + 8 files changed, 299 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/web/audio/audio_embedder/BUILD create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts create mode 100644 mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index e9703e37a..af76a1fe8 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -26,6 +26,7 @@ mediapipe_ts_library( srcs = ["audio.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 764fd8393..056426f50 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -15,9 +15,11 @@ */ import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; +import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; +const AudioEmbedder = AudioEmbedderImpl; -export {AudioClassifier}; +export {AudioClassifier, AudioEmbedder}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index 4f6e48b28..acd7494d7 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -9,5 +9,6 @@ mediapipe_ts_library( srcs = ["index.ts"], deps = [ "//mediapipe/tasks/web/audio/audio_classifier", + "//mediapipe/tasks/web/audio/audio_embedder", ], ) diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD new file mode 100644 index 000000000..7d9a994a3 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -0,0 +1,43 @@ +# This contains the MediaPipe Audio Embedder Task. +# +# This task takes audio input and performs embedding. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "audio_embedder", + srcs = ["audio_embedder.ts"], + deps = [ + ":audio_embedder_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_jspb_proto", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/audio/core:audio_task_runner", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/components/processors:embedder_options", + "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "audio_embedder_types", + srcs = [ + "audio_embedder_options.d.ts", + "audio_embedder_result.d.ts", + ], + deps = [ + "//mediapipe/tasks/web/audio/core:audio_task_options", + "//mediapipe/tasks/web/components/containers:embedding_result", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:embedder_options", + ], +) diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts new file mode 100644 index 000000000..51cb819de --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -0,0 +1,211 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../../../../tasks/cc/audio/audio_embedder/proto/audio_embedder_graph_options_pb'; +import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; +import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {AudioEmbedderOptions} from './audio_embedder_options'; +import {AudioEmbedderResult} from './audio_embedder_result'; + +export * from './audio_embedder_options'; +export * from './audio_embedder_result'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot +// be changed +// TODO: Change this to `audio_in` to match the name in the CC +// implementation +const AUDIO_STREAM = 'input_audio'; +const SAMPLE_RATE_STREAM = 'sample_rate'; +const EMBEDDINGS_STREAM = 'embeddings_out'; +const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; +const AUDIO_EMBEDDER_CALCULATOR = + 'mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph'; + +/** Performs embedding extraction on audio. */ +export class AudioEmbedder extends AudioTaskRunner { + private embeddingResults: AudioEmbedderResult[] = []; + private readonly options = new AudioEmbedderGraphOptionsProto(); + + /** + * Initializes the Wasm runtime and creates a new audio embedder from the + * provided options. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param audioEmbedderOptions The options for the audio embedder. Note that + * either a path to the TFLite model or the model itself needs to be + * provided (via `baseOptions`). + */ + static async createFromOptions( + wasmLoaderOptions: WasmLoaderOptions, + audioEmbedderOptions: AudioEmbedderOptions): Promise { + // Create a file locator based on the loader options + const fileLocator: FileLocator = { + locateFile() { + // The only file we load is the Wasm binary + return wasmLoaderOptions.wasmBinaryPath.toString(); + } + }; + + const embedder = await createMediaPipeLib( + AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + /* assetLoaderScript= */ undefined, + /* glCanvas= */ undefined, fileLocator); + await embedder.setOptions(audioEmbedderOptions); + return embedder; + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * provided model asset buffer. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the TFLite model. + */ + static createFromModelBuffer( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetBuffer: Uint8Array): Promise { + return AudioEmbedder.createFromOptions( + wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new audio embedder based on the + * path to the model asset. + * @param wasmLoaderOptions A configuration object that provides the location + * of the Wasm binary and its loader. + * @param modelAssetPath The path to the TFLite model. + */ + static async createFromModelPath( + wasmLoaderOptions: WasmLoaderOptions, + modelAssetPath: string): Promise { + const response = await fetch(modelAssetPath.toString()); + const graphData = await response.arrayBuffer(); + return AudioEmbedder.createFromModelBuffer( + wasmLoaderOptions, new Uint8Array(graphData)); + } + + protected override get baseOptions(): BaseOptionsProto|undefined { + return this.options.getBaseOptions(); + } + + protected override set baseOptions(proto: BaseOptionsProto|undefined) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the audio embedder. + * + * Calling `setOptions()` with a subset of options only affects those options. + * You can reset an option back to its default value by explicitly setting it + * to `undefined`. + * + * @param options The options for the audio embedder. + */ + override async setOptions(options: AudioEmbedderOptions): Promise { + await super.setOptions(options); + this.options.setEmbedderOptions(convertEmbedderOptionsToProto( + options, this.options.getEmbedderOptions())); + this.refreshGraph(); + } + + /** + * Performs embeding extraction on the provided audio clip and waits + * synchronously for the response. + * + * @param audioData An array of raw audio capture data, like from a call to + * `getChannelData()` on an AudioBuffer. + * @param sampleRate The sample rate in Hz of the provided audio data. If not + * set, defaults to the sample rate set via `setDefaultSampleRate()` or + * `48000` if no custom default was set. + * @return The embedding resuls of the audio + */ + embed(audioData: Float32Array, sampleRate?: number): AudioEmbedderResult[] { + return this.processAudioClip(audioData, sampleRate); + } + + protected override process( + audioData: Float32Array, sampleRate: number, + timestampMs: number): AudioEmbedderResult[] { + // Configures the number of samples in the WASM layer. We re-configure the + // number of samples and the sample rate for every frame, but ignore other + // side effects of this function (such as sending the input side packet and + // the input stream header). + this.configureAudio( + /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); + this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.addAudioToStream(audioData, timestampMs); + + this.embeddingResults = []; + this.finishProcessing(); + return this.embeddingResults; + } + + /** Updates the MediaPipe graph configuration. */ + private refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(AUDIO_STREAM); + graphConfig.addInputStream(SAMPLE_RATE_STREAM); + graphConfig.addOutputStream(EMBEDDINGS_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_EMBEDDINGS_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + AudioEmbedderGraphOptionsProto.ext, this.options); + + const embedderNode = new CalculatorGraphConfig.Node(); + embedderNode.setCalculator(AUDIO_EMBEDDER_CALCULATOR); + embedderNode.addInputStream('AUDIO:' + AUDIO_STREAM); + embedderNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); + embedderNode.addOutputStream('EMBEDDINGS:' + EMBEDDINGS_STREAM); + embedderNode.addOutputStream( + 'TIMESTAMPED_EMBEDDINGS:' + TIMESTAMPED_EMBEDDINGS_STREAM); + embedderNode.setOptions(calculatorOptions); + + graphConfig.addNode(embedderNode); + + this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + }); + + this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + + diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts new file mode 100644 index 000000000..98f412d0f --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; +import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; + +/** Options to configure the MediaPipe Audio Embedder Task */ +export declare interface AudioEmbedderOptions extends EmbedderOptions, + AudioTaskOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts new file mode 100644 index 000000000..13abc28d9 --- /dev/null +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_result.d.ts @@ -0,0 +1,17 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export {Embedding, EmbeddingResult as AudioEmbedderResult} from '../../../../tasks/web/components/containers/embedding_result'; diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index a5083b326..17a908f30 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -15,3 +15,4 @@ */ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; +export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; From 977ee4272e90272fef0ab140036816e83e05c615 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 19 Nov 2022 10:51:20 -0800 Subject: [PATCH 073/137] Add public visibility to the model maker public API. PiperOrigin-RevId: 489701768 --- mediapipe/model_maker/python/text/text_classifier/BUILD | 7 +++++++ .../model_maker/python/vision/gesture_recognizer/BUILD | 7 +++++++ mediapipe/model_maker/python/vision/image_classifier/BUILD | 7 +++++++ 3 files changed, 21 insertions(+) diff --git a/mediapipe/model_maker/python/text/text_classifier/BUILD b/mediapipe/model_maker/python/text/text_classifier/BUILD index 0c35e7966..7bb41351e 100644 --- a/mediapipe/model_maker/python/text/text_classifier/BUILD +++ b/mediapipe/model_maker/python/text/text_classifier/BUILD @@ -21,9 +21,16 @@ package( licenses(["notice"]) +###################################################################### +# Public target of the MediaPipe Model Maker TextCassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/text/text_classifier/customize for +# more information about the MediaPipe Model Maker TextCassifier APIs. +###################################################################### py_library( name = "text_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":model_options", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b7d334d9c..b9425a181 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -103,9 +103,16 @@ py_library( ], ) +###################################################################### +# Public target of the MediaPipe Model Maker GestureRecognizer APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer/customize +# for more information about the MediaPipe Model Maker GestureRecognizer APIs. +###################################################################### py_library( name = "gesture_recognizer_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":gesture_recognizer", diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index c581d9fbc..29ae189e9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -21,9 +21,16 @@ package( default_visibility = ["//mediapipe:__subpackages__"], ) +###################################################################### +# Public target of the MediaPipe Model Maker ImageClassifier APIs. + +# Please see https://developers.google.com/mediapipe/solutions/vision/image_classifier/customize for +# more information about the MediaPipe Model Maker ImageClassifier APIs. +###################################################################### py_library( name = "image_classifier_import", srcs = ["__init__.py"], + visibility = ["//visibility:public"], deps = [ ":dataset", ":hyperparameters", From a33cb1e05e602cb06b6e6ecdc3a12dad82f5f4e4 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sat, 19 Nov 2022 21:03:29 -0800 Subject: [PATCH 074/137] Check that Java buffer supports direct access before using it If the buffer is not created with allocateDirect, JNI APIs will return a data pointer of nullptr and a capacity of -1. This can cause a crash when we access it. Also clean up the code to raise exceptions instead of just logging errors and returning nullptr. PiperOrigin-RevId: 489751312 --- .../framework/jni/packet_creator_jni.cc | 171 +++++++++++------- .../framework/jni/packet_getter_jni.cc | 42 +++-- 2 files changed, 133 insertions(+), 80 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 250d7c938..2d5447401 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -17,6 +17,8 @@ #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/camera_intrinsics.h" #include "mediapipe/framework/formats/image.h" @@ -107,17 +109,18 @@ absl::StatusOr CreateGpuBuffer( // Create a 1, 3, or 4 channel 8-bit ImageFrame shared pointer from a Java // ByteBuffer. -std::unique_ptr CreateImageFrameFromByteBuffer( - JNIEnv* env, jobject byte_buffer, jint width, jint height, - mediapipe::ImageFormat::Format format) { +absl::StatusOr> +CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, + jint height, + mediapipe::ImageFormat::Format format) { switch (format) { case mediapipe::ImageFormat::SRGBA: case mediapipe::ImageFormat::SRGB: case mediapipe::ImageFormat::GRAY8: break; default: - LOG(ERROR) << "Format must be either SRGBA, SRGB, or GRAY8."; - return nullptr; + return absl::InvalidArgumentError( + "Format must be either SRGBA, SRGB, or GRAY8."); } auto image_frame = std::make_unique( @@ -125,25 +128,30 @@ std::unique_ptr CreateImageFrameFromByteBuffer( mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + const int num_channels = image_frame->NumberOfChannels(); const int expected_buffer_size = num_channels == 1 ? width * height : image_frame->PixelDataSize(); - if (buffer_size != expected_buffer_size) { - if (num_channels != 1) - LOG(ERROR) << "The input image buffer should have 4 bytes alignment."; - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; - return nullptr; - } + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << (num_channels != 1 + ? "The input image buffer should have 4 bytes alignment. " + : "") + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; // Copy buffer data to image frame's pixel_data_. if (num_channels == 1) { const int width_step = image_frame->WidthStep(); - const char* src_row = - reinterpret_cast(env->GetDirectBufferAddress(byte_buffer)); + const char* src_row = reinterpret_cast(buffer_data); char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); for (int i = height; i > 0; --i) { std::memcpy(dst_row, src_row, width); @@ -152,7 +160,6 @@ std::unique_ptr CreateImageFrameFromByteBuffer( } } else { // 3 and 4 channels. - const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); std::memcpy(image_frame->MutablePixelData(), buffer_data, image_frame->PixelDataSize()); } @@ -176,77 +183,100 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } +absl::StatusOr> CreateRgbImageFromRgba( + JNIEnv* env, jobject byte_buffer, jint width, jint height) { + const uint8_t* rgba_data = + static_cast(env->GetDirectBufferAddress(byte_buffer)); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It should be created " + "using allocateDirect."); + } + + const int expected_buffer_size = width * height * 4; + RET_CHECK_EQ(buffer_size, expected_buffer_size) + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << expected_buffer_size + << ", Image width: " << width; + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::SRGB, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, + image_frame->MutablePixelData(), + image_frame->WidthStep()); + return image_frame; +} + JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImageFromRgba)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const uint8_t* rgba_data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::SRGB, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != width * height * 4) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << width * height * 4 - << ", Image width: " << width; - return 0L; - } - mediapipe::android::RgbaToRgb(rgba_data, width * 4, width, height, - image_frame->MutablePixelData(), - image_frame->WidthStep()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + auto image_frame_or = CreateRgbImageFromRgba(env, byte_buffer, width, height); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - const void* data = env->GetDirectBufferAddress(byte_buffer); - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (buffer_size != image_frame->PixelDataSize()) { - LOG(ERROR) << "Please check the input buffer size."; - LOG(ERROR) << "Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - return 0L; - } - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + // TODO: merge this case with CreateImageFrameFromByteBuffer. + auto image_frame_or = + [&]() -> absl::StatusOr> { + const void* data = env->GetDirectBufferAddress(byte_buffer); + int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (data == nullptr || buffer_size < 0) { + return absl::InvalidArgumentError( + "input buffer does not support direct access"); + } + + auto image_frame = absl::make_unique( + mediapipe::ImageFormat::VEC32F1, width, height, + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); + RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize()) + << "Please check the input buffer size." + << " Buffer size: " << buffer_size + << ", Buffer size needed: " << image_frame->PixelDataSize() + << ", Image width: " << width; + std::memcpy(image_frame->MutablePixelData(), data, + image_frame->PixelDataSize()); + return image_frame; + }(); + if (ThrowIfError(env, image_frame_or.status())) return 0L; + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame = CreateImageFrameFromByteBuffer( + auto image_frame_or = CreateImageFrameFromByteBuffer( env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame.release()); + mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -291,6 +321,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateAudioPacketDirect)( jint num_samples) { const uint8_t* audio_sample = reinterpret_cast(env->GetDirectBufferAddress(data)); + if (!audio_sample) { + ThrowIfError(env, absl::InvalidArgumentError( + "Cannot get direct access to the input buffer. It " + "should be created using allocateDirect.")); + return 0L; + } mediapipe::Packet packet = createAudioPacket(audio_sample, num_samples, num_channels); return CreatePacketWithContext(context, packet); @@ -360,8 +396,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEnv* env, jobject thiz, jlong context, jint rows, jint cols, jfloatArray data) { if (env->GetArrayLength(data) != rows * cols) { - LOG(ERROR) << "Please check the matrix data size, has to be rows * cols = " - << rows * cols; + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Please check the matrix data size, has to be rows * cols = ", + rows * cols))); return 0L; } std::unique_ptr matrix(new mediapipe::Matrix(rows, cols)); @@ -392,16 +430,18 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( format = mediapipe::ImageFormat::GRAY8; break; default: - LOG(ERROR) << "Channels must be either 1, 3, or 4."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Channels must be either 1, 3, or 4, but are ", + num_channels))); return 0L; } - auto image_frame = + auto image_frame_or = CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); - if (nullptr == image_frame) return 0L; + if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = - mediapipe::MakePacket(std::move(image_frame)); + mediapipe::MakePacket(*std::move(image_frame_or)); return CreatePacketWithContext(context, packet); } @@ -502,7 +542,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCalculatorOptions)( jbyte* data_ref = env->GetByteArrayElements(data, nullptr); auto options = absl::make_unique(); if (!options->ParseFromArray(data_ref, count)) { - LOG(ERROR) << "Parsing binary-encoded CalculatorOptions failed."; + ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat( + "Parsing binary-encoded CalculatorOptions failed."))); return 0L; } mediapipe::Packet packet = mediapipe::Adopt(options.release()); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index c215dd929..737f6db72 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -14,6 +14,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h" +#include "absl/strings/str_cat.h" #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" @@ -299,34 +300,38 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)( : GetFromNativeHandle(packet); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + void* buffer_data = env->GetDirectBufferAddress(byte_buffer); + if (buffer_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } // Assume byte buffer stores pixel data contiguously. const int expected_buffer_size = image.Width() * image.Height() * image.ByteDepth() * image.NumberOfChannels(); if (buffer_size != expected_buffer_size) { - LOG(ERROR) << "Expected buffer size " << expected_buffer_size - << " got: " << buffer_size << ", width " << image.Width() - << ", height " << image.Height() << ", channels " - << image.NumberOfChannels(); + ThrowIfError( + env, absl::InvalidArgumentError(absl::StrCat( + "Expected buffer size ", expected_buffer_size, + " got: ", buffer_size, ", width ", image.Width(), ", height ", + image.Height(), ", channels ", image.NumberOfChannels()))); return false; } switch (image.ByteDepth()) { case 1: { - uint8* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint8* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 2: { - uint16* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + uint16* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } case 4: { - float* data = - static_cast(env->GetDirectBufferAddress(byte_buffer)); + float* data = static_cast(buffer_data); image.CopyToBuffer(data, expected_buffer_size); break; } @@ -351,12 +356,19 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)( uint8_t* rgba_data = static_cast(env->GetDirectBufferAddress(byte_buffer)); int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); + if (rgba_data == nullptr || buffer_size < 0) { + ThrowIfError(env, absl::InvalidArgumentError( + "input buffer does not support direct access")); + return false; + } if (buffer_size != image.Width() * image.Height() * 4) { - LOG(ERROR) << "Buffer size has to be width*height*4\n" - << "Image width: " << image.Width() - << ", Image height: " << image.Height() - << ", Buffer size: " << buffer_size << ", Buffer size needed: " - << image.Width() * image.Height() * 4; + ThrowIfError(env, + absl::InvalidArgumentError(absl::StrCat( + "Buffer size has to be width*height*4\n" + "Image width: ", + image.Width(), ", Image height: ", image.Height(), + ", Buffer size: ", buffer_size, ", Buffer size needed: ", + image.Width() * image.Height() * 4))); return false; } mediapipe::android::RgbToRgba(image.PixelData(), image.WidthStep(), From bdf4078e89cb11e01da0c5eda6322a22ad74e127 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sat, 19 Nov 2022 21:12:23 -0800 Subject: [PATCH 075/137] Internal change PiperOrigin-RevId: 489752009 --- mediapipe/model_maker/python/core/utils/BUILD | 1 + .../python/core/utils/model_util_test.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 12fef631f..492bba0a9 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -45,6 +45,7 @@ py_test( name = "model_util_test", srcs = ["model_util_test.py"], deps = [ + ":file_util", ":model_util", ":quantization", ":test_util", diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index 05c6ffe3f..f0020db25 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -14,10 +14,12 @@ import os from typing import Optional +from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf +from mediapipe.model_maker.python.core.utils import file_util from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util @@ -25,11 +27,15 @@ from mediapipe.model_maker.python.core.utils import test_util class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): - def test_load_keras_model(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_keras_model(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') model.save(saved_model_path) + # model_util.load_keras_model takes in a relative path to files within the + # model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = saved_model_path loaded_model = model_util.load_keras_model(saved_model_path) input_tensors = test_util.create_random_sample(size=[1, input_dim]) @@ -37,13 +43,16 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): loaded_model_output = loaded_model.predict_on_batch(input_tensors) self.assertTrue((model_output == loaded_model_output).all()) - def test_load_tflite_model_buffer(self): + @unittest_mock.patch.object(file_util, 'get_absolute_path', autospec=True) + def test_load_tflite_model_buffer(self, mock_get_absolute_path): input_dim = 4 model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_model = model_util.convert_to_tflite(model) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.save_tflite(tflite_model=tflite_model, tflite_file=tflite_file) - + # model_util.load_tflite_model_buffer takes in a relative path to files + # within the model_maker dir, so we patch the function for testing + mock_get_absolute_path.return_value = tflite_file tflite_model_buffer = model_util.load_tflite_model_buffer(tflite_file) test_util.test_tflite( keras_model=model, From a367753eda595f01a60e4ccb12845f2675cb37c5 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Sun, 20 Nov 2022 10:39:59 -0800 Subject: [PATCH 076/137] Internal change PiperOrigin-RevId: 489824381 --- .../vision/gesture_recognizer/gesture_recognizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 39272cbbc..9cee88362 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -14,7 +14,6 @@ import io import os -import random import tempfile from unittest import mock as unittest_mock import zipfile @@ -27,6 +26,7 @@ from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +tf.keras.backend.experimental.enable_tf_random_generator() class GestureRecognizerTest(tf.test.TestCase): @@ -42,7 +42,7 @@ class GestureRecognizerTest(tf.test.TestCase): def setUp(self): super().setUp() - random.seed(1234) + tf.keras.utils.set_random_seed(87654321) all_data = self._load_data() # Splits data, 90% data for training, 10% for validation self._train_data, self._validation_data = all_data.split(0.9) From 6cf464636b00fb5039bf705319ffe09408d207b3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 20 Nov 2022 14:24:21 -0800 Subject: [PATCH 077/137] Internal change PiperOrigin-RevId: 489842199 --- mediapipe/tasks/BUILD | 7 ++ .../tasks/cc/audio/audio_classifier/BUILD | 53 ++++++----- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 55 ++++++------ mediapipe/tasks/cc/audio/core/BUILD | 1 + .../tasks/cc/components/containers/BUILD | 2 +- .../tasks/cc/components/processors/BUILD | 2 + mediapipe/tasks/cc/core/BUILD | 4 +- mediapipe/tasks/cc/text/text_classifier/BUILD | 51 ++++++----- mediapipe/tasks/cc/text/text_embedder/BUILD | 3 + mediapipe/tasks/cc/vision/core/BUILD | 2 + .../tasks/cc/vision/gesture_recognizer/BUILD | 90 ++++++++++--------- .../tasks/cc/vision/hand_landmarker/BUILD | 72 ++++++++------- .../tasks/cc/vision/image_classifier/BUILD | 49 +++++----- .../tasks/cc/vision/image_embedder/BUILD | 49 +++++----- .../tasks/cc/vision/image_segmenter/BUILD | 6 +- .../tasks/cc/vision/object_detector/BUILD | 65 +++++++------- 16 files changed, 278 insertions(+), 233 deletions(-) diff --git a/mediapipe/tasks/BUILD b/mediapipe/tasks/BUILD index 242a88cfc..98ddd5777 100644 --- a/mediapipe/tasks/BUILD +++ b/mediapipe/tasks/BUILD @@ -21,3 +21,10 @@ package_group( "//mediapipe/tasks/...", ], ) + +package_group( + name = "users", + includes = [ + ":internal", + ], +) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 1955adfe7..a817bcc3b 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -16,6 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Classifier +# https://developers.google.com/mediapipe/solutions/audio/audio_classifier +cc_library( + name = "audio_classifier", + srcs = ["audio_classifier.cc"], + hdrs = ["audio_classifier.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_classifier_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_classifier_graph", srcs = ["audio_classifier_graph.cc"], @@ -52,28 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_classifier", - srcs = ["audio_classifier.cc"], - hdrs = ["audio_classifier.h"], - deps = [ - ":audio_classifier_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index b982ef39a..adba28e6a 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -16,6 +16,36 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Audio Embedder +# https://developers.google.com/mediapipe/solutions/audio/audio_embedder +cc_library( + name = "audio_embedder", + srcs = ["audio_embedder.cc"], + hdrs = ["audio_embedder.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":audio_embedder_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:matrix", + "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", + "//mediapipe/tasks/cc/audio/core:base_audio_task_api", + "//mediapipe/tasks/cc/audio/core:running_mode", + "//mediapipe/tasks/cc/components/containers:embedding_result", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedder_options", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:cosine_similarity", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "audio_embedder_graph", srcs = ["audio_embedder_graph.cc"], @@ -51,29 +81,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "audio_embedder", - srcs = ["audio_embedder.cc"], - hdrs = ["audio_embedder.h"], - deps = [ - ":audio_embedder_graph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:matrix", - "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_cc_proto", - "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", - "//mediapipe/tasks/cc/audio/core:base_audio_task_api", - "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components/containers:embedding_result", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedder_options", - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:cosine_similarity", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: mediapipe/tasks/cc/audio/utils:test_utils does not compile in the OSS build diff --git a/mediapipe/tasks/cc/audio/core/BUILD b/mediapipe/tasks/cc/audio/core/BUILD index 93362fd3d..016faa10f 100644 --- a/mediapipe/tasks/cc/audio/core/BUILD +++ b/mediapipe/tasks/cc/audio/core/BUILD @@ -19,6 +19,7 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 2f5f8be5b..dec977fb8 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 7845a3dae..32a628db7 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -20,6 +20,7 @@ cc_library( name = "classifier_options", srcs = ["classifier_options.cc"], hdrs = ["classifier_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], ) @@ -67,6 +68,7 @@ cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], hdrs = ["embedder_options.h"], + visibility = ["//visibility:public"], deps = ["//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto"], ) diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f14457073..202f3ea3c 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -22,9 +22,7 @@ cc_library( name = "base_options", srcs = ["base_options.cc"], hdrs = ["base_options.h"], - visibility = [ - "//mediapipe/tasks:internal", - ], + visibility = ["//visibility:public"], deps = [ ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 52b0c0e4b..01adc9fc3 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -16,6 +16,33 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Classifier +# https://developers.google.com/mediapipe/solutions/text/text_classifier +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + visibility = ["//visibility:public"], + deps = [ + ":text_classifier_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:category", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "text_classifier_graph", srcs = ["text_classifier_graph.cc"], @@ -41,30 +68,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "text_classifier", - srcs = ["text_classifier.cc"], - hdrs = ["text_classifier.h"], - deps = [ - ":text_classifier_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/tasks/cc/components/containers:category", - "//mediapipe/tasks/cc/components/containers:classification_result", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", - "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - cc_test( name = "text_classifier_test", srcs = ["text_classifier_test.cc"], diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index e2e16c9c1..27c9cb730 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -16,10 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Text Embedder +# https://developers.google.com/mediapipe/solutions/text/text_embedder cc_library( name = "text_embedder", srcs = ["text_embedder.cc"], hdrs = ["text_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":text_embedder_graph", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index e8e197a1d..1f5ab5faf 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -19,11 +19,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "running_mode", hdrs = ["running_mode.h"], + visibility = ["//visibility:public"], ) cc_library( name = "image_processing_options", hdrs = ["image_processing_options.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/tasks/cc/components/containers:rect", ], diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 75289b1e8..7b144e7aa 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -18,6 +18,52 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Gesture Recognizer +# https://developers.google.com/mediapipe/solutions/vision/gesture_recognizer +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":gesture_recognizer_graph", + ":gesture_recognizer_result", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + cc_library( name = "handedness_util", srcs = ["handedness_util.cc"], @@ -127,51 +173,9 @@ cc_library( cc_library( name = "gesture_recognizer_result", hdrs = ["gesture_recognizer_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) - -cc_library( - name = "gesture_recognizer", - srcs = ["gesture_recognizer.cc"], - hdrs = ["gesture_recognizer.h"], - deps = [ - ":gesture_recognizer_graph", - ":gesture_recognizer_result", - ":hand_gesture_recognizer_graph", - "//mediapipe/framework:packet", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 5c5073fc2..3b869eab4 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -18,6 +18,43 @@ package(default_visibility = [ licenses(["notice"]) +# Docs for Mediapipe Tasks Hand Landmarker +# https://developers.google.com/mediapipe/solutions/vision/hand_landmarker +cc_library( + name = "hand_landmarker", + srcs = ["hand_landmarker.cc"], + hdrs = ["hand_landmarker.h"], + visibility = ["//visibility:public"], + deps = [ + ":hand_landmarker_graph", + ":hand_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], @@ -113,44 +150,11 @@ cc_library( cc_library( name = "hand_landmarker_result", hdrs = ["hand_landmarker_result.h"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", ], ) -cc_library( - name = "hand_landmarker", - srcs = ["hand_landmarker.cc"], - hdrs = ["hand_landmarker.h"], - deps = [ - ":hand_landmarker_graph", - ":hand_landmarker_result", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/processors:classifier_options", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:task_runner", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], -) - # TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index b59d8d682..2b93aa262 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_classifier_graph", - srcs = ["image_classifier_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", - "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Classifier +# https://developers.google.com/mediapipe/solutions/vision/image_classifier cc_library( name = "image_classifier", srcs = ["image_classifier.cc"], hdrs = ["image_classifier.h"], + visibility = ["//visibility:public"], deps = [ ":image_classifier_graph", "//mediapipe/framework:packet", @@ -69,4 +49,27 @@ cc_library( ], ) +cc_library( + name = "image_classifier_graph", + srcs = ["image_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index ea7f40261..8fdb97ccd 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -16,33 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -cc_library( - name = "image_embedder_graph", - srcs = ["image_embedder_graph.cc"], - deps = [ - "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", - "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", - "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) - +# Docs for Mediapipe Tasks Image Embedder +# https://developers.google.com/mediapipe/solutions/vision/image_embedder cc_library( name = "image_embedder", srcs = ["image_embedder.cc"], hdrs = ["image_embedder.h"], + visibility = ["//visibility:public"], deps = [ ":image_embedder_graph", "//mediapipe/framework/api2:builder", @@ -67,4 +47,27 @@ cc_library( ], ) +cc_library( + name = "image_embedder_graph", + srcs = ["image_embedder_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", + "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", + "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", + "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 7206a45ea..595eef568 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -16,13 +16,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Image Segmenter +# https://developers.google.com/mediapipe/solutions/vision/image_segmenter cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], - visibility = [ - "//mediapipe/tasks:internal", - ], + visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 8220d8b7f..b8002fa96 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -16,6 +16,41 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +# Docs for Mediapipe Tasks Object Detector +# https://developers.google.com/mediapipe/solutions/vision/object_detector +cc_library( + name = "object_detector", + srcs = ["object_detector.cc"], + hdrs = ["object_detector.h"], + visibility = [ + "//mediapipe/tasks:users", + ], + deps = [ + ":object_detector_graph", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "object_detector_graph", srcs = ["object_detector_graph.cc"], @@ -56,34 +91,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "object_detector", - srcs = ["object_detector.cc"], - hdrs = ["object_detector.h"], - deps = [ - ":object_detector_graph", - "//mediapipe/calculators/core:concatenate_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/core:base_options", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/core:base_vision_task_api", - "//mediapipe/tasks/cc/vision/core:image_processing_options", - "//mediapipe/tasks/cc/vision/core:running_mode", - "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", - ], -) - # TODO: This test fails in OSS From 3ac7f6a216c12d617edd6549ace59f4f76e085c7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Sun, 20 Nov 2022 19:30:05 -0800 Subject: [PATCH 078/137] Simplify image creation in PacketCreator Use more existing functions, remove redundant code, remove direct use of RuntimeException. PiperOrigin-RevId: 489868983 --- .../mediapipe/framework/PacketCreator.java | 53 +++++---- .../framework/jni/packet_creator_jni.cc | 104 +++++------------- .../framework/jni/packet_creator_jni.h | 2 +- 3 files changed, 64 insertions(+), 95 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java index d93eea7b5..04265cab5 100644 --- a/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/PacketCreator.java @@ -55,7 +55,11 @@ public class PacketCreator { public Packet createRgbImage(ByteBuffer buffer, int width, int height) { int widthStep = (((width * 3) + 3) / 4) * 4; if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + widthStep * height + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImage(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -123,7 +127,11 @@ public class PacketCreator { */ public Packet createRgbImageFromRgba(ByteBuffer buffer, int width, int height) { if (width * height * 4 != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + width * height * 4); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbImageFromRgba(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -136,7 +144,7 @@ public class PacketCreator { */ public Packet createGrayscaleImage(ByteBuffer buffer, int width, int height) { if (width * height != buffer.capacity()) { - throw new RuntimeException( + throw new IllegalArgumentException( "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); } return Packet.create( @@ -150,7 +158,11 @@ public class PacketCreator { */ public Packet createRgbaImageFrame(ByteBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateRgbaImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -163,7 +175,11 @@ public class PacketCreator { */ public Packet createFloatImageFrame(FloatBuffer buffer, int width, int height) { if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); + throw new IllegalArgumentException( + "The size of the buffer should be: " + + width * height * 4 + + " but is " + + buffer.capacity()); } return Packet.create( nativeCreateFloatImageFrame(mediapipeGraph.getNativeHandle(), buffer, width, height)); @@ -354,25 +370,24 @@ public class PacketCreator { *

For 3 and 4 channel images, the pixel rows should have 4-byte alignment. */ public Packet createImage(ByteBuffer buffer, int width, int height, int numChannels) { + int widthStep; if (numChannels == 4) { - if (buffer.capacity() != width * height * 4) { - throw new RuntimeException("buffer doesn't have the correct size."); - } + widthStep = width * 4; } else if (numChannels == 3) { - int widthStep = (((width * 3) + 3) / 4) * 4; - if (widthStep * height != buffer.capacity()) { - throw new RuntimeException("The size of the buffer should be: " + widthStep * height); - } + widthStep = (((width * 3) + 3) / 4) * 4; } else if (numChannels == 1) { - if (width * height != buffer.capacity()) { - throw new RuntimeException( - "The size of the buffer should be: " + width * height + " but is " + buffer.capacity()); - } + widthStep = width; } else { - throw new RuntimeException("Channels should be: 1, 3, or 4, but is " + numChannels); + throw new IllegalArgumentException("Channels should be: 1, 3, or 4, but is " + numChannels); + } + int expectedSize = widthStep * height; + if (buffer.capacity() != expectedSize) { + throw new IllegalArgumentException( + "The size of the buffer should be: " + expectedSize + " but is " + buffer.capacity()); } return Packet.create( - nativeCreateCpuImage(mediapipeGraph.getNativeHandle(), buffer, width, height, numChannels)); + nativeCreateCpuImage( + mediapipeGraph.getNativeHandle(), buffer, width, height, widthStep, numChannels)); } /** Helper callback adaptor to create the Java {@link GlSyncToken}. This is called by JNI code. */ @@ -430,7 +445,7 @@ public class PacketCreator { long context, int name, int width, int height, TextureReleaseCallback releaseCallback); private native long nativeCreateCpuImage( - long context, ByteBuffer buffer, int width, int height, int numChannels); + long context, ByteBuffer buffer, int width, int height, int rowBytes, int numChannels); private native long nativeCreateInt32Array(long context, int[] data); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index 2d5447401..46ea1ce41 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -111,22 +111,8 @@ absl::StatusOr CreateGpuBuffer( // ByteBuffer. absl::StatusOr> CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, - jint height, + jint height, jint width_step, mediapipe::ImageFormat::Format format) { - switch (format) { - case mediapipe::ImageFormat::SRGBA: - case mediapipe::ImageFormat::SRGB: - case mediapipe::ImageFormat::GRAY8: - break; - default: - return absl::InvalidArgumentError( - "Format must be either SRGBA, SRGB, or GRAY8."); - } - - auto image_frame = std::make_unique( - format, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - const int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); const void* buffer_data = env->GetDirectBufferAddress(byte_buffer); if (buffer_data == nullptr || buffer_size < 0) { @@ -135,34 +121,19 @@ CreateImageFrameFromByteBuffer(JNIEnv* env, jobject byte_buffer, jint width, "using allocateDirect."); } - const int num_channels = image_frame->NumberOfChannels(); - const int expected_buffer_size = - num_channels == 1 ? width * height : image_frame->PixelDataSize(); - + const int expected_buffer_size = height * width_step; RET_CHECK_EQ(buffer_size, expected_buffer_size) - << (num_channels != 1 - ? "The input image buffer should have 4 bytes alignment. " - : "") - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; - // Copy buffer data to image frame's pixel_data_. - if (num_channels == 1) { - const int width_step = image_frame->WidthStep(); - const char* src_row = reinterpret_cast(buffer_data); - char* dst_row = reinterpret_cast(image_frame->MutablePixelData()); - for (int i = height; i > 0; --i) { - std::memcpy(dst_row, src_row, width); - src_row += width; - dst_row += width_step; - } - } else { - // 3 and 4 channels. - std::memcpy(image_frame->MutablePixelData(), buffer_data, - image_frame->PixelDataSize()); - } + auto image_frame = std::make_unique(); + // TODO: we could retain the buffer with a special deleter and use + // the data directly without a copy. May need a new Java API since existing + // code might expect to be able to overwrite the buffer after creating an + // ImageFrame from it. + image_frame->CopyPixelData( + format, width, height, width_step, static_cast(buffer_data), + mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); return image_frame; } @@ -183,8 +154,12 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateReferencePacket)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGB); + // We require 4-byte alignment. See Java method. + constexpr int kAlignment = 4; + int width_step = ((width * 3 - 1) | (kAlignment - 1)) + 1; + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, + width_step, mediapipe::ImageFormat::SRGB); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); @@ -204,10 +179,8 @@ absl::StatusOr> CreateRgbImageFromRgba( const int expected_buffer_size = width * height * 4; RET_CHECK_EQ(buffer_size, expected_buffer_size) - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << expected_buffer_size - << ", Image width: " << width; + << "Input buffer size should be " << expected_buffer_size + << " but is: " << buffer_size; auto image_frame = absl::make_unique( mediapipe::ImageFormat::SRGB, width, height, @@ -232,7 +205,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::GRAY8); + env, byte_buffer, width, height, width, mediapipe::ImageFormat::GRAY8); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); @@ -242,28 +215,9 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGrayscaleImage)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - // TODO: merge this case with CreateImageFrameFromByteBuffer. auto image_frame_or = - [&]() -> absl::StatusOr> { - const void* data = env->GetDirectBufferAddress(byte_buffer); - int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer); - if (data == nullptr || buffer_size < 0) { - return absl::InvalidArgumentError( - "input buffer does not support direct access"); - } - - auto image_frame = absl::make_unique( - mediapipe::ImageFormat::VEC32F1, width, height, - mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); - RET_CHECK_EQ(buffer_size, image_frame->PixelDataSize()) - << "Please check the input buffer size." - << " Buffer size: " << buffer_size - << ", Buffer size needed: " << image_frame->PixelDataSize() - << ", Image width: " << width; - std::memcpy(image_frame->MutablePixelData(), data, - image_frame->PixelDataSize()); - return image_frame; - }(); + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::VEC32F1); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); @@ -272,10 +226,10 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateFloatImageFrame)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateRgbaImageFrame)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, jint height) { - auto image_frame_or = CreateImageFrameFromByteBuffer( - env, byte_buffer, width, height, mediapipe::ImageFormat::SRGBA); + auto image_frame_or = + CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, width * 4, + mediapipe::ImageFormat::SRGBA); if (ThrowIfError(env, image_frame_or.status())) return 0L; - mediapipe::Packet packet = mediapipe::Adopt(image_frame_or->release()); return CreatePacketWithContext(context, packet); } @@ -417,7 +371,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels) { + jint height, jint width_step, jint num_channels) { mediapipe::ImageFormat::Format format; switch (num_channels) { case 4: @@ -436,8 +390,8 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( return 0L; } - auto image_frame_or = - CreateImageFrameFromByteBuffer(env, byte_buffer, width, height, format); + auto image_frame_or = CreateImageFrameFromByteBuffer( + env, byte_buffer, width, height, width_step, format); if (ThrowIfError(env, image_frame_or.status())) return 0L; mediapipe::Packet packet = diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h index d6f44b0a3..b3b1043fb 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.h @@ -99,7 +99,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateCpuImage)( JNIEnv* env, jobject thiz, jlong context, jobject byte_buffer, jint width, - jint height, jint num_channels); + jint height, jint width_step, jint num_channels); JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuImage)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, From 13c6b9a8c6ce6fc9d0e34316821d497bb7f4f9f2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 20 Nov 2022 22:18:49 -0800 Subject: [PATCH 079/137] Allow kernel cache path to be specified without trailing path delimiter PiperOrigin-RevId: 489891079 --- .../calculators/tensor/inference_calculator_gl_advanced.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index ad5df849f..c2c723402 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -241,9 +241,9 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { - cached_kernel_filename_ = gpu_delegate_options.cached_kernel_path() + - mediapipe::File::Basename(options.model_path()) + - ".ker"; + cached_kernel_filename_ = mediapipe::file::JoinPath( + gpu_delegate_options.cached_kernel_path(), + mediapipe::File::Basename(options.model_path()) + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = From 7acbf557a1294e3809e8671ac769c855dd3336c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 01:55:49 -0800 Subject: [PATCH 080/137] Cleanup after migration to new classification output format. PiperOrigin-RevId: 489921603 --- .../tasks/cc/components/calculators/BUILD | 1 - .../classification_aggregation_calculator.cc | 68 +--- .../cc/components/containers/proto/BUILD | 6 - .../containers/proto/category.proto | 41 --- .../containers/proto/classifications.proto | 17 +- .../classification_postprocessing_graph.cc | 9 - .../classification_postprocessing_graph.h | 3 - ...lassification_postprocessing_graph_test.cc | 322 ------------------ .../text_classifier/text_classifier_graph.cc | 27 +- .../image_classifier_graph.cc | 9 - .../com/google/mediapipe/tasks/text/BUILD | 1 - .../com/google/mediapipe/tasks/vision/BUILD | 1 - .../tasks/python/components/containers/BUILD | 2 +- .../python/components/containers/category.py | 16 +- .../containers/classification_result.py | 15 +- 15 files changed, 23 insertions(+), 515 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/containers/proto/category.proto diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 1f726a018..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 1a83fdad2..ad2c668c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,14 +25,12 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into either a ClassificationResult object // representing the classification results aggregated by classifier head, or @@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications; // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example without timestamp aggregation: // node { @@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node { ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); std::vector ConvertToTimestampedClassificationResults( CalculatorContext* cc); - // TODO: deprecate this function once migration is over. - ClassificationResult LegacyConvertToClassificationResult( - CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } - // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if - // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is - // not connected. All dependent tasks must be updated to use these outputs - // first. + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected()); + } else { + RET_CHECK(kClassificationsOut(cc).IsConnected()); + } return absl::OkStatus(); } @@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process( if (kTimestampsIn(cc).IsEmpty()) { return absl::OkStatus(); } - classification_result = LegacyConvertToClassificationResult(cc); kTimestampedClassificationsOut(cc).Send( ConvertToTimestampedClassificationResults(cc)); } else { - classification_result = LegacyConvertToClassificationResult(cc); kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } kClassificationResultOut(cc).Send(classification_result); @@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( return results; } -ClassificationResult -ClassificationAggregationCalculator::LegacyConvertToClassificationResult( - CalculatorContext* cc) { - ClassificationResult result; - Timestamp first_timestamp(0); - std::vector timestamps; - if (time_aggregation_enabled_) { - timestamps = kTimestampsIn(cc).Get(); - first_timestamp = timestamps[0]; - } else { - timestamps = {cc->InputTimestamp()}; - } - for (Timestamp timestamp : timestamps) { - int count = cached_classifications_[timestamp.Value()].size(); - for (int i = 0; i < count; ++i) { - Classifications* c; - if (result.classifications_size() <= i) { - c = result.add_classifications(); - if (!head_names_.empty()) { - c->set_head_index(i); - c->set_head_name(head_names_[i]); - } - } else { - c = result.mutable_classifications(i); - } - auto* entry = c->add_entries(); - for (const auto& elem : - cached_classifications_[timestamp.Value()][i].classification()) { - auto* category = entry->add_categories(); - if (elem.has_index()) { - category->set_index(elem.index()); - } - if (elem.has_score()) { - category->set_score(elem.score()); - } - if (elem.has_label()) { - category->set_category_name(elem.label()); - } - if (elem.has_display_name()) { - category->set_display_name(elem.display_name()); - } - } - entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / - 1000); - } - } - return result; -} - MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator); } // namespace api2 diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 7b455c0c4..27d2357b5 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], -) - mediapipe_proto_library( name = "classifications_proto", srcs = ["classifications.proto"], deps = [ - ":category_proto", "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto deleted file mode 100644 index 412e71428..000000000 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto2"; - -package mediapipe.tasks.components.containers.proto; - -option java_package = "com.google.mediapipe.tasks.components.containers.proto"; -option java_outer_classname = "CategoryProto"; - -// TODO: deprecate this message once migration is over. -// A single classification result. -message Category { - // The index of the category in the corresponding label map, usually packed in - // the TFLite Model Metadata [1]. - // - // [1]: https://www.tensorflow.org/lite/convert/metadata - optional int32 index = 1; - // The score for this category, e.g. (but not necessarily) a probability in - // [0,1]. - optional float score = 2; - // A human readable name of the category filled from the label map. - optional string display_name = 3; - // An ID for the category, not necessarily human-readable, e.g. a Google - // Knowledge Graph ID [1], filled from the label map. - // - // [1]: https://developers.google.com/knowledge-graph - optional string category_name = 4; -} diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index f098ed0e4..2b2306829 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -18,27 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; import "mediapipe/framework/formats/classification.proto"; -import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; -// TODO: deprecate this message once migration is over. -// List of predicted categories with an optional timestamp. -message ClassificationEntry { - // The array of predicted categories, usually sorted by descending scores, - // e.g., from high to low probability. - repeated Category categories = 1; - // The optional timestamp (in milliseconds) associated to the classifcation - // entry. This is useful for time series use cases, e.g., audio - // classification. - optional int64 timestamp_ms = 2; -} - // Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { - // TODO: deprecate this field once migration is over. - repeated ClassificationEntry entries = 1; // The classification results for this head. optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful @@ -48,6 +33,8 @@ message Classifications { // name. // TODO: Add github link to metadata_schema.fbs. optional string head_name = 3; + // Reserved fields. + reserved 1; } // Classifications for a given classifier model. diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 0fb62afaf..5a0472f5c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; @@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; // Struct holding the different output streams produced by the graph. struct ClassificationPostprocessingOutputStreams { - Source classification_result; Source classifications; Source> timestamped_classifications; }; @@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph( // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Connects output. ClassificationPostprocessingOutputStreams output_streams{ - /*classification_result=*/result_aggregation - [Output(kClassificationResultTag)], /*classifications=*/ result_aggregation[Output(kClassificationsTag)], /*timestamped_classifications=*/ diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 48575ceb0..03ae91130 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -58,9 +58,6 @@ namespace processors { // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index d4728e725..8eb6f3c3b 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsName[] = "tensors"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kClassificationResultName[] = "classification_result"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsName[] = "classifications"; constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; @@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { })pb")})); } -// TODO: remove these tests once migration is over. -class LegacyPostprocessingTest : public tflite_shims::testing::Test { - protected: - absl::StatusOr BuildGraph( - absl::string_view model_name, const proto::ClassifierOptions& options, - bool connect_timestamps = false) { - ASSIGN_OR_RETURN(auto model_resources, - CreateModelResourcesForModel(model_name)); - - Graph graph; - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( - *model_resources, options, - &postprocessing - .GetOptions())); - graph[Input>(kTensorsTag)].SetName(kTensorsName) >> - postprocessing.In(kTensorsTag); - if (connect_timestamps) { - graph[Input>(kTimestampsTag)].SetName( - kTimestampsName) >> - postprocessing.In(kTimestampsTag); - } - postprocessing.Out(kClassificationResultTag) - .SetName(kClassificationResultName) >> - graph[Output(kClassificationResultTag)]; - - MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); - ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( - kClassificationResultName)); - MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); - return poller; - } - - template - void AddTensor( - const std::vector& tensor, const Tensor::ElementType& element_type, - const Tensor::QuantizationParameters& quantization_parameters = {}) { - tensors_->emplace_back(element_type, - Tensor::Shape{1, static_cast(tensor.size())}, - quantization_parameters); - auto view = tensors_->back().GetCpuWriteView(); - T* buffer = view.buffer(); - std::copy(tensor.begin(), tensor.end(), buffer); - } - - absl::Status Run( - std::optional> aggregation_timestamps = std::nullopt, - int timestamp = 0) { - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); - // Reset tensors for future calls. - tensors_ = absl::make_unique>(); - if (aggregation_timestamps.has_value()) { - auto packet = absl::make_unique>(); - for (const auto& timestamp : *aggregation_timestamps) { - packet->emplace_back(Timestamp(timestamp)); - } - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); - } - return absl::OkStatus(); - } - - absl::StatusOr GetClassificationResult( - OutputStreamPoller& poller) { - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); - MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); - - Packet packet; - if (!poller.Next(&packet)) { - return absl::InternalError("Unable to get output packet"); - } - auto result = packet.Get(); - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); - return result; - } - - private: - CalculatorGraph calculator_graph_; - std::unique_ptr> tensors_ = - absl::make_unique>(); -}; - -TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - options.set_score_threshold(0.5); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 18; - tensor[2] = 16; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto(R"pb(classifications { - entries { - categories { index: 1 score: 0.8 } - categories { index: 2 score: 0.6 } - timestamp_ms: 0 - } - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.8 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.6899744811 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6456563062 - category_name: "great white shark" - } - categories { - index: 2 - score: 0.5986876601 - category_name: "goldfish" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor_0(kTwoHeadsNumClasses[0], 0); - tensor_0[1] = 0.2; - tensor_0[2] = 0.4; - tensor_0[3] = 0.6; - std::vector tensor_1(kTwoHeadsNumClasses[1], 0); - tensor_1[1] = 0.2; - tensor_1[2] = 0.4; - tensor_1[3] = 0.6; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kFloat32); - AddTensor(tensor_1, Tensor::ElementType::kFloat32); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Narration, monologue" - } - categories { - index: 2 - score: 0.4 - category_name: "Conversation" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "yamnet_classification" - } - classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Azara\'s Spinetail" - } - categories { - index: 2 - score: 0.4 - category_name: "House Sparrow" - } - timestamp_ms: 0 - } - head_index: 1 - head_name: "bird_classification" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, - /*connect_timestamps=*/true)); - // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); - tensor_0[1] = 12; - tensor_0[2] = 14; - tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); - tensor_1[5] = 12; - tensor_1[6] = 14; - tensor_1[7] = 16; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - AddTensor(tensor_1, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run( - /*aggregation_timestamps=*/std::optional>({0, 1000}), - /*timestamp=*/1000)); - - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - entries { - categories { index: 7 score: 0.6 category_name: "stingray" } - categories { - index: 6 - score: 0.4 - category_name: "electric ray" - } - timestamp_ms: 1 - } - head_index: 0 - head_name: "probability" - })pb")); -} - } // namespace } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 36ff68a07..9a7dce1aa 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -// TODO: remove once Java API migration is over. -// Struct holding the different output streams produced by the text classifier. -struct TextClassifierOutputStreams { - Source classification_result; - Source classifications; -}; - } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -72,10 +64,6 @@ struct TextClassifierOutputStreams { // Outputs: // CLASSIFICATIONS - ClassificationResult @Optional // The classification results aggregated by classifier head. -// TODO: remove once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). // // Example: // node { @@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto output_streams, + auto classifications, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; - output_streams.classifications >> - graph[Output(kClassificationsTag)]; + classifications >> graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildTextClassifierTask( + absl::StatusOr> BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { @@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return TextClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], - /*classifications=*/postprocessing[Output( - kClassificationsTag)]}; + return postprocessing[Output(kClassificationsTag)]; } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8fa1a0d2a..2fc88bcb6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS"; // Struct holding the different output streams produced by the image classifier // subgraph. struct ImageClassifierOutputStreams { - Source classification_result; Source classifications; Source image; }; @@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams { // The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. -// TODO: remove this output once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example: // node { @@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. return ImageClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], /*classifications=*/ postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 0e72878ab..023a1f286 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -48,7 +48,6 @@ android_library( deps = [ "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 289e3000d..72cee133f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -97,7 +97,6 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d931c26c7..9d275e167 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -68,7 +68,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index cfdb83740..9b5419883 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,10 +16,10 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.containers.proto import category_pb2 +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_CategoryProto = category_pb2.Category +_ClassificationProto = classification_pb2.Classification @dataclasses.dataclass @@ -45,23 +45,23 @@ class Category: category_name: Optional[str] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _CategoryProto: + def to_pb2(self) -> _ClassificationProto: """Generates a Category protobuf object.""" - return _CategoryProto( + return _ClassificationProto( index=self.index, score=self.score, - display_name=self.display_name, - category_name=self.category_name) + label=self.category_name, + display_name=self.display_name) @classmethod @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category': + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category': """Creates a `Category` object from the given protobuf object.""" return Category( index=pb2_obj.index, score=pb2_obj.score, display_name=pb2_obj.display_name, - category_name=pb2_obj.category_name) + category_name=pb2_obj.label) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index 6ffdabe51..000468041 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -49,11 +49,7 @@ class Classifications: """Generates a Classifications protobuf object.""" classification_list_proto = _ClassificationListProto() for category in self.categories: - classification_proto = _ClassificationProto( - index=category.index, - score=category.score, - label=category.category_name, - display_name=category.display_name) + classification_proto = category.to_pb2() classification_list_proto.classification.append(classification_proto) return _ClassificationsProto( classification_list=classification_list_proto, @@ -65,14 +61,9 @@ class Classifications: def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': """Creates a `Classifications` object from the given protobuf object.""" categories = [] - for entry in pb2_obj.classification_list.classification: + for classification in pb2_obj.classification_list.classification: categories.append( - category_module.Category( - index=entry.index, - score=entry.score, - display_name=entry.display_name, - category_name=entry.label)) - + category_module.Category.create_from_pb2(classification)) return Classifications( categories=categories, head_index=pb2_obj.head_index, From 7f0134eecbe75a94bcda7cf113e1ae8aa47cd916 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 12:13:38 -0800 Subject: [PATCH 081/137] Internal change PiperOrigin-RevId: 490041386 --- mediapipe/tasks/python/core/BUILD | 1 + mediapipe/tasks/python/text/BUILD | 1 + 2 files changed, 2 insertions(+) diff --git a/mediapipe/tasks/python/core/BUILD b/mediapipe/tasks/python/core/BUILD index 76e2f4f4a..fc0018ab1 100644 --- a/mediapipe/tasks/python/core/BUILD +++ b/mediapipe/tasks/python/core/BUILD @@ -31,6 +31,7 @@ py_library( py_library( name = "base_options", srcs = ["base_options.py"], + visibility = ["//mediapipe/tasks:users"], deps = [ ":optional_dependencies", "//mediapipe/tasks/cc/core/proto:base_options_py_pb2", diff --git a/mediapipe/tasks/python/text/BUILD b/mediapipe/tasks/python/text/BUILD index bb42da912..10b4b8a6e 100644 --- a/mediapipe/tasks/python/text/BUILD +++ b/mediapipe/tasks/python/text/BUILD @@ -23,6 +23,7 @@ py_library( srcs = [ "text_classifier.py", ], + visibility = ["//mediapipe/tasks:users"], deps = [ "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", From 652423a23d9a69d5c3dabe61926a55bd77d6d610 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 13:04:53 -0800 Subject: [PATCH 082/137] Internal change PiperOrigin-RevId: 490053179 --- mediapipe/calculators/tensor/image_to_tensor_utils.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index d27c595b5..3f91f3dc2 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -253,11 +253,15 @@ int GetNumOutputChannels(const mediapipe::Image& image) { } #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU - // The output tensor channel is 1 for the input image with 1 channel; And the - // output tensor channels is 3 for the input image with 3 or 4 channels. // TODO: Add a unittest here to test the behavior on GPU, i.e. // failure. - return image.channels() == 1 ? 1 : 3; + // Only output channel == 1 when running on CPU and the input image channel + // is 1. Ideally, we want to also support GPU for output channel == 1. But + // setting this on the safer side to prevent unintentional failure. + if (!image.UsesGpu() && image.channels() == 1) { + return 1; + } + return 3; } absl::StatusOr> GetInputImage( From adddf2c2abe953b0280507b6168a41bcbb5a08f3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 14:37:42 -0800 Subject: [PATCH 083/137] Extracted common test helper functions out from the unittest into a sharable library. Also migrated away from OpenCVX. PiperOrigin-RevId: 490074410 --- mediapipe/calculators/tensor/BUILD | 2 + .../tensor/image_to_tensor_calculator_test.cc | 169 ++++++------------ mediapipe/util/BUILD | 18 ++ mediapipe/util/image_test_utils.cc | 57 ++++++ mediapipe/util/image_test_utils.h | 32 ++++ 5 files changed, 166 insertions(+), 112 deletions(-) create mode 100644 mediapipe/util/image_test_utils.cc create mode 100644 mediapipe/util/image_test_utils.h diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 2a573fc44..645189a07 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -30,6 +30,7 @@ exports_files( glob(["testdata/image_to_tensor/*"]), visibility = [ "//mediapipe/calculators/image:__subpackages__", + "//mediapipe/util:__subpackages__", ], ) @@ -1133,6 +1134,7 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/util:image_test_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index 7ea60d98e..ceb1fc502 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -36,29 +36,17 @@ #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/image_test_utils.h" namespace mediapipe { namespace { -cv::Mat GetRgb(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); - return rgb; -} +constexpr char kTestDataDir[] = + "/mediapipe/calculators/tensor/testdata/" + "image_to_tensor/"; -cv::Mat GetRgba(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat rgb; - cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); - return rgb; -} - -cv::Mat GetGray(absl::string_view path) { - cv::Mat bgr = cv::imread(file::JoinPath("./", path)); - cv::Mat gray; - cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); - return gray; +std::string GetFilePath(absl::string_view filename) { + return file::JoinPath("./", kTestDataDir, filename); } // Image to tensor test template. @@ -259,15 +247,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, - /*border mode*/ {}, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, + /*border mode*/ {}, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { @@ -277,11 +262,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -295,11 +277,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -314,11 +293,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * 90.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "medium_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, @@ -332,16 +309,12 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb( - "/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), - /*float_ranges=*/{{-1.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation.png")), + /*float_ranges=*/{{-1.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { @@ -351,11 +324,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) { roi.set_width(0.5f); roi.set_height(0.5f); roi.set_rotation(M_PI * -45.0f / 180.0f); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "medium_sub_rect_with_rotation_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("medium_sub_rect_with_rotation_border_zero.png")), /*float_ranges=*/{{-1.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, @@ -369,10 +339,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, @@ -386,15 +354,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, - BorderMode::kZero, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_border_zero.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, + BorderMode::kZero, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { @@ -404,15 +369,12 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest( - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), - /*float_ranges=*/{{0.0f, 1.0f}}, - /*int_ranges=*/{{0, 255}, {-128, 127}}, - /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, - BorderMode::kReplicate, roi); + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect.png")), + /*float_ranges=*/{{0.0f, 1.0f}}, + /*int_ranges=*/{{0, 255}, {-128, 127}}, + /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, + BorderMode::kReplicate, roi); } TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { @@ -422,11 +384,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(0); - RunTest(GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_border_zero.png"), + RunTest(GetRgb(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -440,11 +399,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -458,11 +414,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotationGray) { roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation.png"), + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath("large_sub_rect_keep_aspect_with_rotation.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -477,11 +430,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -496,11 +447,9 @@ TEST(ImageToTensorCalculatorTest, roi.set_width(1.5f); roi.set_height(1.1f); roi.set_rotation(M_PI * -15.0f / 180.0f); - RunTest(GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetGray("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/" - "large_sub_rect_keep_aspect_with_rotation_border_zero.png"), + RunTest(GetGray(GetFilePath("input.jpg")), + GetGray(GetFilePath( + "large_sub_rect_keep_aspect_with_rotation_border_zero.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}}, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -514,10 +463,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, @@ -531,10 +478,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) { roi.set_width(1.0f); roi.set_height(1.0f); roi.set_rotation(0); - RunTest(GetRgba("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/input.jpg"), - GetRgb("/mediapipe/calculators/" - "tensor/testdata/image_to_tensor/noop_except_range.png"), + RunTest(GetRgba(GetFilePath("input.jpg")), + GetRgb(GetFilePath("noop_except_range.png")), /*float_ranges=*/{{0.0f, 1.0f}}, /*int_ranges=*/{{0, 255}, {-128, 127}}, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 15835aea5..55c1df59f 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -368,3 +368,21 @@ cc_test( "//mediapipe/framework/port:gtest_main", ], ) + +cc_library( + name = "image_test_utils", + testonly = 1, + srcs = ["image_test_utils.cc"], + hdrs = ["image_test_utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgcodecs", + "//mediapipe/framework/port:opencv_imgproc", + ], +) diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc new file mode 100644 index 000000000..815666985 --- /dev/null +++ b/mediapipe/util/image_test_utils.cc @@ -0,0 +1,57 @@ +#include "mediapipe/util/image_test_utils.h" + +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/timestamp.h" + +namespace mediapipe { + +cv::Mat GetRgb(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGB); + return rgb; +} + +cv::Mat GetRgba(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat rgb; + cv::cvtColor(bgr, rgb, cv::COLOR_BGR2RGBA); + return rgb; +} + +cv::Mat GetGray(const std::string& path) { + cv::Mat bgr = cv::imread(path); + cv::Mat gray; + cv::cvtColor(bgr, gray, cv::COLOR_BGR2GRAY); + return gray; +} + +mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { + if (image_channels == 4) { + return ImageFormat::SRGBA; + } else if (image_channels == 3) { + return ImageFormat::SRGB; + } else if (image_channels == 1) { + return ImageFormat::GRAY8; + } + LOG(FATAL) << "Unsupported input image channles: " << image_channels; +} + +Packet MakeImageFramePacket(cv::Mat input, int timestamp) { + ImageFrame input_image(GetImageFormat(input.channels()), input.cols, + input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input, int timestamp) { + mediapipe::Image input_image(std::make_shared( + GetImageFormat(input.channels()), input.cols, input.rows, input.step, + input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +} // namespace mediapipe diff --git a/mediapipe/util/image_test_utils.h b/mediapipe/util/image_test_utils.h new file mode 100644 index 000000000..6df9644d2 --- /dev/null +++ b/mediapipe/util/image_test_utils.h @@ -0,0 +1,32 @@ +#ifndef MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ +#define MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ + +#include + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/opencv_core_inc.h" + +namespace mediapipe { + +// Reads the image file into cv::Mat with RGB channels. +cv::Mat GetRgb(const std::string& path); + +// Reads the image file into cv::Mat with RGBA channels. +cv::Mat GetRgba(const std::string& path); + +// Reads the image file into cv::Mat with Gray channel. +cv::Mat GetGray(const std::string& path); + +// Converts the image channels into corresponding ImageFormat. +mediapipe::ImageFormat::Format GetImageFormat(int image_channels); + +// Converts the cv::Mat into ImageFrame packet. +Packet MakeImageFramePacket(cv::Mat input, int timestamp = 0); + +// Converts the cv::Mat into Image packet. +Packet MakeImagePacket(cv::Mat input, int timestamp = 0); + +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_IMAGE_TEST_UTILS_H_ From d43d0ff615030abb9241c28e6de6e345a8dba7eb Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 15:45:29 -0800 Subject: [PATCH 084/137] Internal change PiperOrigin-RevId: 490089940 --- .../image_to_tensor_converter_opencv.cc | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 76e46f99d..95e38f89c 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -76,31 +76,49 @@ class OpenCvProcessor : public ImageToTensorConverter { return InvalidArgumentError(absl::StrCat( "Unsupported format: ", static_cast(input.image_format()))); } - // TODO: Remove the check once tensor_buffer_offset > 0 is - // supported. - RET_CHECK_EQ(tensor_buffer_offset, 0) - << "The non-zero tensor_buffer_offset input is not supported yet."; + + RET_CHECK_GE(tensor_buffer_offset, 0) + << "The input tensor_buffer_offset needs to be non-negative."; const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); const int output_height = output_shape.dims[1]; const int output_width = output_shape.dims[2]; const int output_channels = output_shape.dims[3]; + const int num_elements_per_img = + output_height * output_width * output_channels; auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; const int dst_data_type = output_channels == 1 ? mat_gray_type_ : mat_type_; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE(output_shape.num_elements(), + tensor_buffer_offset / sizeof(int8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(int8)); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(float) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(float)); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_height, output_width, dst_data_type, - buffer_view.buffer()); + RET_CHECK_GE( + output_shape.num_elements(), + tensor_buffer_offset / sizeof(uint8) + num_elements_per_img) + << "The buffer offset + the input image size is larger than the " + "allocated tensor buffer."; + dst = cv::Mat( + output_height, output_width, dst_data_type, + buffer_view.buffer() + tensor_buffer_offset / sizeof(uint8)); break; default: return InvalidArgumentError( @@ -153,9 +171,8 @@ class OpenCvProcessor : public ImageToTensorConverter { absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { RET_CHECK_EQ(output_shape.dims.size(), 4) << "Wrong output dims size: " << output_shape.dims.size(); - RET_CHECK_EQ(output_shape.dims[0], 1) - << "Handling batch dimension not equal to 1 is not implemented in this " - "converter."; + RET_CHECK_GE(output_shape.dims[0], 1) + << "The batch dimension needs to be equal or larger than 1."; RET_CHECK(output_shape.dims[3] == 3 || output_shape.dims[3] == 1) << "Wrong output channel: " << output_shape.dims[3]; return absl::OkStatus(); From 7c9fc9a6428b1c40738b5dce80abbacd627c4bdf Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Mon, 21 Nov 2022 21:45:58 -0800 Subject: [PATCH 085/137] Remove `mp.solutions` from doc generation. These need to be excluded from the current package, so do it automatically. PiperOrigin-RevId: 490146934 --- docs/build_py_api_docs.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/build_py_api_docs.py b/docs/build_py_api_docs.py index fa1e4314f..fe706acd3 100644 --- a/docs/build_py_api_docs.py +++ b/docs/build_py_api_docs.py @@ -30,7 +30,7 @@ from tensorflow_docs.api_generator import public_api try: # mediapipe has not been set up to work with bazel yet, so catch & report. - import mediapipe # pytype: disable=import-error + import mediapipe as mp # pytype: disable=import-error except ImportError as e: raise ImportError('Please `pip install mediapipe`.') from e @@ -58,11 +58,13 @@ _SITE_PATH = flags.DEFINE_string('site_path', '/mediapipe/api_docs/python', def gen_api_docs(): """Generates API docs for the mediapipe package.""" + if hasattr(mp, 'solutions'): + del mp.solutions doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[(PROJECT_SHORT_NAME, mediapipe)], - base_dir=os.path.dirname(mediapipe.__file__), + py_modules=[(PROJECT_SHORT_NAME, mp)], + base_dir=os.path.dirname(mp.__file__), code_url_prefix=_URL_PREFIX.value, search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, From 54a684717fa39cd39315f8f6cb60b6c5a7fa76aa Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:22:49 -0800 Subject: [PATCH 086/137] Internal change PiperOrigin-RevId: 490159674 --- mediapipe/gpu/attachments.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/attachments.h b/mediapipe/gpu/attachments.h index ca9f074c4..3a73e4676 100644 --- a/mediapipe/gpu/attachments.h +++ b/mediapipe/gpu/attachments.h @@ -31,8 +31,8 @@ class AttachmentBase {}; template class Attachment : public AttachmentBase { public: - using FactoryT = std::function(Context&)>; - Attachment(FactoryT factory) : factory_(factory) {} + using FactoryT = AttachmentPtr (*)(Context&); + explicit constexpr Attachment(FactoryT factory) : factory_(factory) {} Attachment(const Attachment&) = delete; Attachment(Attachment&&) = delete; From a8b776102240ecb73f1a7aeb8ace9db42eb05f96 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:27:55 -0800 Subject: [PATCH 087/137] Define a kUtilityFramebuffer context attachment A framebuffer object is often needed to render to a texture or read data from it. Currently we create one in each GlCalculatorHelper, but that is redundant (we only need one per context, and multiple calculators can share the same context). Other times, the code that needs to use this doesn't own a helper. For both reasons, this should be attached to the context. We could just make this a member of GlContext since it's so common. However, I figured we might as well use the attachment system. PiperOrigin-RevId: 490160214 --- mediapipe/gpu/gl_context.cc | 12 ++++++++++++ mediapipe/gpu/gl_context.h | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 53e3ff8b7..99b995dda 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -1054,4 +1054,16 @@ void GlContext::SetStandardTextureParams(GLenum target, GLint internal_format) { glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); } +const GlContext::Attachment kUtilityFramebuffer( + [](GlContext&) -> GlContext::Attachment::Ptr { + GLuint framebuffer; + glGenFramebuffers(1, &framebuffer); + if (!framebuffer) return nullptr; + return {new GLuint(framebuffer), [](void* ptr) { + GLuint* fb = static_cast(ptr); + glDeleteFramebuffers(1, fb); + delete fb; + }}; + }); + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 7f5168d8b..4f2390404 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -474,6 +474,12 @@ class GlContext : public std::enable_shared_from_this { bool destructing_ = false; }; +// A framebuffer that the framework can use to attach textures for rendering +// etc. +// This could just be a member of GlContext, but it serves as a basic example +// of an attachment. +ABSL_CONST_INIT extern const GlContext::Attachment kUtilityFramebuffer; + // For backward compatibility. TODO: migrate remaining callers. ABSL_DEPRECATED( "Prefer passing an explicit GlVersion argument (use " From bacbac8d926d769bf51f770914d603b942094ebb Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 21 Nov 2022 23:57:33 -0800 Subject: [PATCH 088/137] Use kUtilityFramebuffer in ReadTexture This avoids creating a temporary framebuffer each time. PiperOrigin-RevId: 490163892 --- mediapipe/gpu/gl_texture_buffer.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 7f77cd4b3..3d2642552 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -15,6 +15,7 @@ #include "mediapipe/gpu/gl_texture_buffer.h" #include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_texture_view.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" @@ -333,8 +334,8 @@ void GlTextureBuffer::ViewDoneWriting(const GlTextureView& view) { #endif // __ANDROID__ } -static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, - void* output, size_t size) { +static void ReadTexture(GlContext& ctx, const GlTextureView& view, + GpuBufferFormat format, void* output, size_t size) { // TODO: check buffer size? We could use glReadnPixels where available // (OpenGL ES 3.2, i.e. nowhere). Note that, to fully check that the read // won't overflow the buffer with glReadPixels, we'd also need to check or @@ -347,10 +348,7 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, GLint previous_fbo; glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - // We use a temp fbo to avoid depending on the app having an existing one. - // TODO: keep a utility fbo around in the context? - GLuint fbo = 0; - glGenFramebuffers(1, &fbo); + GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), view.name(), 0); @@ -360,7 +358,6 @@ static void ReadTexture(const GlTextureView& view, GpuBufferFormat format, 0); // TODO: just set the binding to 0 to avoid the get call? glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); - glDeleteFramebuffers(1, &fbo); } static std::shared_ptr ConvertToImageFrame( @@ -370,9 +367,10 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - buf->GetProducerContext()->Run([buf, &output] { + auto ctx = buf->GetProducerContext(); + ctx->Run([buf, &output, &ctx] { auto view = buf->GetReadView(internal::types{}, /*plane=*/0); - ReadTexture(view, buf->format(), output->MutablePixelData(), + ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), output->PixelDataSize()); }); return std::make_shared(std::move(output)); From d648926155d19cb6665895661624ec19cc7d33c6 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 00:35:27 -0800 Subject: [PATCH 089/137] Just reset the fb binding to 0 in ReadTexture This saves a get operation. We already have precedent in lots of other MediaPipe code where we just reset bindings to 0. PiperOrigin-RevId: 490170691 --- mediapipe/gpu/gl_texture_buffer.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 3d2642552..d530d5d12 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -345,9 +345,6 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, GlTextureInfo info = GlTextureInfoForGpuBufferFormat( format, view.plane(), view.gl_context()->GetGlVersion()); - GLint previous_fbo; - glGetIntegerv(GL_FRAMEBUFFER_BINDING, &previous_fbo); - GLuint fbo = kUtilityFramebuffer.Get(ctx); glBindFramebuffer(GL_FRAMEBUFFER, fbo); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, view.target(), @@ -356,8 +353,7 @@ static void ReadTexture(GlContext& ctx, const GlTextureView& view, output); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, 0, 0); - // TODO: just set the binding to 0 to avoid the get call? - glBindFramebuffer(GL_FRAMEBUFFER, previous_fbo); + glBindFramebuffer(GL_FRAMEBUFFER, 0); } static std::shared_ptr ConvertToImageFrame( From 872d1afda7f8a465db59dfcf9ab56e6d60832646 Mon Sep 17 00:00:00 2001 From: vrabaud Date: Tue, 22 Nov 2022 03:10:35 -0800 Subject: [PATCH 090/137] Internal change PiperOrigin-RevId: 490196129 --- mediapipe/framework/port/BUILD | 11 ++++++++++ mediapipe/framework/port/opencv_videoio_inc.h | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 mediapipe/framework/port/opencv_videoio_inc.h diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 87944d80f..e499ca3a6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -311,6 +311,17 @@ cc_library( ], ) +cc_library( + name = "opencv_videoio", + hdrs = ["opencv_videoio_inc.h"], + visibility = ["//visibility:public"], + deps = [ + ":opencv_core", + "//mediapipe/framework:port", + "//third_party:opencv", + ], +) + cc_library( name = "parse_text_proto", hdrs = [ diff --git a/mediapipe/framework/port/opencv_videoio_inc.h b/mediapipe/framework/port/opencv_videoio_inc.h new file mode 100644 index 000000000..63029b69f --- /dev/null +++ b/mediapipe/framework/port/opencv_videoio_inc.h @@ -0,0 +1,21 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ +#define MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ + +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "third_party/OpenCV/videoio.hpp" + +#endif // MEDIAPIPE_PORT_OPENCV_VIDEOIO_INC_H_ From 515d00fc22100bfb948aecfa39408a0b599a0c89 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 15:16:52 -0800 Subject: [PATCH 091/137] Internal change PiperOrigin-RevId: 490349260 --- mediapipe/framework/formats/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index e13bb2704..4276ffc3a 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -312,9 +312,7 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) mediapipe_register_type( From 7ce4aa6592c30c2ac5d0c075304e50ae7d01b38f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 16:38:51 -0800 Subject: [PATCH 092/137] Internal change PiperOrigin-RevId: 490366250 --- mediapipe/util/sequence/media_sequence_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index 40a474599..42b0e3889 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -802,7 +802,7 @@ TEST(MediaSequenceTest, ReconcileMetadataImages) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_image(bytes.begin(), bytes.end()); AddImageEncoded(encoded_image, &sequence); AddImageEncoded(encoded_image, &sequence); @@ -843,7 +843,7 @@ TEST(MediaSequenceTest, ReconcileMetadataFlowEncoded) { tensorflow::SequenceExample sequence; cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; - ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); + ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {})); std::string encoded_flow(bytes.begin(), bytes.end()); AddForwardFlowEncoded(encoded_flow, &sequence); From efa9e737f80e245aec4c6ef9483fc92547e6d1d9 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 17:22:18 -0800 Subject: [PATCH 093/137] Use current context if available in ConvertToImageFrame If we're already running in a GlContext, there's no need to go back to the producer context, which may be different. PiperOrigin-RevId: 490373829 --- mediapipe/gpu/gl_texture_buffer.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index d530d5d12..69b9889c7 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -363,7 +363,8 @@ static std::shared_ptr ConvertToImageFrame( auto output = absl::make_unique(image_format, buf->width(), buf->height(), ImageFrame::kGlDefaultAlignmentBoundary); - auto ctx = buf->GetProducerContext(); + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); ctx->Run([buf, &output, &ctx] { auto view = buf->GetReadView(internal::types{}, /*plane=*/0); ReadTexture(*ctx, view, buf->format(), output->MutablePixelData(), @@ -392,7 +393,9 @@ static std::shared_ptr ConvertToCvPixelBuffer( std::shared_ptr buf) { auto output = absl::make_unique( buf->width(), buf->height(), buf->format()); - buf->GetProducerContext()->Run([buf, &output] { + auto ctx = GlContext::GetCurrent(); + if (!ctx) ctx = buf->GetProducerContext(); + ctx->Run([buf, &output] { TempGlFramebuffer framebuffer; auto src = buf->GetReadView(internal::types{}, /*plane=*/0); auto dst = From fac97554dfb80e8c14ecbfb2cbe12e0ad26ce0b4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 17:23:48 -0800 Subject: [PATCH 094/137] Small TS audio API improvement PiperOrigin-RevId: 490374083 --- .../audio_classifier/audio_classifier.ts | 14 +- .../audio/audio_embedder/audio_embedder.ts | 14 +- mediapipe/web/graph_runner/graph_runner.ts | 129 ++++++++++++++---- 3 files changed, 105 insertions(+), 52 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 0c54a4718..20c745383 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -35,11 +35,7 @@ export * from './audio_classifier_result'; const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' and -// cannot be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; @@ -154,14 +150,8 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStream(audioData, timestampMs); + this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 51cb819de..46a7b6729 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -35,11 +35,7 @@ export * from './audio_embedder_result'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern -// Note: `input_audio` is hardcoded in 'gl_graph_runner_internal_audio' cannot -// be changed -// TODO: Change this to `audio_in` to match the name in the CC -// implementation -const AUDIO_STREAM = 'input_audio'; +const AUDIO_STREAM = 'audio_in'; const SAMPLE_RATE_STREAM = 'sample_rate'; const EMBEDDINGS_STREAM = 'embeddings_out'; const TIMESTAMPED_EMBEDDINGS_STREAM = 'timestamped_embeddings_out'; @@ -151,14 +147,8 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - // Configures the number of samples in the WASM layer. We re-configure the - // number of samples and the sample rate for every frame, but ignore other - // side effects of this function (such as sending the input side packet and - // the input stream header). - this.configureAudio( - /* numChannels= */ 1, /* numSamples= */ audioData.length, sampleRate); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStream(audioData, timestampMs); + this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 7de5aa33b..c4654794c 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -15,9 +15,6 @@ export declare interface FileLocator { locateFile: (filename: string) => string; } -/** Listener to be passed in by user for handling output audio data. */ -export type AudioOutputListener = (output: Float32Array) => void; - /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -32,19 +29,14 @@ export declare interface WasmModule { _bindTextureToCanvas: () => boolean; _changeBinaryGraph: (size: number, dataPtr: number) => void; _changeTextGraph: (size: number, dataPtr: number) => void; - _configureAudio: - (channels: number, samples: number, sampleRate: number) => void; _free: (ptr: number) => void; _malloc: (size: number) => number; - _processAudio: (dataPtr: number, timestamp: number) => void; _processFrame: (width: number, height: number, timestamp: number) => void; _setAutoRenderToScreen: (enabled: boolean) => void; _waitUntilIdle: () => void; // Exposed so that clients of this lib can access this field dataFileDownloads?: {[url: string]: {loaded: number, total: number}}; - // Wasm module will call us back at this function when given audio data. - onAudioOutput?: AudioOutputListener; // Wasm Module multistream entrypoints. Require // gl_graph_runner_internal_multi_input as a build dependency. @@ -100,11 +92,14 @@ export declare interface WasmModule { _attachProtoVectorListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; - // Requires dependency ":gl_graph_runner_audio_out", and will register an - // audio output listening function which can be tapped into dynamically during - // graph running via onAudioOutput. This call must be made before graph is - // initialized, but after wasmModule is instantiated. - _attachAudioOutputListener: () => void; + // Require dependency ":gl_graph_runner_audio_out" + _attachAudioListener: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + + // Require dependency ":gl_graph_runner_audio" + _addAudioToInputStream: (dataPtr: number, numChannels: number, + numSamples: number, streamNamePtr: number, timestamp: number) => void; + _configureAudio: (channels: number, samples: number, sampleRate: number, + streamNamePtr: number, headerNamePtr: number) => void; // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more @@ -235,19 +230,38 @@ export class GraphRunner { } /** - * Configures the current graph to handle audio in a certain way. Must be - * called before the graph is set/started in order to use processAudio. + * Configures the current graph to handle audio processing in a certain way + * for all its audio input streams. Additionally can configure audio headers + * (both input side packets as well as input stream headers), but these + * configurations only take effect if called before the graph is set/started. * @param numChannels The number of channels of audio input. Only 1 * is supported for now. * @param numSamples The number of samples that are taken in each * audio capture. * @param sampleRate The rate, in Hz, of the sampling. + * @param streamName The optional name of the input stream to additionally + * configure with audio information. This configuration only occurs before + * the graph is set/started. If unset, a default stream name will be used. + * @param headerName The optional name of the header input side packet to + * additionally configure with audio information. This configuration only + * occurs before the graph is set/started. If unset, a default header name + * will be used. */ - configureAudio(numChannels: number, numSamples: number, sampleRate: number) { - this.wasmModule._configureAudio(numChannels, numSamples, sampleRate); - if (this.wasmModule._attachAudioOutputListener) { - this.wasmModule._attachAudioOutputListener(); + configureAudio(numChannels: number, numSamples: number, sampleRate: number, + streamName?: string, headerName?: string) { + if (!this.wasmModule._configureAudio) { + console.warn( + 'Attempting to use configureAudio without support for input audio. ' + + 'Is build dep ":gl_graph_runner_audio" missing?'); } + streamName = streamName || 'input_audio'; + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + headerName = headerName || 'audio_header'; + this.wrapStringPtr(headerName, (headerNamePtr: number) => { + this.wasmModule._configureAudio(streamNamePtr, headerNamePtr, + numChannels, numSamples, sampleRate); + }); + }); } /** @@ -437,9 +451,36 @@ export class GraphRunner { * processed. * @param audioData An array of raw audio capture data, like * from a call to getChannelData on an AudioBuffer. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. * @param timestamp The timestamp of the current frame, in ms. */ - addAudioToStream(audioData: Float32Array, timestamp: number) { + addAudioToStream( + audioData: Float32Array, streamName: string, timestamp: number) { + // numChannels and numSamples being 0 will cause defaults to be used, + // which will reflect values from last call to configureAudio. + this.addAudioToStreamWithShape(audioData, 0, 0, streamName, timestamp); + } + + /** + * Takes the raw data from a JS audio capture array, and sends it to C++ to be + * processed, shaping the audioData array into an audio matrix according to + * the numChannels and numSamples parameters. + * @param audioData An array of raw audio capture data, like + * from a call to getChannelData on an AudioBuffer. + * @param numChannels The number of audio channels this data represents. If 0 + * is passed, then the value will be taken from the last call to + * configureAudio. + * @param numSamples The number of audio samples captured in this data packet. + * If 0 is passed, then the value will be taken from the last call to + * configureAudio. + * @param streamName The name of the MediaPipe graph stream to add the audio + * data to. + * @param timestamp The timestamp of the current frame, in ms. + */ + addAudioToStreamWithShape( + audioData: Float32Array, numChannels: number, numSamples: number, + streamName: string, timestamp: number) { // 4 bytes for each F32 const size = audioData.length * 4; if (this.audioSize !== size) { @@ -450,7 +491,11 @@ export class GraphRunner { this.audioSize = size; } this.wasmModule.HEAPF32.set(audioData, this.audioPtr! / 4); - this.wasmModule._processAudio(this.audioPtr!, timestamp); + + this.wrapStringPtr(streamName, (streamNamePtr: number) => { + this.wasmModule._addAudioToInputStream( + this.audioPtr!, numChannels, numSamples, streamNamePtr, timestamp); + }); } /** @@ -943,17 +988,45 @@ export class GraphRunner { } /** - * Sets a listener to be called back with audio output packet data, as a - * Float32Array, when graph has finished processing it. - * @param audioOutputListener The caller's listener function. + * Attaches an audio packet listener to the specified output_stream, to be + * given a Float32Array as output. + * @param outputStreamName The name of the graph output stream to grab audio + * data from. + * @param callbackFcn The function that will be called back with the data, as + * it is received. Note that the data is only guaranteed to exist for the + * duration of the callback, and the callback will be called inline, so it + * should not perform overly complicated (or any async) behavior. If the + * audio data needs to be able to outlive the call, you may set the + * optional makeDeepCopy parameter to true, or can manually deep-copy the + * data yourself. + * @param makeDeepCopy Optional convenience parameter which, if set to true, + * will override the default memory management behavior and make a deep + * copy of the underlying data, rather than just returning a view into the + * C++-managed memory. At the cost of a data copy, this allows the + * returned data to outlive the callback lifetime (and it will be cleaned + * up automatically by JS garbage collection whenever the user is finished + * with it). */ - setOnAudioOutput(audioOutputListener: AudioOutputListener) { - this.wasmModule.onAudioOutput = audioOutputListener; - if (!this.wasmModule._attachAudioOutputListener) { + attachAudioListener(outputStreamName: string, + callbackFcn: (data: Float32Array) => void, makeDeepCopy?: boolean): void { + if (!this.wasmModule._attachAudioListener) { console.warn( - 'Attempting to use AudioOutputListener without support for ' + + 'Attempting to use attachAudioListener without support for ' + 'output audio. Is build dep ":gl_graph_runner_audio_out" missing?'); } + + // Set up our TS listener to receive any packets for this stream, and + // additionally reformat our Uint8Array into a Float32Array for the user. + this.setListener(outputStreamName, (data: Uint8Array) => { + const floatArray = new Float32Array(data.buffer); // Should be very fast + callbackFcn(floatArray); + }); + + // Tell our graph to listen for string packets on this stream. + this.wrapStringPtr(outputStreamName, (outputStreamNamePtr: number) => { + this.wasmModule._attachAudioListener( + outputStreamNamePtr, makeDeepCopy || false); + }); } /** From 8ba9d87e667f0c6e67026f96aa58ee1a980b0ce1 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 17:25:55 -0800 Subject: [PATCH 095/137] Update ImageFrameToGpuBufferCalculator to use api2 and GpuBuffer conversions PiperOrigin-RevId: 490374387 --- mediapipe/gpu/BUILD | 2 + .../image_frame_to_gpu_buffer_calculator.cc | 62 ++++++++----------- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 10a8d7fff..f97eed678 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -901,6 +901,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", + ":gpu_buffer_storage_image_frame", + "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 2a8331db8..c67fb0c62 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -12,73 +12,63 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#ifdef __APPLE__ -#include "mediapipe/objc/util.h" -#endif - namespace mediapipe { +namespace api2 { -// Convert ImageFrame to GpuBuffer. -class ImageFrameToGpuBufferCalculator : public CalculatorBase { +class ImageFrameToGpuBufferCalculator + : public RegisteredNode { public: - ImageFrameToGpuBufferCalculator() {} + static constexpr Input kIn{""}; + static constexpr Output kOut{""}; - static absl::Status GetContract(CalculatorContract* cc); + MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); + + static absl::Status UpdateContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; private: -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; -REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -absl::Status ImageFrameToGpuBufferCalculator::GetContract( +absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); // Note: we call this method even on platforms where we don't use the helper, // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. - MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return absl::OkStatus(); + return GlCalculatorHelper::UpdateContract(cc); } absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { - // Inform the framework that we always output at the same timestamp - // as we receive a packet at. - cc->SetOffset(TimestampDiff(0)); -#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); -#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { -#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - CFHolder buffer; - MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( - cc->Inputs().Index(0).Value(), &buffer)); - cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); -#else - const auto& input = cc->Inputs().Index(0).Get(); - helper_.RunInGlContext([this, &input, &cc]() { - auto src = helper_.CreateSourceTexture(input); - auto output = src.GetFrame(); - glFlush(); - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - src.Release(); - }); -#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto image_frame = std::const_pointer_cast( + mediapipe::SharedPtrWithPacket(kIn(cc).packet())); + auto gpu_buffer = api2::MakePacket( + std::make_shared( + std::move(image_frame))) + .At(cc->InputTimestamp()); + // This calculator's behavior has been to do the texture upload eagerly, and + // some graphs may rely on running this on a separate GL context to avoid + // blocking another context with the read operation. So let's request GPU + // access here to ensure that the behavior stays the same. + // TODO: have a better way to do this, or defer until later. + helper_.RunInGlContext( + [&gpu_buffer] { auto view = gpu_buffer->GetReadView(0); }); + kOut(cc).Send(std::move(gpu_buffer)); return absl::OkStatus(); } +} // namespace api2 } // namespace mediapipe From 837225c53d55700ff485367bb0fa71890f905e2e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Tue, 22 Nov 2022 17:30:23 -0800 Subject: [PATCH 096/137] Internal change PiperOrigin-RevId: 490374976 --- mediapipe/framework/validated_graph_config.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index 16aad6e9b..01e3da83e 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -1048,6 +1048,14 @@ absl::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( for (const auto& required_item : required_side_packets_) { auto iter = side_packet_types.find(required_item.first); if (iter == side_packet_types.end()) { + bool is_optional = true; + for (int index : required_item.second) { + is_optional &= input_side_packets_[index].packet_type->IsOptional(); + } + if (is_optional) { + // Side packets that are optional and not provided are ignored. + continue; + } statuses.push_back(mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << required_item.first << "\" is required but was not provided."); From 3bbc0e9af9150797142295f47b1d87a0403d8f44 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 22 Nov 2022 17:34:58 -0800 Subject: [PATCH 097/137] Internal change PiperOrigin-RevId: 490375672 --- mediapipe/tasks/web/BUILD | 18 +++--------------- mediapipe/tasks/web/audio.ts | 3 +-- mediapipe/tasks/web/text.ts | 3 +-- mediapipe/tasks/web/vision.ts | 6 +----- 4 files changed, 6 insertions(+), 24 deletions(-) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index af76a1fe8..7e5d02892 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -24,10 +24,7 @@ mediapipe_files(srcs = [ mediapipe_ts_library( name = "audio_lib", srcs = ["audio.ts"], - deps = [ - "//mediapipe/tasks/web/audio/audio_classifier", - "//mediapipe/tasks/web/audio/audio_embedder", - ], + deps = ["//mediapipe/tasks/web/audio:audio_lib"], ) rollup_bundle( @@ -69,10 +66,7 @@ pkg_npm( mediapipe_ts_library( name = "text_lib", srcs = ["text.ts"], - deps = [ - "//mediapipe/tasks/web/text/text_classifier", - "//mediapipe/tasks/web/text/text_embedder", - ], + deps = ["//mediapipe/tasks/web/text:text_lib"], ) rollup_bundle( @@ -114,13 +108,7 @@ pkg_npm( mediapipe_ts_library( name = "vision_lib", srcs = ["vision.ts"], - deps = [ - "//mediapipe/tasks/web/vision/gesture_recognizer", - "//mediapipe/tasks/web/vision/hand_landmarker", - "//mediapipe/tasks/web/vision/image_classifier", - "//mediapipe/tasks/web/vision/image_embedder", - "//mediapipe/tasks/web/vision/object_detector", - ], + deps = ["//mediapipe/tasks/web/vision:vision_lib"], ) rollup_bundle( diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 056426f50..8c522efcc 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl} from '../../tasks/web/audio/audio_classifier/audio_classifier'; -import {AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/audio_embedder/audio_embedder'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index 39d101237..8f15075c5 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import {TextClassifier as TextClassifierImpl} from '../../tasks/web/text/text_classifier/text_classifier'; -import {TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/text_embedder/text_embedder'; +import {TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 4e4fab43f..74a056464 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,11 +14,7 @@ * limitations under the License. */ -import {GestureRecognizer as GestureRecognizerImpl} from '../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; -import {HandLandmarker as HandLandmarkerImpl} from '../../tasks/web/vision/hand_landmarker/hand_landmarker'; -import {ImageClassifier as ImageClassifierImpl} from '../../tasks/web/vision/image_classifier/image_classifier'; -import {ImageEmbedder as ImageEmbedderImpl} from '../../tasks/web/vision/image_embedder/image_embedder'; -import {ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/object_detector/object_detector'; +import {GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. From a55839de51dafe27b4c2b705954444895a842c3c Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 22 Nov 2022 18:07:26 -0800 Subject: [PATCH 098/137] This storage only needs a "done writing" callback on simulator, so only set it there - When not on simulator, we pass nullptr instead of a do-nothing callback. - The callback is no longer a method, but a function. Only the CVPixelBuffer is captured. PiperOrigin-RevId: 490380248 --- .../gpu/gpu_buffer_storage_cv_pixel_buffer.cc | 45 +++++++++++-------- .../gpu/gpu_buffer_storage_cv_pixel_buffer.h | 1 - 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc index f3954a6e4..014cc1c69 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.cc @@ -70,25 +70,9 @@ GlTextureView GpuBufferStorageCvPixelBuffer::GetReadView( return GetTexture(plane, nullptr); } -GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types, int plane) { - return GetTexture(plane, [this](const mediapipe::GlTextureView& view) { - ViewDoneWriting(view); - }); -} - -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( - internal::types) const { - return CreateImageFrameForCVPixelBuffer(**this); -} -std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( - internal::types) { - return CreateImageFrameForCVPixelBuffer(**this); -} - -void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { #if TARGET_IPHONE_SIMULATOR - CVPixelBufferRef pixel_buffer = **this; +static void ViewDoneWritingSimulatorWorkaround(CVPixelBufferRef pixel_buffer, + const GlTextureView& view) { CHECK(pixel_buffer); CVReturn err = CVPixelBufferLockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) @@ -126,7 +110,30 @@ void GpuBufferStorageCvPixelBuffer::ViewDoneWriting(const GlTextureView& view) { err = CVPixelBufferUnlockBaseAddress(pixel_buffer, 0); CHECK(err == kCVReturnSuccess) << "CVPixelBufferUnlockBaseAddress failed: " << err; -#endif +} +#endif // TARGET_IPHONE_SIMULATOR + +GlTextureView GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types, int plane) { + return GetTexture(plane, +#if TARGET_IPHONE_SIMULATOR + [pixel_buffer = CFHolder(*this)]( + const mediapipe::GlTextureView& view) { + ViewDoneWritingSimulatorWorkaround(*pixel_buffer, view); + } +#else + nullptr +#endif // TARGET_IPHONE_SIMULATOR + ); +} + +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetReadView( + internal::types) const { + return CreateImageFrameForCVPixelBuffer(**this); +} +std::shared_ptr GpuBufferStorageCvPixelBuffer::GetWriteView( + internal::types) { + return CreateImageFrameForCVPixelBuffer(**this); } static std::shared_ptr ConvertFromImageFrame( diff --git a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h index a9389ab8a..8723a1087 100644 --- a/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h +++ b/mediapipe/gpu/gpu_buffer_storage_cv_pixel_buffer.h @@ -63,7 +63,6 @@ class GpuBufferStorageCvPixelBuffer private: GlTextureView GetTexture(int plane, GlTextureView::DoneWritingFn done_writing) const; - void ViewDoneWriting(const GlTextureView& view); }; inline CFHolder GpuBufferStorageCvPixelBuffer::GetReadView( From 05681fc0e17089a4e1d3f999bd17f3020cabb9bc Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 01:26:15 -0800 Subject: [PATCH 099/137] Internal PiperOrigin-RevId: 490439195 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 8b09260bd..762184842 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -18,7 +18,6 @@ load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_build load("@build_bazel_rules_android//android:rules.bzl", "android_library") _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [ - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_java_proto_lite", From c5ce5236972a6045f42bb23d526ebb27a7e58bb7 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Wed, 23 Nov 2022 02:02:18 -0800 Subject: [PATCH 100/137] Add cosine APIs to Embedder tasks PiperOrigin-RevId: 490444597 --- .../tasks/web/audio/audio_embedder/BUILD | 1 + .../audio/audio_embedder/audio_embedder.ts | 15 +++++ mediapipe/tasks/web/components/utils/BUILD | 11 ++++ .../web/components/utils/cosine_similarity.ts | 62 +++++++++++++++++++ mediapipe/tasks/web/text/text_embedder/BUILD | 1 + .../web/text/text_embedder/text_embedder.ts | 15 +++++ .../tasks/web/vision/image_embedder/BUILD | 1 + .../vision/image_embedder/image_embedder.ts | 15 +++++ 8 files changed, 121 insertions(+) create mode 100644 mediapipe/tasks/web/components/utils/BUILD create mode 100644 mediapipe/tasks/web/components/utils/cosine_similarity.ts diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 7d9a994a3..1a66464bd 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 46a7b6729..9dce02862 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -20,8 +20,10 @@ import {AudioEmbedderGraphOptions as AudioEmbedderGraphOptionsProto} from '../.. import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -144,6 +146,19 @@ export class AudioEmbedder extends AudioTaskRunner { return this.processAudioClip(audioData, sampleRate); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { diff --git a/mediapipe/tasks/web/components/utils/BUILD b/mediapipe/tasks/web/components/utils/BUILD new file mode 100644 index 000000000..1c1ba69ca --- /dev/null +++ b/mediapipe/tasks/web/components/utils/BUILD @@ -0,0 +1,11 @@ +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +mediapipe_ts_library( + name = "cosine_similarity", + srcs = ["cosine_similarity.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:embedding_result", + ], +) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts new file mode 100644 index 000000000..fb1d0c185 --- /dev/null +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -0,0 +1,62 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + *

Licensed under the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + *

http://www.apache.org/licenses/LICENSE-2.0 + * + *

Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; + +/** + * Computes cosine similarity[1] between two `Embedding` objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types (float vs. quantized), + * have different sizes, or have an L2-norm of 0. + */ +export function computeCosineSimilarity(u: Embedding, v: Embedding): number { + if (u.floatEmbedding && v.floatEmbedding) { + return compute(u.floatEmbedding, v.floatEmbedding); + } + if (u.quantizedEmbedding && v.quantizedEmbedding) { + return compute( + convertToBytes(u.quantizedEmbedding), + convertToBytes(v.quantizedEmbedding)); + } + throw new Error( + 'Cannot compute cosine similarity between quantized and float embeddings.'); +} +function convertToBytes(data: Uint8Array): number[] { + return Array.from(data, v => v - 128); +} + +function compute(u: number[], v: number[]) { + if (u.length !== v.length) { + throw new Error( + `Cannot compute cosine similarity between embeddings of different sizes (${ + u.length} vs. ${v.length}).`); + } + let dotProduct = 0.0; + let normU = 0.0; + let normV = 0.0; + for (let i = 0; i < u.length; i++) { + dotProduct += u[i] * v[i]; + normU += u[i] * u[i]; + normV += v[i] * v[i]; + } + if (normU <= 0 || normV <= 0) { + throw new Error( + 'Cannot compute cosine similarity on embedding with 0 norm.'); + } + return dotProduct / Math.sqrt(normU * normV); +} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index c555f8d33..3f92b8ae1 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -22,6 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 57b91d575..2042a0985 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -18,9 +18,11 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; @@ -143,6 +145,19 @@ export class TextEmbedder extends TaskRunner { return this.embeddingResult; } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); diff --git a/mediapipe/tasks/web/vision/image_embedder/BUILD b/mediapipe/tasks/web/vision/image_embedder/BUILD index feb3ae054..2f012dc5e 100644 --- a/mediapipe/tasks/web/vision/image_embedder/BUILD +++ b/mediapipe/tasks/web/vision/image_embedder/BUILD @@ -21,6 +21,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", + "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/vision/core:vision_task_options", diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index c60665052..f96f1e961 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -19,8 +19,10 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ImageEmbedderGraphOptions} from '../../../../tasks/cc/vision/image_embedder/proto/image_embedder_graph_options_pb'; +import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; +import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; @@ -157,6 +159,19 @@ export class ImageEmbedder extends VisionTaskRunner { return this.processVideoData(imageFrame, timestamp); } + /** + * Utility function to compute cosine similarity[1] between two `Embedding` + * objects. + * + * [1]: https://en.wikipedia.org/wiki/Cosine_similarity + * + * @throws if the embeddings are of different types(float vs. quantized), have + * different sizes, or have an L2-norm of 0. + */ + static cosineSimilarity(u: Embedding, v: Embedding): number { + return computeCosineSimilarity(u, v); + } + /** Runs the embedding extraction and blocks on the response. */ protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { From b5189758f7fc913e050ae0e6d4f7f999365e8118 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 02:03:35 -0800 Subject: [PATCH 101/137] Move ImagePreprocessing to "processors" folder. PiperOrigin-RevId: 490444821 --- mediapipe/tasks/cc/components/BUILD | 45 --- .../tasks/cc/components/processors/BUILD | 33 ++ .../image_preprocessing_graph.cc} | 42 ++- .../image_preprocessing_graph.h} | 26 +- .../image_preprocessing_graph_test.cc | 343 ++++++++++++++++++ .../cc/components/processors/proto/BUILD | 10 + .../image_preprocessing_graph_options.proto} | 6 +- .../tasks/cc/vision/gesture_recognizer/BUILD | 4 - .../gesture_recognizer/gesture_recognizer.cc | 1 - .../hand_gesture_recognizer_graph.cc | 2 - mediapipe/tasks/cc/vision/hand_detector/BUILD | 2 +- .../hand_detector/hand_detector_graph.cc | 20 +- .../tasks/cc/vision/hand_landmarker/BUILD | 3 +- .../vision/hand_landmarker/hand_landmarker.cc | 1 - .../hand_landmarks_detector_graph.cc | 17 +- .../tasks/cc/vision/image_classifier/BUILD | 4 +- .../image_classifier_graph.cc | 19 +- .../tasks/cc/vision/image_embedder/BUILD | 4 +- .../image_embedder/image_embedder_graph.cc | 19 +- .../tasks/cc/vision/image_segmenter/BUILD | 4 +- .../image_segmenter/image_segmenter_graph.cc | 19 +- .../tasks/cc/vision/object_detector/BUILD | 2 +- .../object_detector/object_detector_graph.cc | 17 +- 23 files changed, 493 insertions(+), 150 deletions(-) rename mediapipe/tasks/cc/components/{image_preprocessing.cc => processors/image_preprocessing_graph.cc} (90%) rename mediapipe/tasks/cc/components/{image_preprocessing.h => processors/image_preprocessing_graph.h} (72%) create mode 100644 mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc rename mediapipe/tasks/cc/components/{image_preprocessing_options.proto => processors/proto/image_preprocessing_graph_options.proto} (89%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index c90349ab2..54a5207d2 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -12,55 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "image_preprocessing_options_proto", - srcs = ["image_preprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) - -cc_library( - name = "image_preprocessing", - srcs = ["image_preprocessing.cc"], - hdrs = ["image_preprocessing.h"], - deps = [ - ":image_preprocessing_options_cc_proto", - "//mediapipe/calculators/core:pass_through_calculator", - "//mediapipe/calculators/image:image_clone_calculator", - "//mediapipe/calculators/image:image_clone_calculator_cc_proto", - "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator", - "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:image", - "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/framework/formats:tensor", - "//mediapipe/gpu:gpu_origin_cc_proto", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - -# TODO: Enable this test - # TODO: Investigate rewriting the build rule to only link # the Bert Preprocessor if it's needed. cc_library( diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 32a628db7..4946683f5 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -100,3 +100,36 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "image_preprocessing_graph", + srcs = ["image_preprocessing_graph.cc"], + hdrs = ["image_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/image:image_clone_calculator", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/gpu:gpu_origin_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", + "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc similarity index 90% rename from mediapipe/tasks/cc/components/image_preprocessing.cc rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc index ef447df97..b24b7f0cb 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" @@ -42,6 +42,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::Tensor; @@ -144,9 +145,9 @@ bool DetermineImagePreprocessingGpuBackend( return acceleration.has_gpu(); } -absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, - bool use_gpu, - ImagePreprocessingOptions* options) { +absl::Status ConfigureImagePreprocessingGraph( + const ModelResources& model_resources, bool use_gpu, + proto::ImagePreprocessingGraphOptions* options) { ASSIGN_OR_RETURN(auto image_tensor_specs, BuildImageTensorSpecs(model_resources)); MP_RETURN_IF_ERROR(ConfigureImageToTensorCalculator( @@ -154,9 +155,9 @@ absl::Status ConfigureImagePreprocessing(const ModelResources& model_resources, // The GPU backend isn't able to process int data. If the input tensor is // quantized, forces the image preprocessing graph to use CPU backend. if (use_gpu && image_tensor_specs.tensor_type != tflite::TensorType_UINT8) { - options->set_backend(ImagePreprocessingOptions::GPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::GPU_BACKEND); } else { - options->set_backend(ImagePreprocessingOptions::CPU_BACKEND); + options->set_backend(proto::ImagePreprocessingGraphOptions::CPU_BACKEND); } return absl::OkStatus(); } @@ -170,8 +171,7 @@ Source AddDataConverter(Source image_in, Graph& graph, return image_converter[Output("")]; } -// A "mediapipe.tasks.components.ImagePreprocessingSubgraph" performs image -// preprocessing. +// An ImagePreprocessingGraph performs image preprocessing. // - Accepts CPU input images and outputs CPU tensors. // // Inputs: @@ -192,7 +192,7 @@ Source AddDataConverter(Source image_in, Graph& graph, // An std::array representing the letterbox padding from the 4 // sides ([left, top, right, bottom]) of the output image, normalized to // [0.f, 1.f] by the output dimensions. The padding values are non-zero only -// when the "keep_aspect_ratio" is true in ImagePreprocessingOptions. +// when the "keep_aspect_ratio" is true in ImagePreprocessingGraphOptions. // IMAGE_SIZE - std::pair @Optional // The size of the original input image as a pair. // IMAGE - Image @Optional @@ -200,15 +200,15 @@ Source AddDataConverter(Source image_in, Graph& graph, // GPU). // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureImagePreprocessing()' function. See header file for more -// details. -class ImagePreprocessingSubgraph : public Subgraph { +// using the 'ConfigureImagePreprocessingGraph()' function. See header file for +// more details. +class ImagePreprocessingGraph : public Subgraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; auto output_streams = BuildImagePreprocessing( - sc->Options(), + sc->Options(), graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph); output_streams.tensors >> graph[Output>(kTensorsTag)]; @@ -233,24 +233,25 @@ class ImagePreprocessingSubgraph : public Subgraph { // - the image that has pixel data stored on the target storage // (mediapipe::Image). // - // options: the mediapipe tasks ImagePreprocessingOptions. + // options: the mediapipe tasks ImagePreprocessingGraphOptions. // image_in: (mediapipe::Image) stream to preprocess. // graph: the mediapipe builder::Graph instance to be updated. ImagePreprocessingOutputStreams BuildImagePreprocessing( - const ImagePreprocessingOptions& options, Source image_in, - Source norm_rect_in, Graph& graph) { + const proto::ImagePreprocessingGraphOptions& options, + Source image_in, Source norm_rect_in, + Graph& graph) { // Convert image to tensor. auto& image_to_tensor = graph.AddNode("ImageToTensorCalculator"); image_to_tensor.GetOptions() .CopyFrom(options.image_to_tensor_options()); switch (options.backend()) { - case ImagePreprocessingOptions::CPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::CPU_BACKEND: { auto cpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/false); cpu_image >> image_to_tensor.In(kImageTag); break; } - case ImagePreprocessingOptions::GPU_BACKEND: { + case proto::ImagePreprocessingGraphOptions::GPU_BACKEND: { auto gpu_image = AddDataConverter(image_in, graph, /*output_on_gpu=*/true); gpu_image >> image_to_tensor.In(kImageTag); @@ -284,8 +285,9 @@ class ImagePreprocessingSubgraph : public Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ImagePreprocessingSubgraph); + ::mediapipe::tasks::components::processors::ImagePreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/image_preprocessing.h b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h similarity index 72% rename from mediapipe/tasks/cc/components/image_preprocessing.h rename to mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h index 6963b6556..455a9b316 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.h +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h @@ -13,35 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -// Configures an ImagePreprocessing subgraph using the provided model resources +// Configures an ImagePreprocessingGraph using the provided model resources // When use_gpu is true, use GPU as backend to convert image to tensor. // - Accepts CPU input images and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.ImagePreprocessingGraph"); // core::proto::Acceleration acceleration; // acceleration.mutable_xnnpack(); // bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); -// MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( +// MP_RETURN_IF_ERROR(ConfigureImagePreprocessingGraph( // model_resources, // use_gpu, -// &preprocessing.GetOptions())); +// &preprocessing.GetOptions())); // -// The resulting ImagePreprocessing subgraph has the following I/O: +// The resulting ImagePreprocessingGraph has the following I/O: // Inputs: // IMAGE - Image // The image to preprocess. @@ -61,17 +62,18 @@ namespace components { // IMAGE - Image @Optional // The image that has the pixel data stored on the target storage (CPU vs // GPU). -absl::Status ConfigureImagePreprocessing( +absl::Status ConfigureImagePreprocessingGraph( const core::ModelResources& model_resources, bool use_gpu, - ImagePreprocessingOptions* options); + proto::ImagePreprocessingGraphOptions* options); -// Determine if the image preprocessing subgraph should use GPU as the backend +// Determine if the image preprocessing graph should use GPU as the backend // according to the given acceleration setting. bool DetermineImagePreprocessingGpuBackend( const core::proto::Acceleration& acceleration); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_IMAGE_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_IMAGE_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc new file mode 100644 index 000000000..6c094c6bc --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -0,0 +1,343 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::ModelResources; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::DecodeImageFromFile; +using ::testing::ContainerEq; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kMobileNetFloatWithMetadata[] = "mobilenet_v2_1.0_224.tflite"; +constexpr char kMobileNetFloatWithoutMetadata[] = + "mobilenet_v1_0.25_224_1_default_1.tflite"; +constexpr char kMobileNetQuantizedWithMetadata[] = + "mobilenet_v1_0.25_224_quant.tflite"; +constexpr char kMobileNetQuantizedWithoutMetadata[] = + "mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; + +constexpr char kTestImage[] = "burger.jpg"; +constexpr int kTestImageWidth = 480; +constexpr int kTestImageHeight = 325; + +constexpr char kTestModelResourcesTag[] = "test_model_resources"; +constexpr std::array kIdentityMatrix = {1, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 1, 0, 0, 0, 0, 1}; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kMatrixName[] = "matrix_out"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kTensorsName[] = "tensors_out"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kImageSizeName[] = "image_size_out"; +constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; +constexpr char kLetterboxPaddingName[] = "letterbox_padding_out"; + +constexpr float kLetterboxMaxAbsError = 1e-5; + +// Helper function to get ModelResources. +absl::StatusOr> CreateModelResourcesForModel( + absl::string_view model_name) { + auto external_file = std::make_unique(); + external_file->set_file_name(JoinPath("./", kTestDataDirectory, model_name)); + return ModelResources::Create(kTestModelResourcesTag, + std::move(external_file)); +} + +// Helper function to create a TaskRunner from ModelResources. +absl::StatusOr> CreateTaskRunner( + const ModelResources& model_resources, bool keep_aspect_ratio) { + Graph graph; + + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + auto& options = + preprocessing.GetOptions(); + options.mutable_image_to_tensor_options()->set_keep_aspect_ratio( + keep_aspect_ratio); + MP_RETURN_IF_ERROR( + ConfigureImagePreprocessingGraph(model_resources, false, &options)); + graph[Input(kImageTag)].SetName(kImageName) >> + preprocessing.In(kImageTag); + preprocessing.Out(kTensorsTag).SetName(kTensorsName) >> + graph[Output>(kTensorsTag)]; + preprocessing.Out(kMatrixTag).SetName(kMatrixName) >> + graph[Output>(kMatrixTag)]; + preprocessing.Out(kImageSizeTag).SetName(kImageSizeName) >> + graph[Output>(kImageSizeTag)]; + preprocessing.Out(kLetterboxPaddingTag).SetName(kLetterboxPaddingName) >> + graph[Output>(kLetterboxPaddingTag)]; + + return TaskRunner::Create(graph.GetConfig()); +} + +class ConfigureTest : public tflite_shims::testing::Test {}; + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 192 + output_tensor_height: 192 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, false, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithQuantizedModelFallbacksCpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetQuantizedWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_uint_range { min: 0 max: 255 } + gpu_origin: TOP_LEFT + } + backend: CPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, SucceedsWithFloatModelGpuBackend) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithMetadata)); + + proto::ImagePreprocessingGraphOptions options; + core::proto::Acceleration acceleration; + acceleration.mutable_gpu(); + bool use_gpu = DetermineImagePreprocessingGpuBackend(acceleration); + EXPECT_TRUE(use_gpu); + MP_EXPECT_OK( + ConfigureImagePreprocessingGraph(*model_resources, use_gpu, &options)); + + EXPECT_THAT(options, EqualsProto( + R"pb(image_to_tensor_options { + output_tensor_width: 224 + output_tensor_height: 224 + output_tensor_float_range { min: -1 max: 1 } + gpu_origin: TOP_LEFT + } + backend: GPU_BACKEND)pb")); +} + +TEST_F(ConfigureTest, FailsWithFloatModelWithoutMetadata) { + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(kMobileNetFloatWithoutMetadata)); + + proto::ImagePreprocessingGraphOptions options; + auto status = + ConfigureImagePreprocessingGraph(*model_resources, false, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_THAT(status.message(), + HasSubstr("requires specifying NormalizationOptions metadata")); +} + +// Struct holding the parameters for parameterized PreprocessingTest class. +struct PreprocessingParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // If true, keep test image aspect ratio. + bool keep_aspect_ratio; + // The expected output tensor type. + Tensor::ElementType expected_type; + // The expected outoput tensor shape. + std::vector expected_shape; + // The expected output letterbox padding; + std::array expected_letterbox_padding; +}; + +class PreprocessingTest : public testing::TestWithParam {}; + +TEST_P(PreprocessingTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kTestImage))); + MP_ASSERT_OK_AND_ASSIGN( + auto model_resources, + CreateModelResourcesForModel(GetParam().input_model_name)); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, + CreateTaskRunner(*model_resources, GetParam().keep_aspect_ratio)); + + auto output_packets = + task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + MP_ASSERT_OK(output_packets); + + const std::vector& tensors = + (*output_packets)[kTensorsName].Get>(); + EXPECT_EQ(tensors.size(), 1); + EXPECT_EQ(tensors[0].element_type(), GetParam().expected_type); + EXPECT_THAT(tensors[0].shape().dims, ContainerEq(GetParam().expected_shape)); + auto& matrix = (*output_packets)[kMatrixName].Get>(); + if (!GetParam().keep_aspect_ratio) { + for (int i = 0; i < matrix.size(); ++i) { + EXPECT_FLOAT_EQ(matrix[i], kIdentityMatrix[i]); + } + } + auto& image_size = + (*output_packets)[kImageSizeName].Get>(); + EXPECT_EQ(image_size.first, kTestImageWidth); + EXPECT_EQ(image_size.second, kTestImageHeight); + std::array letterbox_padding = + (*output_packets)[kLetterboxPaddingName].Get>(); + for (int i = 0; i < letterbox_padding.size(); ++i) { + EXPECT_NEAR(letterbox_padding[i], GetParam().expected_letterbox_padding[i], + kLetterboxMaxAbsError); + } +} + +INSTANTIATE_TEST_SUITE_P( + PreprocessingTest, PreprocessingTest, + Values( + PreprocessingParams{.test_name = "kMobileNetQuantizedWithMetadata", + .input_model_name = kMobileNetQuantizedWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetQuantizedWithoutMetadata", + .input_model_name = kMobileNetQuantizedWithoutMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kUInt8, + .expected_shape = {1, 192, 192, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{.test_name = "kMobileNetFloatWithMetadata", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = false, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {0, 0, 0, 0}}, + PreprocessingParams{ + .test_name = "kMobileNetFloatWithMetadataKeepAspectRatio", + .input_model_name = kMobileNetFloatWithMetadata, + .keep_aspect_ratio = true, + .expected_type = Tensor::ElementType::kFloat32, + .expected_shape = {1, 224, 224, 3}, + .expected_letterbox_padding = {/*left*/ 0, + /*top*/ 0.161458, + /*right*/ 0, + /*bottom*/ 0.161458}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 23ebbe008..9c58a8585 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -49,3 +49,13 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_proto", ], ) + +mediapipe_proto_library( + name = "image_preprocessing_graph_options_proto", + srcs = ["image_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:image_to_tensor_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/image_preprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto similarity index 89% rename from mediapipe/tasks/cc/components/image_preprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto index d1685c319..bf4fc9067 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.proto @@ -15,14 +15,14 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/image_to_tensor_calculator.proto"; import "mediapipe/framework/calculator.proto"; -message ImagePreprocessingOptions { +message ImagePreprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ImagePreprocessingOptions ext = 456882436; + optional ImagePreprocessingGraphOptions ext = 456882436; } // Options for the ImageToTensor calculator encapsulated by the diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 7b144e7aa..d473a8dc3 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", @@ -105,10 +104,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", - "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_asset_bundle_resources", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_cache", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index 8d555b12c..e7fcf6fd9 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 7b6a8c79d..d7e983d81 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -29,8 +29,6 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" -#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 71cef6270..55162d09b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -46,7 +46,7 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 06bb2e549..c24548c9b 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -226,21 +226,23 @@ class HandDetectorGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); auto& image_to_tensor_options = *preprocessing - .GetOptions() + .GetOptions() .mutable_image_to_tensor_options(); image_to_tensor_options.set_keep_aspect_ratio(true); image_to_tensor_options.set_border_mode( mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); image_in >> preprocessing.In("IMAGE"); norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 3b869eab4..46948ee6c 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -35,7 +35,6 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", @@ -89,7 +88,7 @@ cc_library( "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", - "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc index 3a9ed5bc2..2b818b2e5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 1f127deb8..014830ba2 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -281,14 +281,15 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - subgraph_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In("IMAGE"); hand_rect >> preprocessing.In("NORM_RECT"); auto image_size = preprocessing[Output>("IMAGE_SIZE")]; diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 2b93aa262..514e601ef 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -59,11 +59,11 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 2fc88bcb6..2d0379c66 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -135,14 +135,15 @@ class ImageClassifierGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index 8fdb97ccd..d729eaf1a 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -57,12 +57,12 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc index bf0dcf3c7..81ccb5361 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_graph.cc @@ -20,10 +20,10 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -130,14 +130,15 @@ class ImageEmbedderGraph : public core::ModelTaskGraph { Source norm_rect_in, Graph& graph) { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 595eef568..2124fe6e0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -56,10 +56,10 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator", "//mediapipe/tasks/cc/components/calculators/tensor:tensors_to_segmentation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:image_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 44742e043..d5eb5af0d 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -27,8 +27,8 @@ limitations under the License. #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" -#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/image_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -243,14 +243,15 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index b8002fa96..c2dd9995d 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -71,9 +71,9 @@ cc_library( "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", "//mediapipe/tasks/cc/components/utils:source_or_node_output", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b149cea0f..f5dc7e061 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" @@ -561,14 +561,15 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Adds preprocessing calculators and connects them to the graph input image // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.ImagePreprocessingSubgraph"); - bool use_gpu = components::DetermineImagePreprocessingGpuBackend( - task_options.base_options().acceleration()); - MP_RETURN_IF_ERROR(ConfigureImagePreprocessing( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( model_resources, use_gpu, - &preprocessing - .GetOptions())); + &preprocessing.GetOptions())); image_in >> preprocessing.In(kImageTag); norm_rect_in >> preprocessing.In(kNormRectTag); From 3c53ec2cdbe5df2aabf6a20f3b6c9b4efa76cb71 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 10:09:42 -0800 Subject: [PATCH 102/137] Do not expose DrishtiGraphGPUData.h in public header This class is an implementation detail. PiperOrigin-RevId: 490530823 --- mediapipe/gpu/BUILD | 7 +------ mediapipe/gpu/MPPMetalHelper.h | 24 +++++++++++------------- mediapipe/gpu/MPPMetalHelper.mm | 6 ++++++ mediapipe/objc/MPPGraph.mm | 1 - 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index f97eed678..42cd9cdc6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -550,12 +550,7 @@ cc_library( name = "gpu_shared_data_header", textual_hdrs = [ "gpu_shared_data_internal.h", - ] + select({ - "//conditions:default": [], - "//mediapipe:apple": [ - "MPPGraphGPUData.h", - ], - }), + ], visibility = ["//visibility:private"], deps = [ ":gl_base", diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h index f3662422e..6ae0f3cf9 100644 --- a/mediapipe/gpu/MPPMetalHelper.h +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -21,37 +21,35 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_type.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" NS_ASSUME_NONNULL_BEGIN @interface MPPMetalHelper : NSObject { - MPPGraphGPUData* _gpuShared; } - (instancetype)init NS_UNAVAILABLE; /// Initialize. This initializer is recommended for calculators. -- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext*)cc; +- (instancetype)initWithCalculatorContext:(mediapipe::CalculatorContext *)cc; /// Initialize. -- (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources +- (instancetype)initWithGpuResources:(mediapipe::GpuResources *)gpuResources NS_DESIGNATED_INITIALIZER; /// Configures a calculator's contract for accessing GPU resources. /// Calculators should use this in GetContract. -+ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc; ++ (absl::Status)updateContract:(mediapipe::CalculatorContract *)cc; /// Deprecated initializer. -- (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; +- (instancetype)initWithSidePackets:(const mediapipe::PacketSet &)inputSidePackets; /// Deprecated initializer. -- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData*)gpuShared; +- (instancetype)initWithGpuSharedData:(mediapipe::GpuSharedData *)gpuShared; /// Configures a calculator's side packets for accessing GPU resources. /// Calculators should use this in FillExpectations. -+ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet *)inputSidePackets; /// Get a metal command buffer. /// Calculators should use this method instead of getting a buffer from the @@ -63,23 +61,23 @@ NS_ASSUME_NONNULL_BEGIN /// Creates a CVMetalTextureRef linked to the provided GpuBuffer. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Creates a CVMetalTextureRef linked to the provided GpuBuffer given a specific plane. /// Ownership follows the copy rule, so the caller is responsible for /// releasing the CVMetalTextureRef. -- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Returns a MTLTexture linked to the provided GpuBuffer. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer; +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer; /// Returns a MTLTexture linked to the provided GpuBuffer given a specific plane. /// A calculator can freely use it as a rendering source, but it should not /// use it as a rendering target if the GpuBuffer was provided as an input. -- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer +- (id)metalTextureWithGpuBuffer:(const mediapipe::GpuBuffer &)gpuBuffer plane:(size_t)plane; /// Obtains a new GpuBuffer to be used as an output destination. @@ -91,7 +89,7 @@ NS_ASSUME_NONNULL_BEGIN format:(mediapipe::GpuBufferFormat)format; /// Convenience method to load a Metal library stored as a bundle resource. -- (id)newLibraryWithResourceName:(NSString*)name error:(NSError* _Nullable*)error; +- (id)newLibraryWithResourceName:(NSString *)name error:(NSError *_Nullable *)error; /// Shared Metal resources. @property(readonly) id mtlDevice; diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index ce6620972..dc1e27a5c 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,11 +14,17 @@ #import "mediapipe/gpu/MPPMetalHelper.h" +#import "mediapipe/gpu/MPPGraphGPUData.h" #import "mediapipe/gpu/graph_support.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" +@interface MPPMetalHelper () { + MPPGraphGPUData* _gpuShared; +} +@end + namespace mediapipe { // Using a C++ class so it can be declared as a friend of LegacyCalculatorSupport. diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 080cca20f..1bd177e80 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -24,7 +24,6 @@ #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/graph_service.h" -#include "mediapipe/gpu/MPPGraphGPUData.h" #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/objc/util.h" From 54d1744c8f5ee102679386b84e3e3812e352bc7a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 10:13:48 -0800 Subject: [PATCH 103/137] Remove DrishtiGraphGPUData, add MetalSharedResources This class is unused except by the Metal helper; let's narrow it down and simplify gpu_shared_data. PiperOrigin-RevId: 490531767 --- mediapipe/gpu/BUILD | 50 +++------ mediapipe/gpu/MPPGraphGPUData.h | 71 ------------- mediapipe/gpu/MPPGraphGPUData.mm | 124 ---------------------- mediapipe/gpu/MPPGraphGPUDataTests.mm | 86 --------------- mediapipe/gpu/MPPMetalHelper.mm | 31 +++--- mediapipe/gpu/gpu_shared_data_internal.cc | 13 +-- mediapipe/gpu/gpu_shared_data_internal.h | 18 ++-- mediapipe/objc/BUILD | 2 +- 8 files changed, 46 insertions(+), 349 deletions(-) delete mode 100644 mediapipe/gpu/MPPGraphGPUData.h delete mode 100644 mediapipe/gpu/MPPGraphGPUData.mm delete mode 100644 mediapipe/gpu/MPPGraphGPUDataTests.mm diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 42cd9cdc6..9cc670fb6 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -470,12 +470,9 @@ objc_library( ) objc_library( - name = "MPPGraphGPUData", - srcs = [ - "MPPGraphGPUData.mm", - "gpu_shared_data_internal.cc", - ], - hdrs = ["MPPGraphGPUData.h"], + name = "metal_shared_resources", + srcs = ["metal_shared_resources.mm"], + hdrs = ["metal_shared_resources.h"], copts = [ "-x objective-c++", "-Wno-shorten-64-to-32", @@ -484,25 +481,9 @@ objc_library( sdk_frameworks = [ "CoreVideo", "Metal", - ] + select({ - "//conditions:default": [ - "OpenGLES", - ], - "//mediapipe:macos": [ - "OpenGL", - "AppKit", - ], - }), + ], visibility = ["//visibility:public"], deps = [ - ":gl_base", - ":gl_context", - ":gpu_buffer_multi_pool", - ":gpu_shared_data_header", - ":graph_support", - ":cv_texture_cache_manager", - "//mediapipe/gpu:gl_context_options_cc_proto", - "//mediapipe/framework:calculator_context", "//mediapipe/framework/port:ret_check", "@google_toolbox_for_mac//:GTM_Defines", ] + [ @@ -584,16 +565,19 @@ cc_library( cc_library( name = "gpu_shared_data_internal_actual", - srcs = select({ - "//conditions:default": [ - "gpu_shared_data_internal.cc", - ], - # iOS uses an Objective-C++ version of this, built in MPPGraphGPUData. - "//mediapipe:apple": [], - }), + srcs = [ + "gpu_shared_data_internal.cc", + ], hdrs = [ "gpu_shared_data_internal.h", ], + copts = select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + }), visibility = ["//visibility:private"], deps = [ "//mediapipe/gpu:gl_context_options_cc_proto", @@ -610,7 +594,7 @@ cc_library( ] + select({ "//conditions:default": [], "//mediapipe:apple": [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":cv_texture_cache_manager", ], }), @@ -1139,8 +1123,8 @@ objc_library( name = "gl_ios_test_lib", testonly = 1, srcs = [ - "MPPGraphGPUDataTests.mm", "gl_ios_test.mm", + "metal_shared_resources_test.mm", ], copts = [ "-Wno-shorten-64-to-32", @@ -1150,7 +1134,7 @@ objc_library( ], features = ["-layering_check"], deps = [ - ":MPPGraphGPUData", + ":metal_shared_resources", ":gl_scaler_calculator", ":gpu_buffer_to_image_frame_calculator", ":gpu_shared_data_internal", diff --git a/mediapipe/gpu/MPPGraphGPUData.h b/mediapipe/gpu/MPPGraphGPUData.h deleted file mode 100644 index 3d8fc0c94..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ -#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ - -#import -#import -#import - -#import "mediapipe/gpu/gl_base.h" -#import "mediapipe/gpu/gl_context.h" - -namespace mediapipe { -class GlContext; -class GpuBufferMultiPool; -} // namespace mediapipe - -@interface MPPGraphGPUData : NSObject { - // Shared buffer pool for GPU calculators. - mediapipe::GpuBufferMultiPool* _gpuBufferPool; - mediapipe::GlContext* _glContext; -} - -- (instancetype)init NS_UNAVAILABLE; - -/// Initialize. The provided multipool pointer must remain valid throughout -/// this object's lifetime. -- (instancetype)initWithContext:(mediapipe::GlContext*)context - multiPool:(mediapipe::GpuBufferMultiPool*)pool NS_DESIGNATED_INITIALIZER; - -/// Shared texture pool for GPU calculators. -/// For internal use by GlCalculatorHelper. -@property(readonly) mediapipe::GpuBufferMultiPool* gpuBufferPool; - -/// Shared OpenGL context. -#if TARGET_OS_OSX -@property(readonly) NSOpenGLContext* glContext; -@property(readonly) NSOpenGLPixelFormat* glPixelFormat; -#else -@property(readonly) EAGLContext* glContext; -#endif // TARGET_OS_OSX - -/// Shared texture cache. -#if TARGET_OS_OSX -@property(readonly) CVOpenGLTextureCacheRef textureCache; -#else -@property(readonly) CVOpenGLESTextureCacheRef textureCache; -#endif // TARGET_OS_OSX - -/// Shared Metal resources. -@property(readonly) id mtlDevice; -@property(readonly) id mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@property(readonly) CVMetalTextureCacheRef mtlTextureCache; -#endif - -@end - -#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_ diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm deleted file mode 100644 index 8ac1eefa5..000000000 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import "mediapipe/gpu/MPPGraphGPUData.h" - -#import "GTMDefines.h" - -#include "mediapipe/gpu/gl_context.h" -#include "mediapipe/gpu/gpu_buffer_multi_pool.h" - -#if TARGET_OS_OSX -#import -#else -#import -#endif // TARGET_OS_OSX - -@implementation MPPGraphGPUData - -@synthesize textureCache = _textureCache; -@synthesize mtlDevice = _mtlDevice; -@synthesize mtlCommandQueue = _mtlCommandQueue; -#if COREVIDEO_SUPPORTS_METAL -@synthesize mtlTextureCache = _mtlTextureCache; -#endif - -#if TARGET_OS_OSX -typedef CVOpenGLTextureCacheRef CVTextureCacheType; -#else -typedef CVOpenGLESTextureCacheRef CVTextureCacheType; -#endif // TARGET_OS_OSX - -- (instancetype)initWithContext:(mediapipe::GlContext *)context - multiPool:(mediapipe::GpuBufferMultiPool *)pool { - self = [super init]; - if (self) { - _gpuBufferPool = pool; - _glContext = context; - } - return self; -} - -- (void)dealloc { - if (_textureCache) { - _textureCache = NULL; - } -#if COREVIDEO_SUPPORTS_METAL - if (_mtlTextureCache) { - CFRelease(_mtlTextureCache); - _mtlTextureCache = NULL; - } -#endif -} - -#if TARGET_OS_OSX -- (NSOpenGLContext *)glContext { - return _glContext->nsgl_context(); -} - -- (NSOpenGLPixelFormat *) glPixelFormat { - return _glContext->nsgl_pixel_format(); -} -#else -- (EAGLContext *)glContext { - return _glContext->eagl_context(); -} -#endif // TARGET_OS_OSX - -- (CVTextureCacheType)textureCache { - @synchronized(self) { - if (!_textureCache) { - _textureCache = _glContext->cv_texture_cache(); - } - } - return _textureCache; -} - -- (mediapipe::GpuBufferMultiPool *)gpuBufferPool { - return _gpuBufferPool; -} - -- (id)mtlDevice { - @synchronized(self) { - if (!_mtlDevice) { - _mtlDevice = MTLCreateSystemDefaultDevice(); - } - } - return _mtlDevice; -} - -- (id)mtlCommandQueue { - @synchronized(self) { - if (!_mtlCommandQueue) { - _mtlCommandQueue = [self.mtlDevice newCommandQueue]; - } - } - return _mtlCommandQueue; -} - -#if COREVIDEO_SUPPORTS_METAL -- (CVMetalTextureCacheRef)mtlTextureCache { - @synchronized(self) { - if (!_mtlTextureCache) { - CVReturn __unused err = - CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); - NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); - // TODO: register and flush metal caches too. - } - } - return _mtlTextureCache; -} -#endif - -@end diff --git a/mediapipe/gpu/MPPGraphGPUDataTests.mm b/mediapipe/gpu/MPPGraphGPUDataTests.mm deleted file mode 100644 index e8b50845b..000000000 --- a/mediapipe/gpu/MPPGraphGPUDataTests.mm +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#import -#import - -#include - -#include "absl/memory/memory.h" -#include "mediapipe/framework/port/threadpool.h" - -#import "mediapipe/gpu/MPPGraphGPUData.h" -#import "mediapipe/gpu/gpu_shared_data_internal.h" - -@interface MPPGraphGPUDataTests : XCTestCase { -} -@end - -@implementation MPPGraphGPUDataTests - -// This test verifies that the internal Objective-C object is correctly -// released when the C++ wrapper is released. -- (void)testCorrectlyReleased { - __weak id gpuData = nil; - std::weak_ptr gpuRes; - @autoreleasepool { - mediapipe::GpuSharedData gpu_shared; - gpuRes = gpu_shared.gpu_resources; - gpuData = gpu_shared.gpu_resources->ios_gpu_data(); - XCTAssertNotEqual(gpuRes.lock(), nullptr); - XCTAssertNotNil(gpuData); - } - XCTAssertEqual(gpuRes.lock(), nullptr); - XCTAssertNil(gpuData); -} - -// This test verifies that the lazy initialization of the glContext instance -// variable is thread-safe. All threads should read the same value. -- (void)testGlContextThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - EAGLContext* ogl_context[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &ogl_context, i] { - ogl_context[i] = gpu_shared.gpu_resources->ios_gpu_data().glContext; - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(ogl_context[i], ogl_context[i + 1]); - } -} - -// This test verifies that the lazy initialization of the textureCache instance -// variable is thread-safe. All threads should read the same value. -- (void)testTextureCacheThreadSafeLazyInitialization { - mediapipe::GpuSharedData gpu_shared; - constexpr int kNumThreads = 10; - CFHolder texture_cache[kNumThreads]; - auto pool = absl::make_unique(kNumThreads); - pool->StartWorkers(); - for (int i = 0; i < kNumThreads; ++i) { - pool->Schedule([&gpu_shared, &texture_cache, i] { - texture_cache[i].reset(gpu_shared.gpu_resources->ios_gpu_data().textureCache); - }); - } - pool.reset(); - for (int i = 0; i < kNumThreads - 1; ++i) { - XCTAssertEqual(*texture_cache[i], *texture_cache[i + 1]); - } -} - -@end diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index dc1e27a5c..1acf7cbfb 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -14,14 +14,15 @@ #import "mediapipe/gpu/MPPMetalHelper.h" -#import "mediapipe/gpu/MPPGraphGPUData.h" +#import "mediapipe/gpu/gpu_buffer.h" #import "mediapipe/gpu/graph_support.h" +#import "mediapipe/gpu/metal_shared_resources.h" #import "GTMDefines.h" #include "mediapipe/framework/port/ret_check.h" @interface MPPMetalHelper () { - MPPGraphGPUData* _gpuShared; + mediapipe::GpuResources* _gpuResources; } @end @@ -46,7 +47,7 @@ class MetalHelperLegacySupport { - (instancetype)initWithGpuResources:(mediapipe::GpuResources*)gpuResources { self = [super init]; if (self) { - _gpuShared = gpuResources->ios_gpu_data(); + _gpuResources = gpuResources; } return self; } @@ -111,19 +112,19 @@ class MetalHelperLegacySupport { } - (id)mtlDevice { - return _gpuShared.mtlDevice; + return _gpuResources->metal_shared().resources().mtlDevice; } - (id)mtlCommandQueue { - return _gpuShared.mtlCommandQueue; + return _gpuResources->metal_shared().resources().mtlCommandQueue; } - (CVMetalTextureCacheRef)mtlTextureCache { - return _gpuShared.mtlTextureCache; + return _gpuResources->metal_shared().resources().mtlTextureCache; } - (id)commandBuffer { - return [_gpuShared.mtlCommandQueue commandBuffer]; + return [_gpuResources->metal_shared().resources().mtlCommandQueue commandBuffer]; } - (CVMetalTextureRef)copyCVMetalTextureWithGpuBuffer:(const mediapipe::GpuBuffer&)gpuBuffer @@ -175,8 +176,9 @@ class MetalHelperLegacySupport { CVMetalTextureRef texture; CVReturn err = CVMetalTextureCacheCreateTextureFromImage( - NULL, _gpuShared.mtlTextureCache, mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, - metalPixelFormat, width, height, plane, &texture); + NULL, _gpuResources->metal_shared().resources().mtlTextureCache, + mediapipe::GetCVPixelBufferRef(gpuBuffer), NULL, metalPixelFormat, width, height, plane, + &texture); CHECK_EQ(err, kCVReturnSuccess); return texture; } @@ -197,19 +199,20 @@ class MetalHelperLegacySupport { } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height { - return _gpuShared.gpuBufferPool->GetBuffer(width, height); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height); } - (mediapipe::GpuBuffer)mediapipeGpuBufferWithWidth:(int)width height:(int)height format:(mediapipe::GpuBufferFormat)format { - return _gpuShared.gpuBufferPool->GetBuffer(width, height, format); + return _gpuResources->gpu_buffer_pool().GetBuffer(width, height, format); } - (id)newLibraryWithResourceName:(NSString*)name error:(NSError * _Nullable *)error { - return [_gpuShared.mtlDevice newLibraryWithFile:[[NSBundle bundleForClass:[self class]] - pathForResource:name ofType:@"metallib"] - error:error]; + return [_gpuResources->metal_shared().resources().mtlDevice + newLibraryWithFile:[[NSBundle bundleForClass:[self class]] pathForResource:name + ofType:@"metallib"] + error:error]; } @end diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 91723a7d1..203a8dfd1 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -21,7 +21,7 @@ #include "mediapipe/gpu/graph_support.h" #if __APPLE__ -#import "mediapipe/gpu/MPPGraphGPUData.h" +#include "mediapipe/gpu/metal_shared_resources.h" #endif // __APPLE__ namespace mediapipe { @@ -97,15 +97,14 @@ GpuResources::GpuResources(std::shared_ptr gl_context) #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER texture_caches_->RegisterTextureCache(gl_context->cv_texture_cache()); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - ios_gpu_data_ = [[MPPGraphGPUData alloc] initWithContext:gl_context.get() - multiPool:&gpu_buffer_pool_]; + metal_shared_ = std::make_unique(); #endif // __APPLE__ } GpuResources::~GpuResources() { #if __APPLE__ - // Note: on Apple platforms, this object contains Objective-C objects. The - // destructor will release them, but ARC must be on. + // Note: on Apple platforms, this object contains Objective-C objects. + // The destructor will release them, but ARC must be on. #if !__has_feature(objc_arc) #error This file must be built with ARC. #endif @@ -196,10 +195,6 @@ GlContext::StatusOrGlContext GpuResources::GetOrCreateGlContext( GpuSharedData::GpuSharedData() : GpuSharedData(kPlatformGlContextNone) {} -#if __APPLE__ -MPPGraphGPUData* GpuResources::ios_gpu_data() { return ios_gpu_data_; } -#endif // __APPLE__ - extern const GraphService kGpuService; #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index 4fe6ba04e..3f7c67e2e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -31,15 +31,14 @@ #ifdef __APPLE__ #include "mediapipe/gpu/cv_texture_cache_manager.h" -#ifdef __OBJC__ -@class MPPGraphGPUData; -#else -struct MPPGraphGPUData; -#endif // __OBJC__ #endif // defined(__APPLE__) namespace mediapipe { +#ifdef __APPLE__ +class MetalSharedResources; +#endif // defined(__APPLE__) + // TODO: rename to GpuService or GpuManager or something. class GpuResources { public: @@ -56,9 +55,7 @@ class GpuResources { // Shared GL context for calculators. // TODO: require passing a context or node identifier. - const std::shared_ptr& gl_context() { - return gl_context(nullptr); - }; + const std::shared_ptr& gl_context() { return gl_context(nullptr); } const std::shared_ptr& gl_context(CalculatorContext* cc); @@ -66,7 +63,7 @@ class GpuResources { GpuBufferMultiPool& gpu_buffer_pool() { return gpu_buffer_pool_; } #ifdef __APPLE__ - MPPGraphGPUData* ios_gpu_data(); + MetalSharedResources& metal_shared() { return *metal_shared_; } #endif // defined(__APPLE__)§ absl::Status PrepareGpuNode(CalculatorNode* node); @@ -96,8 +93,7 @@ class GpuResources { GpuBufferMultiPool gpu_buffer_pool_; #ifdef __APPLE__ - // Note that this is an Objective-C object. - MPPGraphGPUData* ios_gpu_data_; + std::unique_ptr metal_shared_; #endif // defined(__APPLE__) std::map> named_executors_; diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index d77692164..fafdfee8a 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -83,11 +83,11 @@ objc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:threadpool", - "//mediapipe/gpu:MPPGraphGPUData", "//mediapipe/gpu:gl_base", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:graph_support", + "//mediapipe/gpu:metal_shared_resources", "//mediapipe/gpu:pixel_buffer_pool_util", "//mediapipe/util:cpu_util", "@com_google_absl//absl/base:core_headers", From bfa57310c4dfb43e9ea3d5b24059b7e042836911 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 10:17:46 -0800 Subject: [PATCH 104/137] Move TextPreprocessing to "processors" folder. PiperOrigin-RevId: 490532670 --- mediapipe/tasks/cc/components/BUILD | 43 ------------------- .../tasks/cc/components/processors/BUILD | 26 +++++++++++ .../cc/components/processors/proto/BUILD | 9 ++++ .../text_preprocessing_graph_options.proto | 2 +- .../text_preprocessing_graph.cc | 22 +++++----- .../text_preprocessing_graph.h | 30 +++++++------ mediapipe/tasks/cc/components/proto/BUILD | 9 ---- mediapipe/tasks/cc/text/text_classifier/BUILD | 4 +- .../text_classifier/text_classifier_graph.cc | 12 +++--- mediapipe/tasks/cc/text/text_embedder/BUILD | 4 +- .../text/text_embedder/text_embedder_graph.cc | 12 +++--- 11 files changed, 80 insertions(+), 93 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/BUILD rename mediapipe/tasks/cc/components/{ => processors}/proto/text_preprocessing_graph_options.proto (96%) rename mediapipe/tasks/cc/components/{ => processors}/text_preprocessing_graph.cc (94%) rename mediapipe/tasks/cc/components/{ => processors}/text_preprocessing_graph.h (67%) diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD deleted file mode 100644 index 54a5207d2..000000000 --- a/mediapipe/tasks/cc/components/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -licenses(["notice"]) - -# TODO: Investigate rewriting the build rule to only link -# the Bert Preprocessor if it's needed. -cc_library( - name = "text_preprocessing_graph", - srcs = ["text_preprocessing_graph.cc"], - hdrs = ["text_preprocessing_graph.h"], - deps = [ - "//mediapipe/calculators/tensor:bert_preprocessor_calculator", - "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator", - "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", - "//mediapipe/calculators/tensor:text_to_tensor_calculator", - "//mediapipe/framework:subgraph", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], - alwayslink = 1, -) diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD index 4946683f5..185bf231b 100644 --- a/mediapipe/tasks/cc/components/processors/BUILD +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -133,3 +133,29 @@ cc_library( ) # TODO: Enable this test + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD index 9c58a8585..f48c4bad8 100644 --- a/mediapipe/tasks/cc/components/processors/proto/BUILD +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -59,3 +59,12 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", ], ) + +mediapipe_proto_library( + name = "text_preprocessing_graph_options_proto", + srcs = ["text_preprocessing_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto similarity index 96% rename from mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto rename to mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto index 926e3d7fb..a67cfd8a9 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc similarity index 94% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.cc rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc index 6aad8fdd5..de16375bd 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include @@ -25,13 +25,14 @@ limitations under the License. #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/subgraph.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -41,7 +42,8 @@ using ::mediapipe::api2::SideInput; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::SideSource; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::TextPreprocessingGraphOptions; +using ::mediapipe::tasks::components::processors::proto:: + TextPreprocessingGraphOptions; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; @@ -169,7 +171,7 @@ absl::StatusOr GetMaxSeqLen(const tflite::SubGraph& model_graph) { } } // namespace -absl::Status ConfigureTextPreprocessingSubgraph( +absl::Status ConfigureTextPreprocessingGraph( const ModelResources& model_resources, TextPreprocessingGraphOptions& options) { if (model_resources.GetTfLiteModel()->subgraphs()->size() != 1) { @@ -200,8 +202,7 @@ absl::Status ConfigureTextPreprocessingSubgraph( return absl::OkStatus(); } -// A "mediapipe.tasks.components.TextPreprocessingSubgraph" performs text -// preprocessing. +// A TextPreprocessingGraph performs text preprocessing. // - Accepts a std::string input and outputs CPU tensors. // // Inputs: @@ -216,9 +217,9 @@ absl::Status ConfigureTextPreprocessingSubgraph( // Vector containing the preprocessed input tensors for the TFLite model. // // The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureTextPreprocessing()' function. See header file for more -// details. -class TextPreprocessingSubgraph : public mediapipe::Subgraph { +// using the 'ConfigureTextPreprocessingGraph()' function. See header file for +// more details. +class TextPreprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -267,8 +268,9 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { } }; REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::TextPreprocessingSubgraph); + ::mediapipe::tasks::components::processors::TextPreprocessingGraph); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.h b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h similarity index 67% rename from mediapipe/tasks/cc/components/text_preprocessing_graph.h rename to mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h index b031a5550..43d57be29 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h @@ -13,26 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ #include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" -// Configures a TextPreprocessing subgraph using the provided `model_resources` +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a TextPreprocessingGraph using the provided `model_resources` // and TextPreprocessingGraphOptions. // - Accepts a std::string input and outputs CPU tensors. // // Example usage: // // auto& preprocessing = -// graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); +// graph.AddNode("mediapipe.tasks.components.processors.TextPreprocessingSubgraph"); // MP_RETURN_IF_ERROR(ConfigureTextPreprocessingSubgraph( // model_resources, // &preprocessing.GetOptions())); // -// The resulting TextPreprocessing subgraph has the following I/O: +// The resulting TextPreprocessingGraph has the following I/O: // Inputs: // TEXT - std::string // The text to preprocess. @@ -43,16 +48,13 @@ limitations under the License. // Outputs: // TENSORS - std::vector // Vector containing the preprocessed input tensors for the TFLite model. -namespace mediapipe { -namespace tasks { -namespace components { - -absl::Status ConfigureTextPreprocessingSubgraph( - const tasks::core::ModelResources& model_resources, - tasks::components::proto::TextPreprocessingGraphOptions& options); +absl::Status ConfigureTextPreprocessingGraph( + const core::ModelResources& model_resources, + proto::TextPreprocessingGraphOptions& options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_TEXT_PREPROCESSING_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_TEXT_PREPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 4534a1652..569023753 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -22,12 +22,3 @@ mediapipe_proto_library( name = "segmenter_options_proto", srcs = ["segmenter_options.proto"], ) - -mediapipe_proto_library( - name = "text_preprocessing_graph_options_proto", - srcs = ["text_preprocessing_graph_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - ], -) diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 01adc9fc3..61395cf4e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -52,11 +52,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources_calculator", "//mediapipe/tasks/cc/core:model_task_graph", diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 9a7dce1aa..3be92f309 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -115,12 +115,12 @@ class TextClassifierGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index 27c9cb730..f19af35be 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -54,11 +54,11 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", - "//mediapipe/tasks/cc/components:text_preprocessing_graph", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/processors:embedding_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:text_preprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto", - "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:text_preprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc index c54636ee2..225ef07bd 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_graph.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" -#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/text_preprocessing_graph.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" @@ -107,12 +107,12 @@ class TextEmbedderGraph : public core::ModelTaskGraph { Graph& graph) { // Adds preprocessing calculators and connects them to the text input // stream. - auto& preprocessing = - graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); - MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.TextPreprocessingGraph"); + MP_RETURN_IF_ERROR(components::processors::ConfigureTextPreprocessingGraph( model_resources, preprocessing.GetOptions< - tasks::components::proto::TextPreprocessingGraphOptions>())); + components::processors::proto::TextPreprocessingGraphOptions>())); text_in >> preprocessing.In(kTextTag); // Adds both InferenceCalculator and ModelResourcesCalculator. From 41a7f9d7d6fdc0bfd1c9e7d4cc00532512474de2 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 23 Nov 2022 15:23:02 -0800 Subject: [PATCH 105/137] Internal change PiperOrigin-RevId: 490595529 --- mediapipe/web/graph_runner/graph_runner.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index c4654794c..378bc0a4d 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -176,10 +176,14 @@ export class GraphRunner { if (glCanvas !== undefined) { this.wasmModule.canvas = glCanvas; - } else { + } else if (typeof OffscreenCanvas !== 'undefined') { // If no canvas is provided, assume Chrome/Firefox and just make an // OffscreenCanvas for GPU processing. this.wasmModule.canvas = new OffscreenCanvas(1, 1); + } else { + console.warn('OffscreenCanvas not detected and GraphRunner constructor ' + + 'glCanvas parameter is undefined. Creating backup canvas.'); + this.wasmModule.canvas = document.createElement('canvas'); } } From 0bdb48ceb18a772158b92793daf6ac4bf8ce6f76 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Wed, 23 Nov 2022 16:17:02 -0800 Subject: [PATCH 106/137] Use kUtilityFramebuffer in GlCalculatorHelper All calculators using the same context can share a single framebuffer object. PiperOrigin-RevId: 490605074 --- mediapipe/gpu/gl_calculator_helper.cc | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index 7d317e0f1..9b217ddfd 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -27,19 +27,7 @@ namespace mediapipe { GlCalculatorHelper::GlCalculatorHelper() {} -GlCalculatorHelper::~GlCalculatorHelper() { - if (!Initialized()) return; - RunInGlContext( - [this] { - if (framebuffer_) { - glDeleteFramebuffers(1, &framebuffer_); - framebuffer_ = 0; - } - return absl::OkStatus(); - }, - /*calculator_context=*/nullptr) - .IgnoreError(); -} +GlCalculatorHelper::~GlCalculatorHelper() {} void GlCalculatorHelper::InitializeInternal(CalculatorContext* cc, GpuResources* gpu_resources) { @@ -125,9 +113,9 @@ void GlCalculatorHelper::CreateFramebuffer() { // Our framebuffer will have a color attachment but no depth attachment, // so it's important that the depth test be off. It is disabled by default, // but we wanted to be explicit. - // TODO: move this to glBindFramebuffer? + // TODO: move this to glBindFramebuffer? Or just remove. glDisable(GL_DEPTH_TEST); - glGenFramebuffers(1, &framebuffer_); + framebuffer_ = kUtilityFramebuffer.Get(*gl_context_); } void GlCalculatorHelper::BindFramebuffer(const GlTexture& dst) { From 395d9d8ea21c93bbefb37ad980ad41f66b9a2f9f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Sun, 27 Nov 2022 00:05:08 -0800 Subject: [PATCH 107/137] Instantiate GetDetectionVectorItemCalculator variant of GetVectorItemCalculator<>. PiperOrigin-RevId: 491123314 --- mediapipe/calculators/core/BUILD | 1 + mediapipe/calculators/core/get_vector_item_calculator.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 39837fadb..3b658eb5b 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1299,6 +1299,7 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", diff --git a/mediapipe/calculators/core/get_vector_item_calculator.cc b/mediapipe/calculators/core/get_vector_item_calculator.cc index 51fb46b98..3306e4ff3 100644 --- a/mediapipe/calculators/core/get_vector_item_calculator.cc +++ b/mediapipe/calculators/core/get_vector_item_calculator.cc @@ -15,6 +15,7 @@ #include "mediapipe/calculators/core/get_vector_item_calculator.h" #include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" namespace mediapipe { @@ -32,5 +33,9 @@ using GetClassificationListVectorItemCalculator = GetVectorItemCalculator; REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); +using GetDetectionVectorItemCalculator = + GetVectorItemCalculator; +REGISTER_CALCULATOR(GetDetectionVectorItemCalculator); + } // namespace api2 } // namespace mediapipe From 153edc59a111c12b940169a272b36772fcd519a1 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 28 Nov 2022 09:52:40 -0800 Subject: [PATCH 108/137] Add support for browsers without SIMD PiperOrigin-RevId: 491371277 --- mediapipe/tasks/web/BUILD | 12 ++ mediapipe/tasks/web/audio.ts | 5 +- mediapipe/tasks/web/audio/BUILD | 1 + .../tasks/web/audio/audio_classifier/BUILD | 2 +- .../audio_classifier/audio_classifier.ts | 41 ++---- .../audio/audio_embedder/audio_embedder.ts | 28 ++-- mediapipe/tasks/web/audio/index.ts | 1 + mediapipe/tasks/web/core/BUILD | 9 +- mediapipe/tasks/web/core/fileset_resolver.ts | 130 ++++++++++++++++++ mediapipe/tasks/web/core/task_runner.ts | 45 +++++- ..._loader_options.d.ts => wasm_fileset.d.ts} | 4 +- mediapipe/tasks/web/text.ts | 5 +- mediapipe/tasks/web/text/BUILD | 1 + mediapipe/tasks/web/text/index.ts | 1 + .../tasks/web/text/text_classifier/BUILD | 1 - .../text/text_classifier/text_classifier.ts | 39 ++---- mediapipe/tasks/web/text/text_embedder/BUILD | 1 - .../web/text/text_embedder/text_embedder.ts | 42 ++---- mediapipe/tasks/web/vision.ts | 4 +- mediapipe/tasks/web/vision/BUILD | 1 + .../gesture_recognizer/gesture_recognizer.ts | 46 +++---- .../vision/hand_landmarker/hand_landmarker.ts | 46 +++---- .../image_classifier/image_classifier.ts | 41 ++---- .../vision/image_embedder/image_embedder.ts | 40 ++---- mediapipe/tasks/web/vision/index.ts | 1 + .../vision/object_detector/object_detector.ts | 40 ++---- mediapipe/web/graph_runner/graph_runner.ts | 8 +- third_party/wasm_files.bzl | 76 +++++++--- 28 files changed, 410 insertions(+), 261 deletions(-) create mode 100644 mediapipe/tasks/web/core/fileset_resolver.ts rename mediapipe/tasks/web/core/{wasm_loader_options.d.ts => wasm_fileset.d.ts} (88%) diff --git a/mediapipe/tasks/web/BUILD b/mediapipe/tasks/web/BUILD index 7e5d02892..20e717433 100644 --- a/mediapipe/tasks/web/BUILD +++ b/mediapipe/tasks/web/BUILD @@ -13,10 +13,16 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_files(srcs = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ]) # Audio @@ -57,6 +63,8 @@ pkg_npm( deps = [ "wasm/audio_wasm_internal.js", "wasm/audio_wasm_internal.wasm", + "wasm/audio_wasm_nosimd_internal.js", + "wasm/audio_wasm_nosimd_internal.wasm", ":audio_bundle", ], ) @@ -99,6 +107,8 @@ pkg_npm( deps = [ "wasm/text_wasm_internal.js", "wasm/text_wasm_internal.wasm", + "wasm/text_wasm_nosimd_internal.js", + "wasm/text_wasm_nosimd_internal.wasm", ":text_bundle", ], ) @@ -141,6 +151,8 @@ pkg_npm( deps = [ "wasm/vision_wasm_internal.js", "wasm/vision_wasm_internal.wasm", + "wasm/vision_wasm_nosimd_internal.js", + "wasm/vision_wasm_nosimd_internal.wasm", ":vision_bundle", ], ) diff --git a/mediapipe/tasks/web/audio.ts b/mediapipe/tasks/web/audio.ts index 8c522efcc..2f4fb0315 100644 --- a/mediapipe/tasks/web/audio.ts +++ b/mediapipe/tasks/web/audio.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl} from '../../tasks/web/audio/index'; +import {AudioClassifier as AudioClassifierImpl, AudioEmbedder as AudioEmbedderImpl, FilesetResolver as FilesetResolverImpl} from '../../tasks/web/audio/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. const AudioClassifier = AudioClassifierImpl; const AudioEmbedder = AudioEmbedderImpl; +const FilesetResolver = FilesetResolverImpl; -export {AudioClassifier, AudioEmbedder}; +export {AudioClassifier, AudioEmbedder, FilesetResolver}; diff --git a/mediapipe/tasks/web/audio/BUILD b/mediapipe/tasks/web/audio/BUILD index acd7494d7..d08602521 100644 --- a/mediapipe/tasks/web/audio/BUILD +++ b/mediapipe/tasks/web/audio/BUILD @@ -10,5 +10,6 @@ mediapipe_ts_library( deps = [ "//mediapipe/tasks/web/audio/audio_classifier", "//mediapipe/tasks/web/audio/audio_embedder", + "//mediapipe/tasks/web/core:fileset_resolver", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index 498b17845..c419d3b98 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/web/graph_runner:graph_runner_ts", + "//mediapipe/tasks/web/core:task_runner", ], ) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 20c745383..e606019f2 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {TaskRunner} from '../../../../tasks/web/core/task_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; @@ -50,28 +50,17 @@ export class AudioClassifier extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioClassifierOptions The options for the audio classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - audioClassifierOptions: AudioClassifierOptions): + wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file loaded with this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - AudioClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await TaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset); await classifier.setOptions(audioClassifierOptions); return classifier; } @@ -79,31 +68,31 @@ export class AudioClassifier extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return AudioClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return AudioClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index 9dce02862..c87aceabe 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -24,7 +24,7 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -52,25 +52,25 @@ export class AudioEmbedder extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param audioEmbedderOptions The options for the audio embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, audioEmbedderOptions: AudioEmbedderOptions): Promise { // Create a file locator based on the loader options const fileLocator: FileLocator = { locateFile() { // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); + return wasmFileset.wasmBinaryPath.toString(); } }; const embedder = await createMediaPipeLib( - AudioEmbedder, wasmLoaderOptions.wasmLoaderPath, + AudioEmbedder, wasmFileset.wasmLoaderPath, /* assetLoaderScript= */ undefined, /* glCanvas= */ undefined, fileLocator); await embedder.setOptions(audioEmbedderOptions); @@ -80,31 +80,31 @@ export class AudioEmbedder extends AudioTaskRunner { /** * Initializes the Wasm runtime and creates a new audio embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return AudioEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new audio embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return AudioEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/audio/index.ts b/mediapipe/tasks/web/audio/index.ts index 17a908f30..dbad8c617 100644 --- a/mediapipe/tasks/web/audio/index.ts +++ b/mediapipe/tasks/web/audio/index.ts @@ -16,3 +16,4 @@ export * from '../../../tasks/web/audio/audio_classifier/audio_classifier'; export * from '../../../tasks/web/audio/audio_embedder/audio_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index 6eca8bb4a..d709e3409 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -8,7 +8,7 @@ mediapipe_ts_declaration( name = "core", srcs = [ "base_options.d.ts", - "wasm_loader_options.d.ts", + "wasm_fileset.d.ts", ], ) @@ -18,12 +18,19 @@ mediapipe_ts_library( "task_runner.ts", ], deps = [ + ":core", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", ], ) +mediapipe_ts_library( + name = "fileset_resolver", + srcs = ["fileset_resolver.ts"], + deps = [":core"], +) + mediapipe_ts_declaration( name = "classifier_options", srcs = ["classifier_options.d.ts"], diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts new file mode 100644 index 000000000..7d68dbc16 --- /dev/null +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -0,0 +1,130 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Placeholder for internal dependency on trusted resource URL builder + +import {WasmFileset} from './wasm_fileset'; + +let supportsSimd: boolean|undefined; + +/** + * Simple WASM program to test compatibility with the M91 instruction set. + * Compiled from + * https://github.com/GoogleChromeLabs/wasm-feature-detect/blob/main/src/detectors/simd/module.wat + */ +const WASM_SIMD_CHECK = new Uint8Array([ + 0, 97, 115, 109, 1, 0, 0, 0, 1, 5, 1, 96, 0, 1, 123, 3, + 2, 1, 0, 10, 10, 1, 8, 0, 65, 0, 253, 15, 253, 98, 11 +]); + +async function isSimdSupported(): Promise { + if (supportsSimd === undefined) { + try { + await WebAssembly.instantiate(WASM_SIMD_CHECK); + supportsSimd = true; + } catch { + supportsSimd = false; + } + } + + return supportsSimd; +} + +async function createFileset( + taskName: string, basePath: string = '.'): Promise { + if (await isSimdSupported()) { + return { + wasmLoaderPath: + `/${basePath}/${taskName}_wasm_internal.js`, + wasmBinaryPath: + `/${basePath}/${taskName}_wasm_internal.wasm`, + }; + } else { + return { + wasmLoaderPath: + `/${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: `/${basePath}/${ + taskName}_wasm_nosimd_internal.wasm`, + }; + } +} + +// tslint:disable:class-as-namespace + +/** + * Resolves the files required for the MediaPipe Task APIs. + * + * This class verifies whether SIMD is supported in the current environment and + * loads the SIMD files only if support is detected. The returned filesets + * require that the Wasm files are published without renaming. If this is not + * possible, you can invoke the MediaPipe Tasks APIs using a manually created + * `WasmFileset`. + */ +export class FilesetResolver { + /** + * Returns whether SIMD is supported in the current environment. + * + * If your environment requires custom locations for the MediaPipe Wasm files, + * you can use `isSimdSupported()` to decide whether to load the SIMD-based + * assets. + * + * @return Whether SIMD support was detected in the current environment. + */ + static isSimdSupported(): Promise { + return isSimdSupported(); + } + + /** + * Creates a fileset for the MediaPipe Audio tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Audio + * tasks. + */ + static forAudioTasks(basePath?: string): Promise { + return createFileset('audio', basePath); + } + + /** + * Creates a fileset for the MediaPipe Text tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Text + * tasks. + */ + static forTextTasks(basePath?: string): Promise { + return createFileset('text', basePath); + } + + /** + * Creates a fileset for the MediaPipe Vision tasks. + * + * @param basePath An optional base path to specify the directory the Wasm + * files should be loaded from. If not specified, the Wasm files are + * loaded from the host's root directory. + * @return A `WasmFileset` that can be used to initialize MediaPipe Vision + * tasks. + */ + static forVisionTasks(basePath?: string): Promise { + return createFileset('vision', basePath); + } +} + + diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 67aa4e4df..4085be697 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,9 +14,14 @@ * limitations under the License. */ -import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; +import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; -import {GraphRunner, WasmModule} from '../../../web/graph_runner/graph_runner'; +import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; + +import {WasmFileset} from './wasm_fileset'; + +// None of the MP Tasks ship bundle assets. +const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing const WasmMediaPipeImageLib = @@ -26,8 +31,40 @@ const WasmMediaPipeImageLib = export abstract class TaskRunner extends WasmMediaPipeImageLib { private processingErrors: Error[] = []; - constructor(wasmModule: WasmModule) { - super(wasmModule); + /** + * Creates a new instance of a Mediapipe Task. Determines if SIMD is + * supported and loads the relevant WASM binary. + * @return A fully instantiated instance of `T`. + */ + protected static async createInstance( + type: WasmMediaPipeConstructor, initializeCanvas: boolean, + fileset: WasmFileset): Promise { + const fileLocator: FileLocator = { + locateFile() { + // The only file loaded with this mechanism is the Wasm binary + return fileset.wasmBinaryPath.toString(); + } + }; + + if (initializeCanvas) { + // Fall back to an OffscreenCanvas created by the GraphRunner if + // OffscreenCanvas is available + const canvas = typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined; + return createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + } else { + return createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, + fileLocator); + } + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. diff --git a/mediapipe/tasks/web/core/wasm_loader_options.d.ts b/mediapipe/tasks/web/core/wasm_fileset.d.ts similarity index 88% rename from mediapipe/tasks/web/core/wasm_loader_options.d.ts rename to mediapipe/tasks/web/core/wasm_fileset.d.ts index 74436583d..18227eab9 100644 --- a/mediapipe/tasks/web/core/wasm_loader_options.d.ts +++ b/mediapipe/tasks/web/core/wasm_fileset.d.ts @@ -16,8 +16,8 @@ // Placeholder for internal dependency on trusted resource url -/** An object containing the locations of all Wasm assets */ -export declare interface WasmLoaderOptions { +/** An object containing the locations of the Wasm assets */ +export declare interface WasmFileset { /** The path to the Wasm loader script. */ wasmLoaderPath: string; /** The path to the Wasm binary. */ diff --git a/mediapipe/tasks/web/text.ts b/mediapipe/tasks/web/text.ts index 8f15075c5..0636714b8 100644 --- a/mediapipe/tasks/web/text.ts +++ b/mediapipe/tasks/web/text.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import {TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; +import {FilesetResolver as FilesetResolverImpl, TextClassifier as TextClassifierImpl, TextEmbedder as TextEmbedderImpl} from '../../tasks/web/text/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. +const FilesetResolver = FilesetResolverImpl; const TextClassifier = TextClassifierImpl; const TextEmbedder = TextEmbedderImpl; -export {TextClassifier, TextEmbedder}; +export {FilesetResolver, TextClassifier, TextEmbedder}; diff --git a/mediapipe/tasks/web/text/BUILD b/mediapipe/tasks/web/text/BUILD index 4b465b0f5..159db1a0d 100644 --- a/mediapipe/tasks/web/text/BUILD +++ b/mediapipe/tasks/web/text/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "text_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/text/text_classifier", "//mediapipe/tasks/web/text/text_embedder", ], diff --git a/mediapipe/tasks/web/text/index.ts b/mediapipe/tasks/web/text/index.ts index d50db209c..a28e4dd1c 100644 --- a/mediapipe/tasks/web/text/index.ts +++ b/mediapipe/tasks/web/text/index.ts @@ -16,3 +16,4 @@ export * from '../../../tasks/web/text/text_classifier/text_classifier'; export * from '../../../tasks/web/text/text_embedder/text_embedder'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 71ef02c92..f3d272daa 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -26,7 +26,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 04789f5e1..197869a36 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -22,8 +22,7 @@ import {convertBaseOptionsToProto} from '../../../../tasks/web/components/proces import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; @@ -48,27 +47,17 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textClassifierOptions The options for the text classifier. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - TextClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset); await classifier.setOptions(textClassifierOptions); return classifier; } @@ -76,31 +65,31 @@ export class TextClassifier extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text classifier based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return TextClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text classifier based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return TextClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } /** diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index 3f92b8ae1..b858f6b83 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -26,7 +26,6 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", - "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 2042a0985..511fd2411 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -24,8 +24,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; @@ -52,27 +51,17 @@ export class TextEmbedder extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param textEmbedderOptions The options for the text embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - TextEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const embedder = await TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset); await embedder.setOptions(textEmbedderOptions); return embedder; } @@ -80,31 +69,31 @@ export class TextEmbedder extends TaskRunner { /** * Initializes the Wasm runtime and creates a new text embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return TextEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new text embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return TextEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } /** @@ -122,14 +111,11 @@ export class TextEmbedder extends TaskRunner { options.baseOptions, this.options.getBaseOptions()); this.options.setBaseOptions(baseOptionsProto); } - this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); - this.refreshGraph(); } - /** * Performs embeding extraction on the provided text and waits synchronously * for the response. diff --git a/mediapipe/tasks/web/vision.ts b/mediapipe/tasks/web/vision.ts index 74a056464..f1ced59af 100644 --- a/mediapipe/tasks/web/vision.ts +++ b/mediapipe/tasks/web/vision.ts @@ -14,10 +14,11 @@ * limitations under the License. */ -import {GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; +import {FilesetResolver as FilesetResolverImpl, GestureRecognizer as GestureRecognizerImpl, HandLandmarker as HandLandmarkerImpl, ImageClassifier as ImageClassifierImpl, ImageEmbedder as ImageEmbedderImpl, ObjectDetector as ObjectDetectorImpl} from '../../tasks/web/vision/index'; // Declare the variables locally so that Rollup in OSS includes them explcilty // as exports. +const FilesetResolver = FilesetResolverImpl; const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; @@ -25,6 +26,7 @@ const ImageEmbedder = ImageEmbedderImpl; const ObjectDetector = ObjectDetectorImpl; export { + FilesetResolver, GestureRecognizer, HandLandmarker, ImageClassifier, diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index 3c45fbfa6..42bc0a494 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -8,6 +8,7 @@ mediapipe_ts_library( name = "vision_lib", srcs = ["index.ts"], deps = [ + "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index dd050d0f1..7441911c1 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -29,9 +29,9 @@ import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/han import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {GestureRecognizerOptions} from './gesture_recognizer_options'; @@ -82,28 +82,18 @@ export class GestureRecognizer extends /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param gestureRecognizerOptions The options for the gesture recognizer. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const recognizer = await createMediaPipeLib( - GestureRecognizer, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const recognizer = await VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); await recognizer.setOptions(gestureRecognizerOptions); return recognizer; } @@ -111,35 +101,37 @@ export class GestureRecognizer extends /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return GestureRecognizer.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new gesture recognizer based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return GestureRecognizer.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 32b1eed4b..6d69d568c 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -25,9 +25,9 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark} from '../../../../tasks/web/components/containers/landmark'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {HandLandmarkerOptions} from './hand_landmarker_options'; @@ -71,27 +71,17 @@ export class HandLandmarker extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param handLandmarkerOptions The options for the HandLandmarker. * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load via this mechanism is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const landmarker = await createMediaPipeLib( - HandLandmarker, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const landmarker = await VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset); await landmarker.setOptions(handLandmarkerOptions); return landmarker; } @@ -99,35 +89,37 @@ export class HandLandmarker extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return HandLandmarker.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new `HandLandmarker` based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return HandLandmarker.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } - constructor(wasmModule: WasmModule) { - super(wasmModule); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); this.handLandmarksDetectorGraphOptions = diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index b59cb6fb1..604795f9f 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -21,9 +21,9 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ImageClassifierGraphOptions} from '../../../../tasks/cc/vision/image_classifier/proto/image_classifier_graph_options_pb'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -49,28 +49,17 @@ export class ImageClassifier extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image classifier from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location + * Wasm binary and its loader. * @param imageClassifierOptions The options for the image classifier. Note * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, - imageClassifierOptions: ImageClassifierOptions): + wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const classifier = await createMediaPipeLib( - ImageClassifier, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const classifier = await VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset); await classifier.setOptions(imageClassifierOptions); return classifier; } @@ -78,31 +67,31 @@ export class ImageClassifier extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image classifier based on * the provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ImageClassifier.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image classifier based on * the path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ImageClassifier.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index f96f1e961..68068db6d 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -23,9 +23,9 @@ import {Embedding} from '../../../../tasks/web/components/containers/embedding_r import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -51,27 +51,17 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image embedder from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param imageEmbedderOptions The options for the image embedder. Note that * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - ImageEmbedder, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const embedder = await VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); await embedder.setOptions(imageEmbedderOptions); return embedder; } @@ -79,31 +69,31 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new image embedder based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the TFLite model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ImageEmbedder.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new image embedder based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ImageEmbedder.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index d68c00cc7..0337a0f2f 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -19,3 +19,4 @@ export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/object_detector/object_detector'; +export * from '../../../tasks/web/core/fileset_resolver'; diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 44046cd1e..0f039acb2 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -19,9 +19,9 @@ import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {Detection as DetectionProto} from '../../../../framework/formats/detection_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; -import {WasmLoaderOptions} from '../../../../tasks/web/core/wasm_loader_options'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -48,27 +48,17 @@ export class ObjectDetector extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new object detector from the * provided options. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param objectDetectorOptions The options for the Object Detector. Note that * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ static async createFromOptions( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmLoaderOptions.wasmBinaryPath.toString(); - } - }; - - const detector = await createMediaPipeLib( - ObjectDetector, wasmLoaderOptions.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); + const detector = await VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset); await detector.setOptions(objectDetectorOptions); return detector; } @@ -76,31 +66,31 @@ export class ObjectDetector extends VisionTaskRunner { /** * Initializes the Wasm runtime and creates a new object detector based on the * provided model asset buffer. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetBuffer A binary representation of the model. */ static createFromModelBuffer( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { return ObjectDetector.createFromOptions( - wasmLoaderOptions, {baseOptions: {modelAssetBuffer}}); + wasmFileset, {baseOptions: {modelAssetBuffer}}); } /** * Initializes the Wasm runtime and creates a new object detector based on the * path to the model asset. - * @param wasmLoaderOptions A configuration object that provides the location - * of the Wasm binary and its loader. + * @param wasmFileset A configuration object that provides the location of the + * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ static async createFromModelPath( - wasmLoaderOptions: WasmLoaderOptions, + wasmFileset: WasmFileset, modelAssetPath: string): Promise { const response = await fetch(modelAssetPath.toString()); const graphData = await response.arrayBuffer(); return ObjectDetector.createFromModelBuffer( - wasmLoaderOptions, new Uint8Array(graphData)); + wasmFileset, new Uint8Array(graphData)); } protected override get baseOptions(): BaseOptionsProto|undefined { diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 378bc0a4d..9a0f7148c 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -133,9 +133,11 @@ export type ImageSource = /** A listener that will be invoked with an absl::StatusCode and message. */ export type ErrorListener = (code: number, message: string) => void; -// Internal type of constructors used for initializing GraphRunner and -// subclasses. -type WasmMediaPipeConstructor = +/** + * Internal type of constructors used for initializing GraphRunner and + * subclasses. + */ +export type WasmMediaPipeConstructor = (new ( module: WasmModule, canvas?: HTMLCanvasElement|OffscreenCanvas|null) => LibType); diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 6bfde21ba..504f8567a 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,36 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"], - ) - - http_file( - name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"], - ) - - http_file( - name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"], + sha256 = "42d2d0ade6e2e8b81425b23686be93eb1423b7777f043eb8f18ad671e2ca803f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1669173769507080"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"], + sha256 = "20200ee9b0866d5176f633a9b375e8a44e53204c01ea2e159e2f9245afb00e80", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1669173772528997"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", + sha256 = "11bbf73d48723b19a5a6a13ec296ecdb2aa178cdc3db9d7bc54265a7d4b94c6a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1669173774625527"], + ) + + http_file( + name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", + sha256 = "d4528972219033996a83a62798952b6ee8b6b396bcffd96fd5bda5458d57d3a3", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1669173777474822"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_internal_js", + sha256 = "29e72e177122f92bda6a3ecd463ebacf30b920559b06c97068112a22eeea4d0e", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1669173779706893"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"], + sha256 = "84e5f5ac70f7718baeaa09a89b155abbea67386e7d50663301b3af7ef0941e74", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1669173782728605"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", + sha256 = "36f247673124e32535f217265b96508c1badee8fe2458c11c1efa95b6bec5daa", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1669173785027190"], + ) + + http_file( + name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", + sha256 = "cc74d90a8aaf6d006ec24048cc80c33f96baeeb0075a6c6739f30d41da54e450", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1669173787903754"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_internal_js", + sha256 = "c3451423186766b08008e07ef6d52f628fcc0aca75beedd9bb4d87d380f29edd", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1669173790070986"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"], + sha256 = "d1e8ad748913e3f190bfd3f72e0e8a4a308f78b918d54c79cec60a2cf30a49f0", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1669173792993881"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", + sha256 = "e5f1b5e8264ff9a90371653cb0fdbf9ce3b30b712acbd72068af18ebca2293ac", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1669173794969702"], + ) + + http_file( + name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", + sha256 = "24351fe580e88f2065b1978b8b3c0f3ad7b90f1c95805aafa07971ce422b5854", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1669173797596874"], ) From c48ca1f674e2fef6b23a28100fd092ebe656e96a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 13:29:35 -0800 Subject: [PATCH 109/137] internal change PiperOrigin-RevId: 491429214 --- .../tasks/cc/components/containers/BUILD | 5 --- .../tasks/cc/vision/hand_landmarker/BUILD | 6 +++ .../hand_landmarker/hand_landmark.h} | 10 ++--- .../tasks/components/containers/BUILD | 12 ------ .../com/google/mediapipe/tasks/vision/BUILD | 2 + .../handlandmarker}/HandLandmark.java | 2 +- .../python/components/containers/landmark.py | 26 ------------ .../tasks/python/vision/hand_landmarker.py | 26 ++++++++++++ .../web/components/containers/landmark.d.ts | 25 ----------- .../tasks/web/vision/hand_landmarker/BUILD | 1 + .../vision/hand_landmarker/hand_landmark.d.ts | 41 +++++++++++++++++++ 11 files changed, 82 insertions(+), 74 deletions(-) rename mediapipe/tasks/cc/{components/containers/landmark.h => vision/hand_landmarker/hand_landmark.h} (78%) rename mediapipe/tasks/java/com/google/mediapipe/tasks/{components/containers => vision/handlandmarker}/HandLandmark.java (97%) create mode 100644 mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index dec977fb8..35d3f4785 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -49,8 +49,3 @@ cc_library( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) - -cc_library( - name = "landmark", - hdrs = ["landmark.h"], -) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 46948ee6c..03ec45f7d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -54,6 +54,12 @@ cc_library( ], ) +cc_library( + name = "hand_landmark", + hdrs = ["hand_landmark.h"], + visibility = ["//visibility:public"], +) + cc_library( name = "hand_landmarks_detector_graph", srcs = ["hand_landmarks_detector_graph.cc"], diff --git a/mediapipe/tasks/cc/components/containers/landmark.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h similarity index 78% rename from mediapipe/tasks/cc/components/containers/landmark.h rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h index 6fdd294ae..c8dbc9254 100644 --- a/mediapipe/tasks/cc/components/containers/landmark.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmark.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ -namespace mediapipe::tasks::components::containers { +namespace mediapipe::tasks::vision::hand_landmarker { // The 21 hand landmarks. enum HandLandmark { @@ -43,6 +43,6 @@ enum HandLandmark { PINKY_TIP = 20 }; -} // namespace mediapipe::tasks::components::containers +} // namespace mediapipe::tasks::vision::hand_landmarker -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARK_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARK_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 869157295..d6e6ac740 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -74,18 +74,6 @@ android_library( ], ) -android_library( - name = "handlandmark", - srcs = ["HandLandmark.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "@maven//:androidx_annotation_annotation", - "@maven//:com_google_guava_guava", - ], -) - android_library( name = "landmark", srcs = ["Landmark.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 72cee133f..b7febb118 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -145,6 +145,7 @@ android_library( android_library( name = "handlandmarker", srcs = [ + "handlandmarker/HandLandmark.java", "handlandmarker/HandLandmarker.java", "handlandmarker/HandLandmarkerResult.java", ], @@ -168,6 +169,7 @@ android_library( "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java similarity index 97% rename from mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java rename to mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java index da7c4e0ca..7b21ebddf 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/HandLandmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmark.java @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package com.google.mediapipe.tasks.components.containers; +package com.google.mediapipe.tasks.vision.handlandmarker; import androidx.annotation.IntDef; diff --git a/mediapipe/tasks/python/components/containers/landmark.py b/mediapipe/tasks/python/components/containers/landmark.py index 81b2943dc..dee2a16ad 100644 --- a/mediapipe/tasks/python/components/containers/landmark.py +++ b/mediapipe/tasks/python/components/containers/landmark.py @@ -14,7 +14,6 @@ """Landmark data class.""" import dataclasses -import enum from typing import Optional from mediapipe.framework.formats import landmark_pb2 @@ -121,28 +120,3 @@ class NormalizedLandmark: z=pb2_obj.z, visibility=pb2_obj.visibility, presence=pb2_obj.presence) - - -class HandLandmark(enum.IntEnum): - """The 21 hand landmarks.""" - WRIST = 0 - THUMB_CMC = 1 - THUMB_MCP = 2 - THUMB_IP = 3 - THUMB_TIP = 4 - INDEX_FINGER_MCP = 5 - INDEX_FINGER_PIP = 6 - INDEX_FINGER_DIP = 7 - INDEX_FINGER_TIP = 8 - MIDDLE_FINGER_MCP = 9 - MIDDLE_FINGER_PIP = 10 - MIDDLE_FINGER_DIP = 11 - MIDDLE_FINGER_TIP = 12 - RING_FINGER_MCP = 13 - RING_FINGER_PIP = 14 - RING_FINGER_DIP = 15 - RING_FINGER_TIP = 16 - PINKY_MCP = 17 - PINKY_PIP = 18 - PINKY_DIP = 19 - PINKY_TIP = 20 diff --git a/mediapipe/tasks/python/vision/hand_landmarker.py b/mediapipe/tasks/python/vision/hand_landmarker.py index 3367f1da7..a0cd99a83 100644 --- a/mediapipe/tasks/python/vision/hand_landmarker.py +++ b/mediapipe/tasks/python/vision/hand_landmarker.py @@ -14,6 +14,7 @@ """MediaPipe hand landmarker task.""" import dataclasses +import enum from typing import Callable, Mapping, Optional, List from mediapipe.framework.formats import classification_pb2 @@ -53,6 +54,31 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +class HandLandmark(enum.IntEnum): + """The 21 hand landmarks.""" + WRIST = 0 + THUMB_CMC = 1 + THUMB_MCP = 2 + THUMB_IP = 3 + THUMB_TIP = 4 + INDEX_FINGER_MCP = 5 + INDEX_FINGER_PIP = 6 + INDEX_FINGER_DIP = 7 + INDEX_FINGER_TIP = 8 + MIDDLE_FINGER_MCP = 9 + MIDDLE_FINGER_PIP = 10 + MIDDLE_FINGER_DIP = 11 + MIDDLE_FINGER_TIP = 12 + RING_FINGER_MCP = 13 + RING_FINGER_PIP = 14 + RING_FINGER_DIP = 15 + RING_FINGER_TIP = 16 + PINKY_MCP = 17 + PINKY_PIP = 18 + PINKY_DIP = 19 + PINKY_TIP = 20 + + @dataclasses.dataclass class HandLandmarkerResult: """The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image. diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index 352717a2f..c887303d0 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -33,28 +33,3 @@ export declare interface Landmark { /** Whether this landmark is normalized with respect to the image size. */ normalized: boolean; } - -/** The 21 hand landmarks. */ -export const enum HandLandmark { - WRIST = 0, - THUMB_CMC = 1, - THUMB_MCP = 2, - THUMB_IP = 3, - THUMB_TIP = 4, - INDEX_FINGER_MCP = 5, - INDEX_FINGER_PIP = 6, - INDEX_FINGER_DIP = 7, - INDEX_FINGER_TIP = 8, - MIDDLE_FINGER_MCP = 9, - MIDDLE_FINGER_PIP = 10, - MIDDLE_FINGER_DIP = 11, - MIDDLE_FINGER_TIP = 12, - RING_FINGER_MCP = 13, - RING_FINGER_PIP = 14, - RING_FINGER_DIP = 15, - RING_FINGER_TIP = 16, - PINKY_MCP = 17, - PINKY_PIP = 18, - PINKY_DIP = 19, - PINKY_TIP = 20 -} diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index 1849687c5..fc3e6ef1f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -34,6 +34,7 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "hand_landmarker_types", srcs = [ + "hand_landmark.d.ts", "hand_landmarker_options.d.ts", "hand_landmarker_result.d.ts", ], diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts new file mode 100644 index 000000000..ca2543f78 --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmark.d.ts @@ -0,0 +1,41 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** The 21 hand landmarks. */ +export const enum HandLandmark { + WRIST = 0, + THUMB_CMC = 1, + THUMB_MCP = 2, + THUMB_IP = 3, + THUMB_TIP = 4, + INDEX_FINGER_MCP = 5, + INDEX_FINGER_PIP = 6, + INDEX_FINGER_DIP = 7, + INDEX_FINGER_TIP = 8, + MIDDLE_FINGER_MCP = 9, + MIDDLE_FINGER_PIP = 10, + MIDDLE_FINGER_DIP = 11, + MIDDLE_FINGER_TIP = 12, + RING_FINGER_MCP = 13, + RING_FINGER_PIP = 14, + RING_FINGER_DIP = 15, + RING_FINGER_TIP = 16, + PINKY_MCP = 17, + PINKY_PIP = 18, + PINKY_DIP = 19, + PINKY_TIP = 20 +} From 342f95fa2044c4957ea7cb65352268a868e3d680 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Mon, 28 Nov 2022 13:51:59 -0800 Subject: [PATCH 110/137] Typo fix PiperOrigin-RevId: 491434987 --- mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h | 2 +- mediapipe/tasks/python/vision/image_segmenter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 43bf5b7e6..511d3b9c1 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -98,7 +98,7 @@ struct ImageSegmenterOptions { // - list of segmented masks. // - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. // - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. +// `channels`. // - batch is always 1 // An example of such model can be found at: // https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index 9ef911f75..62fc8bb7c 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -110,7 +110,7 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): - list of segmented masks. - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - if `output_type` is CONFIDENCE_MASK, float32 Image list of size - `cahnnels`. + `channels`. - batch is always 1 An example of such model can be found at: From b65c40b302ccf397d6da3c27ab2795335e5c63cd Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 14:15:16 -0800 Subject: [PATCH 111/137] Internal change PiperOrigin-RevId: 491441446 --- mediapipe/objc/MPPLayerRenderer.m | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mediapipe/objc/MPPLayerRenderer.m b/mediapipe/objc/MPPLayerRenderer.m index 7c3027fb6..edd2216ee 100644 --- a/mediapipe/objc/MPPLayerRenderer.m +++ b/mediapipe/objc/MPPLayerRenderer.m @@ -54,10 +54,11 @@ glGenRenderbuffers(1, &renderbuffer_); glBindRenderbuffer(GL_RENDERBUFFER, renderbuffer_); glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_RENDERBUFFER, renderbuffer_); - BOOL success = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER fromDrawable:_layer]; + BOOL success __unused = [_glRenderer.glContext renderbufferStorage:GL_RENDERBUFFER + fromDrawable:_layer]; NSAssert(success, @"could not create renderbuffer storage for layer with bounds %@", NSStringFromCGRect(_layer.bounds)); - GLenum status = glCheckFramebufferStatus(GL_FRAMEBUFFER); + GLenum status __unused = glCheckFramebufferStatus(GL_FRAMEBUFFER); NSAssert(status == GL_FRAMEBUFFER_COMPLETE, @"failed to make complete framebuffer object %x", status); } From 26a7ca5c64cd885978677931a7218d33cd7d1dec Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 15:02:55 -0800 Subject: [PATCH 112/137] fix typo and minor formatting issues PiperOrigin-RevId: 491453662 --- mediapipe/python/solutions/drawing_utils.py | 42 ++++++++++----------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index bebcbe97c..1b8b173f7 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MediaPipe solution drawing utils.""" import math @@ -135,15 +134,14 @@ def draw_landmarks( the image. connections: A list of landmark index tuples that specifies how landmarks to be connected in the drawing. - landmark_drawing_spec: Either a DrawingSpec object or a mapping from - hand landmarks to the DrawingSpecs that specifies the landmarks' drawing - settings such as color, line thickness, and circle radius. - If this argument is explicitly set to None, no landmarks will be drawn. - connection_drawing_spec: Either a DrawingSpec object or a mapping from - hand connections to the DrawingSpecs that specifies the - connections' drawing settings such as color and line thickness. - If this argument is explicitly set to None, no landmark connections will - be drawn. + landmark_drawing_spec: Either a DrawingSpec object or a mapping from hand + landmarks to the DrawingSpecs that specifies the landmarks' drawing + settings such as color, line thickness, and circle radius. If this + argument is explicitly set to None, no landmarks will be drawn. + connection_drawing_spec: Either a DrawingSpec object or a mapping from hand + connections to the DrawingSpecs that specifies the connections' drawing + settings such as color and line thickness. If this argument is explicitly + set to None, no landmark connections will be drawn. Raises: ValueError: If one of the followings: @@ -197,14 +195,13 @@ def draw_landmarks( drawing_spec.color, drawing_spec.thickness) -def draw_axis( - image: np.ndarray, - rotation: np.ndarray, - translation: np.ndarray, - focal_length: Tuple[float, float] = (1.0, 1.0), - principal_point: Tuple[float, float] = (0.0, 0.0), - axis_length: float = 0.1, - axis_drawing_spec: DrawingSpec = DrawingSpec()): +def draw_axis(image: np.ndarray, + rotation: np.ndarray, + translation: np.ndarray, + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + axis_length: float = 0.1, + axis_drawing_spec: DrawingSpec = DrawingSpec()): """Draws the 3D axis on the image. Args: @@ -214,8 +211,8 @@ def draw_axis( focal_length: camera focal length along x and y directions. principal_point: camera principal point in x and y. axis_length: length of the axis in the drawing. - axis_drawing_spec: A DrawingSpec object that specifies the xyz axis - drawing settings such as line thickness. + axis_drawing_spec: A DrawingSpec object that specifies the xyz axis drawing + settings such as line thickness. Raises: ValueError: If one of the followings: @@ -226,7 +223,7 @@ def draw_axis( image_rows, image_cols, _ = image.shape # Create axis points in camera coordinate frame. axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation + axis_cam = np.matmul(rotation, axis_length * axis_world.T).T + translation x = axis_cam[..., 0] y = axis_cam[..., 1] z = axis_cam[..., 2] @@ -274,8 +271,9 @@ def plot_landmarks(landmark_list: landmark_pb2.NormalizedLandmarkList, connections' drawing settings such as color and line thickness. elevation: The elevation from which to view the plot. azimuth: the azimuth angle to rotate the plot. + Raises: - ValueError: If any connetions contain invalid landmark index. + ValueError: If any connection contains an invalid landmark index. """ if not landmark_list: return From 7b74fd53f592ab115f60180278952eafeeb61634 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 15:46:30 -0800 Subject: [PATCH 113/137] Verify that kernel cache is only used when OpenCL is active PiperOrigin-RevId: 491463306 --- .../calculators/tensor/inference_calculator_gl_advanced.cc | 6 +++--- mediapipe/calculators/tflite/tflite_inference_calculator.cc | 6 +++--- mediapipe/util/tflite/tflite_gpu_runner.h | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index c2c723402..b226dbbd8 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -258,9 +258,9 @@ InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::SaveGpuCaches( tflite::gpu::TFLiteGPURunner* gpu_runner) const { if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - gpu_runner->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + gpu_runner->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index afdc9ed6f..0f7fa933e 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -485,9 +485,9 @@ absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + ASSIGN_OR_RETURN(std::vector kernel_cache, + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache.begin(), kernel_cache.end()); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index dfbc8d659..5eeaa230f 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -21,6 +21,7 @@ #include "absl/status/status.h" #include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -89,7 +90,8 @@ class TFLiteGPURunner { serialized_binary_cache_ = std::move(cache); } - std::vector GetSerializedBinaryCache() { + absl::StatusOr> GetSerializedBinaryCache() { + RET_CHECK(cl_environment_) << "CL environment is not initialized."; return cl_environment_->GetSerializedBinaryCache(); } From e987b69f397af3d7bb4976d4e77029dacaae999a Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 28 Nov 2022 16:48:17 -0800 Subject: [PATCH 114/137] Add alternative method to determine unique kernel cache path PiperOrigin-RevId: 491476293 --- .../tensor/inference_calculator_gl_advanced.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index b226dbbd8..8fd55efa7 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -236,14 +236,21 @@ absl::Status InferenceCalculatorGlAdvancedImpl::OnDiskCacheHelper::Init( const mediapipe::InferenceCalculatorOptions& options, const mediapipe::InferenceCalculatorOptions::Delegate::Gpu& gpu_delegate_options) { - use_kernel_caching_ = gpu_delegate_options.has_cached_kernel_path(); + // The kernel cache needs a unique filename based on either model_path or the + // model token, to prevent the cache from being overwritten if the graph has + // more than one model. + use_kernel_caching_ = + gpu_delegate_options.has_cached_kernel_path() && + (options.has_model_path() || gpu_delegate_options.has_model_token()); use_serialized_model_ = gpu_delegate_options.has_serialized_model_dir() && gpu_delegate_options.has_model_token(); if (use_kernel_caching_) { + std::string basename = options.has_model_path() + ? mediapipe::File::Basename(options.model_path()) + : gpu_delegate_options.model_token(); cached_kernel_filename_ = mediapipe::file::JoinPath( - gpu_delegate_options.cached_kernel_path(), - mediapipe::File::Basename(options.model_path()) + ".ker"); + gpu_delegate_options.cached_kernel_path(), basename + ".ker"); } if (use_serialized_model_) { serialized_model_path_ = From fc526374abac9e1080e06470004ab292fe0c162a Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 28 Nov 2022 17:48:37 -0800 Subject: [PATCH 115/137] Use GpuResources in GpuTestBase and update GpuBufferMultiPoolTest PiperOrigin-RevId: 491486495 --- mediapipe/gpu/gpu_test_base.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mediapipe/gpu/gpu_test_base.h b/mediapipe/gpu/gpu_test_base.h index e9fd64725..6ec53603b 100644 --- a/mediapipe/gpu/gpu_test_base.h +++ b/mediapipe/gpu/gpu_test_base.h @@ -24,13 +24,14 @@ namespace mediapipe { class GpuTestBase : public ::testing::Test { protected: - GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); } + GpuTestBase() { helper_.InitializeForTest(gpu_resources_.get()); } void RunInGlContext(std::function gl_func) { helper_.RunInGlContext(std::move(gl_func)); } GpuSharedData gpu_shared_; + std::shared_ptr gpu_resources_ = gpu_shared_.gpu_resources; GlCalculatorHelper helper_; }; From cc11b4522837ce2f3763831fca0447e3b7cef495 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Mon, 28 Nov 2022 17:52:35 -0800 Subject: [PATCH 116/137] Remove unneeded GPU_SHARED side packet for GlSurfaceSink PiperOrigin-RevId: 491487092 --- mediapipe/gpu/gl_surface_sink_calculator.cc | 1 - mediapipe/java/com/google/mediapipe/framework/jni/graph.cc | 2 -- 2 files changed, 3 deletions(-) diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 31500ed9a..ad867c2be 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -37,7 +37,6 @@ enum { kAttribVertex, kAttribTexturePosition, kNumberOfAttributes }; // VIDEO or index 0: GpuBuffers to be rendered. // Side inputs: // SURFACE: unique_ptr to an EglSurfaceHolder to draw to. -// GPU_SHARED: shared GPU resources. // // See GlSurfaceSinkCalculatorOptions for options. class GlSurfaceSinkCalculator : public Node { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 6a67c01cb..23bd553af 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -231,8 +231,6 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { *graph_config(), absl::StrCat("egl_surface_sink_", output_stream_name))); sink_node->set_calculator("GlSurfaceSinkCalculator"); sink_node->add_input_stream(output_stream_name); - sink_node->add_input_side_packet( - absl::StrCat(kGpuSharedTagName, ":", kGpuSharedSidePacketName)); const std::string input_side_packet_name = mediapipe::tool::GetUnusedSidePacketName( From c8a413bb4e5da6b977695987809a27b8f044f15a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Tue, 29 Nov 2022 10:17:21 -0800 Subject: [PATCH 117/137] Open up mediapipe framework's visibility. PiperOrigin-RevId: 491672877 --- mediapipe/calculators/image/BUILD | 41 +-------- mediapipe/calculators/tensorflow/BUILD | 70 +--------------- mediapipe/calculators/tflite/BUILD | 20 +---- mediapipe/calculators/util/BUILD | 83 ------------------- mediapipe/calculators/video/BUILD | 29 +------ mediapipe/examples/desktop/hello_world/BUILD | 3 +- mediapipe/framework/BUILD | 2 +- mediapipe/framework/formats/BUILD | 28 +------ mediapipe/framework/formats/annotation/BUILD | 4 +- mediapipe/framework/formats/motion/BUILD | 7 +- .../framework/formats/object_detection/BUILD | 4 +- mediapipe/framework/stream_handler/BUILD | 19 +---- .../holistic_landmark/calculators/BUILD | 3 - mediapipe/util/tracking/BUILD | 17 ---- 14 files changed, 11 insertions(+), 319 deletions(-) diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index c78bc5cf7..530dd3d4a 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -16,12 +16,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,7 +30,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "scale_image_calculator_proto", srcs = ["scale_image_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -42,7 +40,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "set_alpha_calculator_proto", srcs = ["set_alpha_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -52,7 +49,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "image_cropping_calculator_proto", srcs = ["image_cropping_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -62,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -72,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "recolor_calculator_proto", srcs = ["recolor_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "segmentation_smoothing_calculator_proto", srcs = ["segmentation_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -112,7 +104,6 @@ cc_library( cc_library( name = "opencv_encoded_image_to_image_frame_calculator", srcs = ["opencv_encoded_image_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_encoded_image_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -127,7 +118,6 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -142,7 +132,6 @@ cc_library( cc_library( name = "opencv_put_text_calculator", srcs = ["opencv_put_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame_opencv", @@ -156,7 +145,6 @@ cc_library( cc_library( name = "set_alpha_calculator", srcs = ["set_alpha_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":set_alpha_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -183,7 +171,6 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -212,13 +199,11 @@ cc_library( mediapipe_proto_library( name = "rotation_mode_proto", srcs = ["rotation_mode.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", @@ -243,7 +228,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", @@ -287,7 +271,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":image_cropping_calculator_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", @@ -330,7 +313,6 @@ cc_test( cc_library( name = "luminance_calculator", srcs = ["luminance_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -344,7 +326,6 @@ cc_library( cc_library( name = "sobel_edges_calculator", srcs = ["sobel_edges_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", @@ -358,7 +339,6 @@ cc_library( cc_library( name = "recolor_calculator", srcs = ["recolor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":recolor_calculator_cc_proto", "//mediapipe/util:color_cc_proto", @@ -385,9 +365,6 @@ cc_library( name = "scale_image_utils", srcs = ["scale_image_utils.cc"], hdrs = ["scale_image_utils.h"], - visibility = [ - "//mediapipe:__subpackages__", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:logging", @@ -400,9 +377,6 @@ cc_library( cc_library( name = "scale_image_calculator", srcs = ["scale_image_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ ":scale_image_utils", "//mediapipe/calculators/image:scale_image_calculator_cc_proto", @@ -429,7 +403,6 @@ cc_library( mediapipe_proto_library( name = "image_clone_calculator_proto", srcs = ["image_clone_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -439,7 +412,6 @@ mediapipe_proto_library( cc_library( name = "image_clone_calculator", srcs = ["image_clone_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_clone_calculator_cc_proto", "//mediapipe/framework/api2:node", @@ -459,7 +431,6 @@ cc_library( cc_library( name = "image_properties_calculator", srcs = ["image_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", @@ -524,7 +495,6 @@ cc_test( mediapipe_proto_library( name = "mask_overlay_calculator_proto", srcs = ["mask_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -534,7 +504,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "opencv_encoded_image_to_image_frame_calculator_proto", srcs = ["opencv_encoded_image_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -544,7 +513,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -554,7 +522,6 @@ mediapipe_proto_library( cc_library( name = "mask_overlay_calculator", srcs = ["mask_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":mask_overlay_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -570,7 +537,6 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -597,7 +563,6 @@ cc_library( cc_library( name = "image_file_properties_calculator", srcs = ["image_file_properties_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_file_properties_cc_proto", @@ -627,7 +592,6 @@ cc_test( cc_library( name = "segmentation_smoothing_calculator", srcs = ["segmentation_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":segmentation_smoothing_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -724,7 +688,6 @@ cc_library( mediapipe_proto_library( name = "warp_affine_calculator_proto", srcs = ["warp_affine_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -736,7 +699,6 @@ cc_library( name = "warp_affine_calculator", srcs = ["warp_affine_calculator.cc"], hdrs = ["warp_affine_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":affine_transformation", ":warp_affine_calculator_cc_proto", @@ -817,7 +779,6 @@ cc_test( cc_library( name = "yuv_to_image_calculator", srcs = ["yuv_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 45f64f4f7..0f8f8706a 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -17,12 +17,11 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "graph_tensors_packet_generator_proto", srcs = ["graph_tensors_packet_generator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework:packet_generator_proto", @@ -32,49 +31,42 @@ proto_library( proto_library( name = "matrix_to_tensor_calculator_options_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "lapped_tensor_buffer_calculator_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "object_detection_tensors_to_detections_calculator_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensorflow_inference_calculator_proto", srcs = ["tensorflow_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_squeeze_dimensions_calculator_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_image_frame_calculator_proto", srcs = ["tensor_to_image_frame_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_matrix_calculator_proto", srcs = ["tensor_to_matrix_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:time_series_header_proto", @@ -84,30 +76,24 @@ proto_library( proto_library( name = "tensor_to_vector_float_calculator_options_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_int_calculator_options_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "tensor_to_vector_string_calculator_options_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) mediapipe_proto_library( name = "unpack_media_sequence_calculator_proto", srcs = ["unpack_media_sequence_calculator.proto"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_proto", "//mediapipe/framework:calculator_proto", @@ -118,14 +104,12 @@ mediapipe_proto_library( proto_library( name = "vector_float_to_tensor_calculator_options_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "vector_string_to_tensor_calculator_options_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -136,7 +120,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:packet_generator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":graph_tensors_packet_generator_proto"], ) @@ -147,7 +130,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":image_frame_to_tensor_calculator_proto"], ) @@ -155,7 +137,6 @@ mediapipe_cc_proto_library( name = "matrix_to_tensor_calculator_options_cc_proto", srcs = ["matrix_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":matrix_to_tensor_calculator_options_proto"], ) @@ -163,7 +144,6 @@ mediapipe_cc_proto_library( name = "lapped_tensor_buffer_calculator_cc_proto", srcs = ["lapped_tensor_buffer_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":lapped_tensor_buffer_calculator_proto"], ) @@ -171,7 +151,6 @@ mediapipe_cc_proto_library( name = "object_detection_tensors_to_detections_calculator_cc_proto", srcs = ["object_detection_tensors_to_detections_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":object_detection_tensors_to_detections_calculator_proto"], ) @@ -179,7 +158,6 @@ mediapipe_cc_proto_library( name = "tensorflow_inference_calculator_cc_proto", srcs = ["tensorflow_inference_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensorflow_inference_calculator_proto"], ) @@ -190,7 +168,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_generator_proto"], ) @@ -201,7 +178,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_frozen_graph_calculator_proto"], ) @@ -212,7 +188,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:packet_generator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_generator_proto"], ) @@ -223,7 +198,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":tensorflow_session_from_saved_model_calculator_proto"], ) @@ -231,7 +205,6 @@ mediapipe_cc_proto_library( name = "tensor_squeeze_dimensions_calculator_cc_proto", srcs = ["tensor_squeeze_dimensions_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_squeeze_dimensions_calculator_proto"], ) @@ -239,7 +212,6 @@ mediapipe_cc_proto_library( name = "tensor_to_image_frame_calculator_cc_proto", srcs = ["tensor_to_image_frame_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_image_frame_calculator_proto"], ) @@ -250,7 +222,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tensor_to_matrix_calculator_proto"], ) @@ -258,7 +229,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_float_calculator_options_cc_proto", srcs = ["tensor_to_vector_float_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_float_calculator_options_proto"], ) @@ -266,7 +236,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_int_calculator_options_cc_proto", srcs = ["tensor_to_vector_int_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_int_calculator_options_proto"], ) @@ -274,7 +243,6 @@ mediapipe_cc_proto_library( name = "tensor_to_vector_string_calculator_options_cc_proto", srcs = ["tensor_to_vector_string_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":tensor_to_vector_string_calculator_options_proto"], ) @@ -285,7 +253,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "@org_tensorflow//tensorflow/core:protos_all_cc", ], - visibility = ["//visibility:public"], deps = [":vector_int_to_tensor_calculator_options_proto"], ) @@ -293,7 +260,6 @@ mediapipe_cc_proto_library( name = "vector_float_to_tensor_calculator_options_cc_proto", srcs = ["vector_float_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_float_to_tensor_calculator_options_proto"], ) @@ -301,14 +267,12 @@ mediapipe_cc_proto_library( name = "vector_string_to_tensor_calculator_options_cc_proto", srcs = ["vector_string_to_tensor_calculator_options.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":vector_string_to_tensor_calculator_options_proto"], ) cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], - visibility = ["//visibility:public"], deps = [ ":graph_tensors_packet_generator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -323,7 +287,6 @@ cc_library( cc_library( name = "image_frame_to_tensor_calculator", srcs = ["image_frame_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":image_frame_to_tensor_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -344,7 +307,6 @@ cc_library( cc_library( name = "matrix_to_tensor_calculator", srcs = ["matrix_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":matrix_to_tensor_calculator_options_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", @@ -366,7 +328,6 @@ cc_library( cc_library( name = "lapped_tensor_buffer_calculator", srcs = ["lapped_tensor_buffer_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -388,9 +349,6 @@ cc_library( # Layering check doesn't play nicely with portable proto wrappers. "no_layering_check", ], - visibility = [ - "//visibility:public", - ], deps = [ ":object_detection_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,9 +365,6 @@ cc_library( cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", @@ -432,9 +387,6 @@ cc_library( cc_library( name = "string_to_sequence_example_calculator", srcs = ["string_to_sequence_example_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", @@ -449,7 +401,6 @@ cc_library( cc_library( name = "tensorflow_inference_calculator", srcs = ["tensorflow_inference_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_inference_calculator_cc_proto", ":tensorflow_session", @@ -487,7 +438,6 @@ cc_library( "tensorflow_session.h", ], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:core", @@ -505,7 +455,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_calculator", srcs = ["tensorflow_session_from_frozen_graph_calculator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", @@ -537,7 +486,6 @@ cc_library( name = "tensorflow_session_from_frozen_graph_generator", srcs = ["tensorflow_session_from_frozen_graph_generator.cc"], features = ["no_layering_check"], - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_frozen_graph_generator_cc_proto", @@ -572,7 +520,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_calculator_cc_proto", @@ -611,7 +558,6 @@ cc_library( "//mediapipe:android": ["__ANDROID__"], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tensorflow_session", ":tensorflow_session_from_saved_model_generator_cc_proto", @@ -637,7 +583,6 @@ cc_library( cc_library( name = "tensor_squeeze_dimensions_calculator", srcs = ["tensor_squeeze_dimensions_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_squeeze_dimensions_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -651,7 +596,6 @@ cc_library( cc_library( name = "tensor_to_image_frame_calculator", srcs = ["tensor_to_image_frame_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_image_frame_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -666,7 +610,6 @@ cc_library( cc_library( name = "tensor_to_matrix_calculator", srcs = ["tensor_to_matrix_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_matrix_calculator_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto", @@ -688,7 +631,6 @@ cc_library( cc_library( name = "tfrecord_reader_calculator", srcs = ["tfrecord_reader_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:integral_types", @@ -704,7 +646,6 @@ cc_library( cc_library( name = "tensor_to_vector_float_calculator", srcs = ["tensor_to_vector_float_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_float_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -724,7 +665,6 @@ cc_library( cc_library( name = "tensor_to_vector_int_calculator", srcs = ["tensor_to_vector_int_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tensor_to_vector_int_calculator_options_cc_proto", "@com_google_absl//absl/base:core_headers", @@ -746,7 +686,6 @@ cc_library( cc_library( name = "tensor_to_vector_string_calculator", srcs = ["tensor_to_vector_string_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:status", @@ -766,9 +705,6 @@ cc_library( cc_library( name = "unpack_media_sequence_calculator", srcs = ["unpack_media_sequence_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/calculators/tensorflow:unpack_media_sequence_calculator_cc_proto", @@ -786,7 +722,6 @@ cc_library( cc_library( name = "vector_int_to_tensor_calculator", srcs = ["vector_int_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_int_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -800,7 +735,6 @@ cc_library( cc_library( name = "vector_float_to_tensor_calculator", srcs = ["vector_float_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_float_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -814,7 +748,6 @@ cc_library( cc_library( name = "vector_string_to_tensor_calculator", srcs = ["vector_string_to_tensor_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":vector_string_to_tensor_calculator_options_cc_proto", "//mediapipe/framework:calculator_framework", @@ -828,7 +761,6 @@ cc_library( cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":lapped_tensor_buffer_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 8edaeee02..db2a27630 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -18,12 +18,11 @@ load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) mediapipe_proto_library( name = "ssd_anchors_calculator_proto", srcs = ["ssd_anchors_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -33,7 +32,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_custom_op_resolver_calculator_proto", srcs = ["tflite_custom_op_resolver_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -43,7 +41,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_inference_calculator_proto", srcs = ["tflite_inference_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -53,7 +50,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_converter_calculator_proto", srcs = ["tflite_converter_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -63,7 +59,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_segmentation_calculator_proto", srcs = ["tflite_tensors_to_segmentation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -73,7 +68,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_detections_calculator_proto", srcs = ["tflite_tensors_to_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -83,7 +77,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_classification_calculator_proto", srcs = ["tflite_tensors_to_classification_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -93,7 +86,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "tflite_tensors_to_landmarks_calculator_proto", srcs = ["tflite_tensors_to_landmarks_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -103,7 +95,6 @@ mediapipe_proto_library( cc_library( name = "ssd_anchors_calculator", srcs = ["ssd_anchors_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -117,7 +108,6 @@ cc_library( cc_library( name = "tflite_custom_op_resolver_calculator", srcs = ["tflite_custom_op_resolver_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_custom_op_resolver_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -208,7 +198,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_inference_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -287,7 +276,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_converter_calculator_cc_proto", "//mediapipe/util/tflite:config", @@ -326,7 +314,6 @@ cc_library( cc_library( name = "tflite_model_calculator", srcs = ["tflite_model_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", @@ -340,7 +327,6 @@ cc_library( cc_library( name = "tflite_tensors_to_segmentation_calculator", srcs = ["tflite_tensors_to_segmentation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_segmentation_calculator_cc_proto", "@com_google_absl//absl/strings:str_format", @@ -408,7 +394,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_detections_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -444,7 +429,6 @@ cc_library( cc_library( name = "tflite_tensors_to_classification_calculator", srcs = ["tflite_tensors_to_classification_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_classification_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -476,7 +460,6 @@ cc_library( cc_library( name = "tflite_tensors_to_landmarks_calculator", srcs = ["tflite_tensors_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tflite_tensors_to_landmarks_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -490,7 +473,6 @@ cc_library( cc_library( name = "tflite_tensors_to_floats_calculator", srcs = ["tflite_tensors_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:ret_check", diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 24e976a73..43eadd53b 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "alignment_points_to_rects_calculator", srcs = ["alignment_points_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "annotation_overlay_calculator_proto", srcs = ["annotation_overlay_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -50,7 +48,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detection_label_id_to_text_calculator_proto", srcs = ["detection_label_id_to_text_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -61,7 +58,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "filter_detections_calculator_proto", srcs = ["filter_detections_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -71,7 +67,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_id_to_label_calculator_proto", srcs = ["timed_box_list_id_to_label_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -81,13 +76,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "latency_proto", srcs = ["latency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "non_max_suppression_calculator_proto", srcs = ["non_max_suppression_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -97,13 +90,11 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_frequency_proto", srcs = ["packet_frequency.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "packet_frequency_calculator_proto", srcs = ["packet_frequency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -113,7 +104,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "packet_latency_calculator_proto", srcs = ["packet_latency_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -123,7 +113,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "collection_has_min_size_calculator_proto", srcs = ["collection_has_min_size_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -133,7 +122,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "association_calculator_proto", srcs = ["association_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -143,7 +131,6 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -188,7 +175,6 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", @@ -228,9 +214,6 @@ cc_test( cc_library( name = "clock_timestamp_calculator", srcs = ["clock_timestamp_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -246,9 +229,6 @@ cc_library( cc_library( name = "clock_latency_calculator", srcs = ["clock_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -263,7 +243,6 @@ cc_library( cc_library( name = "annotation_overlay_calculator", srcs = ["annotation_overlay_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":annotation_overlay_calculator_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto", @@ -296,7 +275,6 @@ cc_library( cc_library( name = "detection_label_id_to_text_calculator", srcs = ["detection_label_id_to_text_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detection_label_id_to_text_calculator_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", @@ -328,7 +306,6 @@ cc_library( cc_library( name = "timed_box_list_id_to_label_calculator", srcs = ["timed_box_list_id_to_label_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_id_to_label_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -357,7 +334,6 @@ cc_library( cc_library( name = "detection_transformation_calculator", srcs = ["detection_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", @@ -391,7 +367,6 @@ cc_test( cc_library( name = "non_max_suppression_calculator", srcs = ["non_max_suppression_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":non_max_suppression_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -408,7 +383,6 @@ cc_library( cc_library( name = "thresholding_calculator", srcs = ["thresholding_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":thresholding_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -421,7 +395,6 @@ cc_library( cc_library( name = "detection_to_landmarks_calculator", srcs = ["detection_to_landmarks_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -436,7 +409,6 @@ cc_library( cc_library( name = "filter_detections_calculator", srcs = ["filter_detections_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":filter_detections_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -450,7 +422,6 @@ cc_library( cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_detection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -471,7 +442,6 @@ cc_library( hdrs = [ "detections_to_rects_calculator.h", ], - visibility = ["//visibility:public"], deps = [ ":detections_to_rects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -489,7 +459,6 @@ cc_library( cc_library( name = "rect_transformation_calculator", srcs = ["rect_transformation_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_transformation_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -504,7 +473,6 @@ cc_library( cc_library( name = "rect_projection_calculator", srcs = ["rect_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:rect_cc_proto", @@ -535,7 +503,6 @@ cc_test( mediapipe_proto_library( name = "rect_to_render_data_calculator_proto", srcs = ["rect_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -547,7 +514,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_to_render_scale_calculator_proto", srcs = ["rect_to_render_scale_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -557,7 +523,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_render_data_calculator_proto", srcs = ["detections_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -569,7 +534,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_render_data_calculator_proto", srcs = ["landmarks_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -581,7 +545,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "timed_box_list_to_render_data_calculator_proto", srcs = ["timed_box_list_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -593,7 +556,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "labels_to_render_data_calculator_proto", srcs = ["labels_to_render_data_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -605,7 +567,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "thresholding_calculator_proto", srcs = ["thresholding_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -617,7 +578,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "detections_to_rects_calculator_proto", srcs = ["detections_to_rects_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -627,7 +587,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmark_projection_calculator_proto", srcs = ["landmark_projection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -637,7 +596,6 @@ mediapipe_proto_library( cc_library( name = "landmark_visibility_calculator", srcs = ["landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -649,7 +607,6 @@ cc_library( cc_library( name = "set_landmark_visibility_calculator", srcs = ["set_landmark_visibility_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -661,7 +618,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_to_floats_calculator_proto", srcs = ["landmarks_to_floats_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -671,7 +627,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "rect_transformation_calculator_proto", srcs = ["rect_transformation_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -681,7 +636,6 @@ mediapipe_proto_library( mediapipe_proto_library( name = "landmarks_to_detection_calculator_proto", srcs = ["landmarks_to_detection_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -693,7 +647,6 @@ mediapipe_proto_library( cc_library( name = "detections_to_render_data_calculator", srcs = ["detections_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":detections_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -713,7 +666,6 @@ cc_library( name = "landmarks_to_render_data_calculator", srcs = ["landmarks_to_render_data_calculator.cc"], hdrs = ["landmarks_to_render_data_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -732,7 +684,6 @@ cc_library( cc_library( name = "timed_box_list_to_render_data_calculator", srcs = ["timed_box_list_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":timed_box_list_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -751,7 +702,6 @@ cc_library( cc_library( name = "labels_to_render_data_calculator", srcs = ["labels_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":labels_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -770,7 +720,6 @@ cc_library( cc_library( name = "rect_to_render_data_calculator", srcs = ["rect_to_render_data_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_data_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -785,7 +734,6 @@ cc_library( cc_library( name = "rect_to_render_scale_calculator", srcs = ["rect_to_render_scale_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":rect_to_render_scale_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -820,7 +768,6 @@ cc_test( cc_library( name = "detection_letterbox_removal_calculator", srcs = ["detection_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -834,7 +781,6 @@ cc_library( cc_library( name = "detection_projection_calculator", srcs = ["detection_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -867,7 +813,6 @@ cc_test( cc_library( name = "landmark_letterbox_removal_calculator", srcs = ["landmark_letterbox_removal_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -881,7 +826,6 @@ cc_library( cc_library( name = "landmark_projection_calculator", srcs = ["landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmark_projection_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -914,7 +858,6 @@ cc_test( cc_library( name = "world_landmark_projection_calculator", srcs = ["world_landmark_projection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", @@ -928,7 +871,6 @@ cc_library( mediapipe_proto_library( name = "landmarks_smoothing_calculator_proto", srcs = ["landmarks_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -938,7 +880,6 @@ mediapipe_proto_library( cc_library( name = "landmarks_smoothing_calculator", srcs = ["landmarks_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -956,7 +897,6 @@ cc_library( mediapipe_proto_library( name = "visibility_smoothing_calculator_proto", srcs = ["visibility_smoothing_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -966,7 +906,6 @@ mediapipe_proto_library( cc_library( name = "visibility_smoothing_calculator", srcs = ["visibility_smoothing_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_smoothing_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -982,7 +921,6 @@ cc_library( mediapipe_proto_library( name = "visibility_copy_calculator_proto", srcs = ["visibility_copy_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -992,7 +930,6 @@ mediapipe_proto_library( cc_library( name = "visibility_copy_calculator", srcs = ["visibility_copy_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":visibility_copy_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1007,7 +944,6 @@ cc_library( cc_library( name = "landmarks_to_floats_calculator", srcs = ["landmarks_to_floats_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":landmarks_to_floats_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1054,7 +990,6 @@ cc_test( mediapipe_proto_library( name = "top_k_scores_calculator_proto", srcs = ["top_k_scores_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1064,7 +999,6 @@ mediapipe_proto_library( cc_library( name = "top_k_scores_calculator", srcs = ["top_k_scores_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":top_k_scores_calculator_cc_proto", "@com_google_absl//absl/container:node_hash_map", @@ -1108,7 +1042,6 @@ cc_test( mediapipe_proto_library( name = "local_file_contents_calculator_proto", srcs = ["local_file_contents_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1118,7 +1051,6 @@ mediapipe_proto_library( cc_library( name = "local_file_contents_calculator", srcs = ["local_file_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":local_file_contents_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1132,7 +1064,6 @@ cc_library( cc_library( name = "local_file_pattern_contents_calculator", srcs = ["local_file_pattern_contents_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/port:file_helpers", @@ -1146,7 +1077,6 @@ cc_library( name = "filter_collection_calculator", srcs = ["filter_collection_calculator.cc"], hdrs = ["filter_collection_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:classification_cc_proto", @@ -1164,7 +1094,6 @@ cc_library( name = "collection_has_min_size_calculator", srcs = ["collection_has_min_size_calculator.cc"], hdrs = ["collection_has_min_size_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":collection_has_min_size_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1192,7 +1121,6 @@ cc_test( cc_library( name = "association_calculator", hdrs = ["association_calculator.h"], - visibility = ["//visibility:public"], deps = [ ":association_calculator_cc_proto", "//mediapipe/framework:calculator_context", @@ -1209,7 +1137,6 @@ cc_library( cc_library( name = "association_norm_rect_calculator", srcs = ["association_norm_rect_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1224,7 +1151,6 @@ cc_library( cc_library( name = "association_detection_calculator", srcs = ["association_detection_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":association_calculator", "//mediapipe/framework:calculator_context", @@ -1259,7 +1185,6 @@ cc_test( cc_library( name = "detections_to_timed_box_list_calculator", srcs = ["detections_to_timed_box_list_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1274,7 +1199,6 @@ cc_library( cc_library( name = "detection_unique_id_calculator", srcs = ["detection_unique_id_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:detection_cc_proto", @@ -1287,7 +1211,6 @@ cc_library( mediapipe_proto_library( name = "logic_calculator_proto", srcs = ["logic_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1297,7 +1220,6 @@ mediapipe_proto_library( cc_library( name = "logic_calculator", srcs = ["logic_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":logic_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1310,7 +1232,6 @@ cc_library( cc_library( name = "to_image_calculator", srcs = ["to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -1333,7 +1254,6 @@ cc_library( cc_library( name = "from_image_calculator", srcs = ["from_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -1385,7 +1305,6 @@ cc_test( mediapipe_proto_library( name = "refine_landmarks_from_heatmap_calculator_proto", srcs = ["refine_landmarks_from_heatmap_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -1403,7 +1322,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ ":refine_landmarks_from_heatmap_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -1454,7 +1372,6 @@ cc_library( name = "inverse_matrix_calculator", srcs = ["inverse_matrix_calculator.cc"], hdrs = ["inverse_matrix_calculator.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:node", diff --git a/mediapipe/calculators/video/BUILD b/mediapipe/calculators/video/BUILD index 2db3ed252..f2b8135f2 100644 --- a/mediapipe/calculators/video/BUILD +++ b/mediapipe/calculators/video/BUILD @@ -21,19 +21,17 @@ load( licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "flow_to_image_calculator_proto", srcs = ["flow_to_image_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) proto_library( name = "opencv_video_encoder_calculator_proto", srcs = ["opencv_video_encoder_calculator.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:calculator_proto"], ) @@ -58,7 +56,6 @@ proto_library( proto_library( name = "box_tracker_calculator_proto", srcs = ["box_tracker_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_tracker_proto", @@ -68,7 +65,6 @@ proto_library( proto_library( name = "tracked_detection_manager_calculator_proto", srcs = ["tracked_detection_manager_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_proto", @@ -78,7 +74,6 @@ proto_library( proto_library( name = "box_detector_calculator_proto", srcs = ["box_detector_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", "//mediapipe/util/tracking:box_detector_proto", @@ -88,7 +83,6 @@ proto_library( proto_library( name = "video_pre_stream_calculator_proto", srcs = ["video_pre_stream_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_proto", ], @@ -101,7 +95,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:motion_analysis_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_calculator_proto"], ) @@ -112,7 +105,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:flow_packager_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_calculator_proto"], ) @@ -123,7 +115,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_tracker_calculator_proto"], ) @@ -134,7 +125,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:tracked_detection_manager_config_cc_proto", ], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_calculator_proto"], ) @@ -145,7 +135,6 @@ mediapipe_cc_proto_library( "//mediapipe/framework:calculator_cc_proto", "//mediapipe/util/tracking:box_detector_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_calculator_proto"], ) @@ -155,7 +144,6 @@ mediapipe_cc_proto_library( cc_deps = [ "//mediapipe/framework:calculator_cc_proto", ], - visibility = ["//visibility:public"], deps = [":video_pre_stream_calculator_proto"], ) @@ -163,7 +151,6 @@ mediapipe_cc_proto_library( name = "flow_to_image_calculator_cc_proto", srcs = ["flow_to_image_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":flow_to_image_calculator_proto"], ) @@ -171,14 +158,12 @@ mediapipe_cc_proto_library( name = "opencv_video_encoder_calculator_cc_proto", srcs = ["opencv_video_encoder_calculator.proto"], cc_deps = ["//mediapipe/framework:calculator_cc_proto"], - visibility = ["//visibility:public"], deps = [":opencv_video_encoder_calculator_proto"], ) cc_library( name = "flow_to_image_calculator", srcs = ["flow_to_image_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_to_image_calculator_cc_proto", "//mediapipe/calculators/video/tool:flow_quantizer_model", @@ -198,7 +183,6 @@ cc_library( cc_library( name = "opencv_video_decoder_calculator", srcs = ["opencv_video_decoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_format_cc_proto", @@ -217,7 +201,6 @@ cc_library( cc_library( name = "opencv_video_encoder_calculator", srcs = ["opencv_video_encoder_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":opencv_video_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -240,7 +223,6 @@ cc_library( cc_library( name = "tvl1_optical_flow_calculator", srcs = ["tvl1_optical_flow_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", @@ -256,7 +238,6 @@ cc_library( cc_library( name = "motion_analysis_calculator", srcs = ["motion_analysis_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":motion_analysis_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -282,7 +263,6 @@ cc_library( cc_library( name = "flow_packager_calculator", srcs = ["flow_packager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":flow_packager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -300,7 +280,6 @@ cc_library( cc_library( name = "box_tracker_calculator", srcs = ["box_tracker_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -327,7 +306,6 @@ cc_library( cc_library( name = "box_detector_calculator", srcs = ["box_detector_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":box_detector_calculator_cc_proto", "@com_google_absl//absl/memory", @@ -369,7 +347,6 @@ cc_library( cc_library( name = "tracked_detection_manager_calculator", srcs = ["tracked_detection_manager_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":tracked_detection_manager_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -390,7 +367,6 @@ cc_library( cc_library( name = "video_pre_stream_calculator", srcs = ["video_pre_stream_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":video_pre_stream_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -407,7 +383,6 @@ filegroup( "testdata/format_MKV_VP8_VORBIS.video", "testdata/format_MP4_AVC720P_AAC.video", ], - visibility = ["//visibility:public"], ) cc_test( @@ -480,7 +455,6 @@ mediapipe_binary_graph( name = "parallel_tracker_binarypb", graph = "testdata/parallel_tracker_graph.pbtxt", output_name = "testdata/parallel_tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", @@ -494,7 +468,6 @@ mediapipe_binary_graph( name = "tracker_binarypb", graph = "testdata/tracker_graph.pbtxt", output_name = "testdata/tracker.binarypb", - visibility = ["//visibility:public"], deps = [ ":box_tracker_calculator", ":flow_packager_calculator", diff --git a/mediapipe/examples/desktop/hello_world/BUILD b/mediapipe/examples/desktop/hello_world/BUILD index edf98bf13..27aa088e7 100644 --- a/mediapipe/examples/desktop/hello_world/BUILD +++ b/mediapipe/examples/desktop/hello_world/BUILD @@ -14,12 +14,11 @@ licenses(["notice"]) -package(default_visibility = ["//mediapipe/examples:__subpackages__"]) +package(default_visibility = ["//visibility:public"]) cc_binary( name = "hello_world", srcs = ["hello_world.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/framework:calculator_graph", diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index e3429f1e9..3cc72b4f1 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -139,7 +139,7 @@ mediapipe_proto_library( name = "test_calculators_proto", testonly = 1, srcs = ["test_calculators.proto"], - visibility = ["//visibility:public"], + visibility = [":mediapipe_internal"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 4276ffc3a..fdb698c48 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -17,7 +17,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) @@ -26,7 +26,6 @@ licenses(["notice"]) mediapipe_proto_library( name = "detection_proto", srcs = ["detection.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats:location_data_proto"], ) @@ -45,7 +44,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "classification_proto", srcs = ["classification.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -64,46 +62,39 @@ mediapipe_register_type( mediapipe_proto_library( name = "image_format_proto", srcs = ["image_format.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "matrix_data_proto", srcs = ["matrix_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "location_data_proto", srcs = ["location_data.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "affine_transform_data_proto", srcs = ["affine_transform_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "time_series_header_proto", srcs = ["time_series_header.proto"], - visibility = ["//visibility:public"], ) mediapipe_proto_library( name = "image_file_properties_proto", srcs = ["image_file_properties.proto"], - visibility = ["//visibility:public"], ) cc_library( name = "deleting_file", srcs = ["deleting_file.cc"], hdrs = ["deleting_file.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", ], @@ -113,7 +104,6 @@ cc_library( name = "matrix", srcs = ["matrix.cc"], hdrs = ["matrix.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", "//mediapipe/framework/formats:matrix_data_cc_proto", @@ -129,9 +119,6 @@ cc_library( name = "affine_transform", srcs = ["affine_transform.cc"], hdrs = ["affine_transform.h"], - visibility = [ - "//visibility:public", - ], deps = [ ":affine_transform_data_cc_proto", "//mediapipe/framework:port", @@ -154,7 +141,6 @@ cc_library( name = "image_frame", srcs = ["image_frame.cc"], hdrs = ["image_frame.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/base", @@ -179,7 +165,6 @@ cc_library( name = "image_frame_opencv", srcs = ["image_frame_opencv.cc"], hdrs = ["image_frame_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "//mediapipe/framework/formats:image_format_cc_proto", @@ -206,7 +191,6 @@ cc_library( name = "location", srcs = ["location.cc"], hdrs = ["location.h"], - visibility = ["//visibility:public"], deps = [ "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats/annotation:locus_cc_proto", @@ -238,7 +222,6 @@ cc_library( name = "location_opencv", srcs = ["location_opencv.cc"], hdrs = ["location_opencv.h"], - visibility = ["//visibility:public"], deps = [ ":location", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", @@ -261,7 +244,6 @@ cc_test( cc_library( name = "video_stream_header", hdrs = ["video_stream_header.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", ], @@ -270,7 +252,6 @@ cc_library( cc_library( name = "yuv_image", hdrs = ["yuv_image.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:integral_types", "@libyuv", @@ -294,7 +275,6 @@ cc_test( mediapipe_proto_library( name = "rect_proto", srcs = ["rect.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -312,7 +292,6 @@ mediapipe_register_type( mediapipe_proto_library( name = "landmark_proto", srcs = ["landmark.proto"], - visibility = ["//visibility:public"], ) mediapipe_register_type( @@ -344,7 +323,6 @@ cc_library( ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", @@ -374,7 +352,6 @@ cc_library( name = "image_multi_pool", srcs = ["image_multi_pool.cc"], hdrs = ["image_multi_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_frame_pool", @@ -411,7 +388,6 @@ cc_library( hdrs = [ "image_opencv.h", ], - visibility = ["//visibility:public"], deps = [ ":image", "//mediapipe/framework/formats:image_format_cc_proto", @@ -425,7 +401,6 @@ cc_library( name = "image_frame_pool", srcs = ["image_frame_pool.cc"], hdrs = ["image_frame_pool.h"], - visibility = ["//visibility:public"], deps = [ ":image_frame", "@com_google_absl//absl/memory", @@ -476,7 +451,6 @@ cc_library( "-landroid", ], }), - visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", "@com_google_absl//absl/synchronization", diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 328001e85..9bcb7bccd 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -16,7 +16,7 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) licenses(["notice"]) @@ -24,12 +24,10 @@ mediapipe_proto_library( name = "locus_proto", srcs = ["locus.proto"], portable_deps = ["//mediapipe/framework/formats/annotation:rasterization_cc_proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) mediapipe_proto_library( name = "rasterization_proto", srcs = ["rasterization.proto"], - visibility = ["//visibility:public"], ) diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index 9819d262c..f1bbc0289 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -20,18 +20,16 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "optical_flow_field_data_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "optical_flow_field_data_cc_proto", srcs = ["optical_flow_field_data.proto"], - visibility = ["//visibility:public"], deps = [":optical_flow_field_data_proto"], ) @@ -39,9 +37,6 @@ cc_library( name = "optical_flow_field", srcs = ["optical_flow_field.cc"], hdrs = ["optical_flow_field.h"], - visibility = [ - "//visibility:public", - ], deps = [ "//mediapipe/framework:type_map", "//mediapipe/framework/deps:mathutil", diff --git a/mediapipe/framework/formats/object_detection/BUILD b/mediapipe/framework/formats/object_detection/BUILD index 39940acdc..35292e1cc 100644 --- a/mediapipe/framework/formats/object_detection/BUILD +++ b/mediapipe/framework/formats/object_detection/BUILD @@ -19,17 +19,15 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package(default_visibility = ["//visibility:public"]) proto_library( name = "anchor_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], ) mediapipe_cc_proto_library( name = "anchor_cc_proto", srcs = ["anchor.proto"], - visibility = ["//visibility:public"], deps = [":anchor_proto"], ) diff --git a/mediapipe/framework/stream_handler/BUILD b/mediapipe/framework/stream_handler/BUILD index 866a5120e..01ef6ee86 100644 --- a/mediapipe/framework/stream_handler/BUILD +++ b/mediapipe/framework/stream_handler/BUILD @@ -18,35 +18,31 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library" licenses(["notice"]) package( - default_visibility = ["//visibility:private"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], ) proto_library( name = "default_input_stream_handler_proto", srcs = ["default_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "fixed_size_input_stream_handler_proto", srcs = ["fixed_size_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "sync_set_input_stream_handler_proto", srcs = ["sync_set_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) proto_library( name = "timestamp_align_input_stream_handler_proto", srcs = ["timestamp_align_input_stream_handler.proto"], - visibility = ["//visibility:public"], deps = ["//mediapipe/framework:mediapipe_options_proto"], ) @@ -54,7 +50,6 @@ mediapipe_cc_proto_library( name = "default_input_stream_handler_cc_proto", srcs = ["default_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":default_input_stream_handler_proto"], ) @@ -62,7 +57,6 @@ mediapipe_cc_proto_library( name = "fixed_size_input_stream_handler_cc_proto", srcs = ["fixed_size_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":fixed_size_input_stream_handler_proto"], ) @@ -70,7 +64,6 @@ mediapipe_cc_proto_library( name = "sync_set_input_stream_handler_cc_proto", srcs = ["sync_set_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":sync_set_input_stream_handler_proto"], ) @@ -78,14 +71,12 @@ mediapipe_cc_proto_library( name = "timestamp_align_input_stream_handler_cc_proto", srcs = ["timestamp_align_input_stream_handler.proto"], cc_deps = ["//mediapipe/framework:mediapipe_options_cc_proto"], - visibility = ["//visibility:public"], deps = [":timestamp_align_input_stream_handler_proto"], ) cc_library( name = "barrier_input_stream_handler", srcs = ["barrier_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -96,7 +87,6 @@ cc_library( name = "default_input_stream_handler", srcs = ["default_input_stream_handler.cc"], hdrs = ["default_input_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto", @@ -108,7 +98,6 @@ cc_library( cc_library( name = "early_close_input_stream_handler", srcs = ["early_close_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "@com_google_absl//absl/strings", @@ -119,7 +108,6 @@ cc_library( cc_library( name = "fixed_size_input_stream_handler", srcs = ["fixed_size_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ ":default_input_stream_handler", "//mediapipe/framework:input_stream_handler", @@ -131,7 +119,6 @@ cc_library( cc_library( name = "immediate_input_stream_handler", srcs = ["immediate_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", ], @@ -142,7 +129,6 @@ cc_library( name = "in_order_output_stream_handler", srcs = ["in_order_output_stream_handler.cc"], hdrs = ["in_order_output_stream_handler.h"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -160,7 +146,6 @@ cc_library( cc_library( name = "mux_input_stream_handler", srcs = ["mux_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:input_stream_handler", "//mediapipe/framework/port:logging", @@ -173,7 +158,6 @@ cc_library( cc_library( name = "sync_set_input_stream_handler", srcs = ["sync_set_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection", "//mediapipe/framework:collection_item_id", @@ -192,7 +176,6 @@ cc_library( cc_library( name = "timestamp_align_input_stream_handler", srcs = ["timestamp_align_input_stream_handler.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:collection_item_id", "//mediapipe/framework:input_stream_handler", diff --git a/mediapipe/modules/holistic_landmark/calculators/BUILD b/mediapipe/modules/holistic_landmark/calculators/BUILD index c3c091924..bc00b697c 100644 --- a/mediapipe/modules/holistic_landmark/calculators/BUILD +++ b/mediapipe/modules/holistic_landmark/calculators/BUILD @@ -21,7 +21,6 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "hand_detections_from_pose_to_rects_calculator", srcs = ["hand_detections_from_pose_to_rects_calculator.cc"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", @@ -39,7 +38,6 @@ cc_library( mediapipe_proto_library( name = "roi_tracking_calculator_proto", srcs = ["roi_tracking_calculator.proto"], - visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -49,7 +47,6 @@ mediapipe_proto_library( cc_library( name = "roi_tracking_calculator", srcs = ["roi_tracking_calculator.cc"], - visibility = ["//visibility:public"], deps = [ ":roi_tracking_calculator_cc_proto", "//mediapipe/framework:calculator_framework", diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index 3f1ebb353..6bca24446 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -134,7 +134,6 @@ proto_library( mediapipe_cc_proto_library( name = "tone_models_cc_proto", srcs = ["tone_models.proto"], - visibility = ["//visibility:public"], deps = [":tone_models_proto"], ) @@ -142,7 +141,6 @@ mediapipe_cc_proto_library( name = "tone_estimation_cc_proto", srcs = ["tone_estimation.proto"], cc_deps = [":tone_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tone_estimation_proto"], ) @@ -153,21 +151,18 @@ mediapipe_cc_proto_library( ":tone_estimation_cc_proto", ":tone_models_cc_proto", ], - visibility = ["//visibility:public"], deps = [":region_flow_computation_proto"], ) mediapipe_cc_proto_library( name = "motion_saliency_cc_proto", srcs = ["motion_saliency.proto"], - visibility = ["//visibility:public"], deps = [":motion_saliency_proto"], ) mediapipe_cc_proto_library( name = "motion_estimation_cc_proto", srcs = ["motion_estimation.proto"], - visibility = ["//visibility:public"], deps = [":motion_estimation_proto"], ) @@ -179,7 +174,6 @@ mediapipe_cc_proto_library( ":motion_saliency_cc_proto", ":region_flow_computation_cc_proto", ], - visibility = ["//visibility:public"], deps = [":motion_analysis_proto"], ) @@ -187,14 +181,12 @@ mediapipe_cc_proto_library( name = "region_flow_cc_proto", srcs = ["region_flow.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":region_flow_proto"], ) mediapipe_cc_proto_library( name = "motion_models_cc_proto", srcs = ["motion_models.proto"], - visibility = ["//visibility:public"], deps = [":motion_models_proto"], ) @@ -202,21 +194,18 @@ mediapipe_cc_proto_library( name = "camera_motion_cc_proto", srcs = ["camera_motion.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":camera_motion_proto"], ) mediapipe_cc_proto_library( name = "push_pull_filtering_cc_proto", srcs = ["push_pull_filtering.proto"], - visibility = ["//visibility:public"], deps = [":push_pull_filtering_proto"], ) mediapipe_cc_proto_library( name = "frame_selection_solution_evaluator_cc_proto", srcs = ["frame_selection_solution_evaluator.proto"], - visibility = ["//visibility:public"], deps = [":frame_selection_solution_evaluator_proto"], ) @@ -228,7 +217,6 @@ mediapipe_cc_proto_library( ":frame_selection_solution_evaluator_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":frame_selection_proto"], ) @@ -239,7 +227,6 @@ mediapipe_cc_proto_library( ":motion_models_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":flow_packager_proto"], ) @@ -247,7 +234,6 @@ mediapipe_cc_proto_library( name = "tracking_cc_proto", srcs = ["tracking.proto"], cc_deps = [":motion_models_cc_proto"], - visibility = ["//visibility:public"], deps = [":tracking_proto"], ) @@ -255,14 +241,12 @@ mediapipe_cc_proto_library( name = "box_tracker_cc_proto", srcs = ["box_tracker.proto"], cc_deps = [":tracking_cc_proto"], - visibility = ["//visibility:public"], deps = [":box_tracker_proto"], ) mediapipe_cc_proto_library( name = "tracked_detection_manager_config_cc_proto", srcs = ["tracked_detection_manager_config.proto"], - visibility = ["//visibility:public"], deps = [":tracked_detection_manager_config_proto"], ) @@ -273,7 +257,6 @@ mediapipe_cc_proto_library( ":box_tracker_cc_proto", ":region_flow_cc_proto", ], - visibility = ["//visibility:public"], deps = [":box_detector_proto"], ) From 09740130e874560957b154bbb51ae4c90dcd64ca Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Tue, 29 Nov 2022 11:32:44 -0800 Subject: [PATCH 118/137] Use naturalWidth and naturalHeight for image data PiperOrigin-RevId: 491694147 --- mediapipe/web/graph_runner/graph_runner.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 9a0f7148c..9a8101659 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -325,6 +325,10 @@ export class GraphRunner { if ((imageSource as HTMLVideoElement).videoWidth) { width = (imageSource as HTMLVideoElement).videoWidth; height = (imageSource as HTMLVideoElement).videoHeight; + } else if ((imageSource as HTMLImageElement).naturalWidth) { + // TODO: Ensure this works with SVG images + width = (imageSource as HTMLImageElement).naturalWidth; + height = (imageSource as HTMLImageElement).naturalHeight; } else { width = imageSource.width; height = imageSource.height; From 88173948eed970b3cc5c215ec3541fcc08b1723c Mon Sep 17 00:00:00 2001 From: Michael Hays Date: Tue, 29 Nov 2022 13:37:18 -0800 Subject: [PATCH 119/137] Internal change PiperOrigin-RevId: 491724816 --- mediapipe/web/graph_runner/graph_runner.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 9a8101659..a9bb979af 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -1085,8 +1085,8 @@ async function runScript(scriptUrl: string) { */ export async function createMediaPipeLib( constructorFcn: WasmMediaPipeConstructor, - wasmLoaderScript?: string, - assetLoaderScript?: string, + wasmLoaderScript?: string|null, + assetLoaderScript?: string|null, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null, fileLocator?: FileLocator): Promise { const scripts = []; From fcd2d2c5af18dc4ebf16116a4f472b4bdb5e52a0 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Tue, 29 Nov 2022 14:12:14 -0800 Subject: [PATCH 120/137] Internal change PiperOrigin-RevId: 491733850 --- mediapipe/gpu/BUILD | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9cc670fb6..7a8aa6557 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -176,6 +176,16 @@ cc_library( "-fobjc-arc", # enable reference-counting ], }), + linkopts = select({ + "//conditions:default": [], + "//mediapipe:ios": [ + "-framework OpenGLES", + ], + "//mediapipe:macos": [ + "-framework OpenGL", + "-framework AppKit", + ], + }), visibility = ["//visibility:public"], deps = [ ":attachments", @@ -204,8 +214,10 @@ cc_library( }) + select({ "//conditions:default": [ ], - "//mediapipe:ios": [], - "//mediapipe:macos": [], + "//mediapipe:ios": [ + ], + "//mediapipe:macos": [ + ], }), ) From 460aee7933f255c749bda69673174ec91a9be017 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Wed, 30 Nov 2022 20:40:00 -0800 Subject: [PATCH 121/137] Make mediapipe_tasks_aar's android_library depend on "//third_party:androidx_annotation". PiperOrigin-RevId: 492092487 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 762184842..6ca67c096 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -289,6 +289,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", ] + select({ From 29c7702984fd0309fbadf64347fdd7cb5604b52f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 05:50:46 -0800 Subject: [PATCH 122/137] Inline formerly nested 'ClassifierOptions' in Java classifier APIs. PiperOrigin-RevId: 492173060 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audioclassifier/AudioClassifier.java | 84 ++++++++++++++--- .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../text/textclassifier/TextClassifier.java | 90 ++++++++++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../imageclassifier/ImageClassifier.java | 82 ++++++++++++++--- .../textclassifier/TextClassifierTest.java | 31 +++++++ .../imageclassifier/ImageClassifierTest.java | 81 +++++++++++------ 8 files changed, 305 insertions(+), 69 deletions(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 6771335ad..2afc75ec0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -66,10 +66,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java index 0f3374175..d78685fe3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -266,7 +266,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /* * Sends audio data (a block in a continuous audio stream) to perform audio classification, and - * the results will be available via the {@link ResultListener} provided in the + * the results will be available via the {@link ResultListener} provided in the * {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with * the audio stream mode. * @@ -320,10 +320,42 @@ public final class AudioClassifier extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -340,9 +372,7 @@ public final class AudioClassifier extends BaseAudioTaskApi { /** * Validates and builds the {@link AudioClassifierOptions} instance. * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the audio classifier - * is in the audio stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final AudioClassifierOptions build() { AudioClassifierOptions options = autoBuild(); @@ -357,6 +387,13 @@ public final class AudioClassifier extends BaseAudioTaskApi { "The audio classifier is in the audio clips mode, a user-defined result listener" + " shouldn't be provided in AudioClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -365,7 +402,15 @@ public final class AudioClassifier extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -373,7 +418,9 @@ public final class AudioClassifier extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -385,12 +432,21 @@ public final class AudioClassifier extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 023a1f286..f9c8e7c76 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -49,10 +49,10 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java index 341d6bf91..0ea91a9f8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -24,7 +24,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; import com.google.mediapipe.tasks.core.TaskInfo; @@ -216,20 +216,79 @@ public final class TextClassifier implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); - public abstract TextClassifierOptions build(); + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract TextClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link TextClassifierOptions} instance. + * + * @throws IllegalArgumentException if any of the set options are invalid. + */ + public final TextClassifierOptions build() { + TextClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } + return options; + } } abstract BaseOptions baseOptions(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); public static Builder builder() { - return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + return new AutoValue_TextClassifier_TextClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -238,12 +297,21 @@ public final class TextClassifier implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b7febb118..2d130ff05 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -98,10 +98,10 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 5e278804b..8990f46fd 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.ClassificationResult; import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -376,10 +376,42 @@ public final class ImageClassifier extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link ClassifierOptions} controling classification behavior, such as - * score threshold, number of results, etc. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ - public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); /** * Sets the {@link ResultListener} to receive the classification results asynchronously when @@ -396,9 +428,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { /** * Validates and builds the {@link ImageClassifierOptions} instance. * * - * @throws IllegalArgumentException if the result listener and the running mode are not - * properly configured. The result listener should only be set when the image classifier - * is in the live stream mode. + * @throws IllegalArgumentException if any of the set options are invalid. */ public final ImageClassifierOptions build() { ImageClassifierOptions options = autoBuild(); @@ -413,6 +443,13 @@ public final class ImageClassifier extends BaseVisionTaskApi { "The image classifier is in the image or video mode, a user-defined result listener" + " shouldn't be provided in ImageClassifierOptions."); } + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0."); + } + if (!options.categoryAllowlist().isEmpty() && !options.categoryDenylist().isEmpty()) { + throw new IllegalArgumentException( + "Category allowlist and denylist are mutually exclusive."); + } return options; } } @@ -421,7 +458,15 @@ public final class ImageClassifier extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional classifierOptions(); + abstract Optional displayNamesLocale(); + + abstract Optional maxResults(); + + abstract Optional scoreThreshold(); + + abstract List categoryAllowlist(); + + abstract List categoryDenylist(); abstract Optional> resultListener(); @@ -429,7 +474,9 @@ public final class ImageClassifier extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageClassifier_ImageClassifierOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); } /** @@ -441,12 +488,21 @@ public final class ImageClassifier extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + ClassifierOptionsProto.ClassifierOptions.Builder classifierOptionsBuilder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(classifierOptionsBuilder::setDisplayNamesLocale); + maxResults().ifPresent(classifierOptionsBuilder::setMaxResults); + scoreThreshold().ifPresent(classifierOptionsBuilder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + classifierOptionsBuilder.addAllCategoryDenylist(categoryDenylist()); + } ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.Builder taskOptionsBuilder = ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (classifierOptions().isPresent()) { - taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setClassifierOptions(classifierOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageClassifierGraphOptionsProto.ImageClassifierGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java index 5e03d2a4c..5ed413f6a 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -40,6 +40,37 @@ public class TextClassifierTest { private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(BERT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java index 69820ce2d..dac11bf02 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifierTest.java @@ -26,7 +26,6 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -55,6 +54,37 @@ public class ImageClassifierTest { @RunWith(AndroidJUnit4.class) public static final class General extends ImageClassifierTest { + @Test + public void options_failsWithNegativeMaxResults() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setMaxResults(-1) + .build()); + assertThat(exception).hasMessageThat().contains("If specified, maxResults must be > 0"); + } + + @Test + public void options_failsWithBothAllowlistAndDenylist() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + ImageClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) + .setCategoryAllowlist(Arrays.asList("foo")) + .setCategoryDenylist(Arrays.asList("bar")) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("Category allowlist and denylist are mutually exclusive"); + } + @Test public void create_failsWithMissingModel() throws Exception { String nonExistentFile = "/path/to/non/existent/file"; @@ -105,7 +135,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -125,7 +155,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(QUANTIZED_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -141,7 +171,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setScoreThreshold(0.02f).build()) + .setScoreThreshold(0.02f) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -160,10 +190,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) - .build()) + .setCategoryAllowlist(Arrays.asList("cheeseburger", "guacamole", "meat loaf")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -183,11 +210,8 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions( - ClassifierOptions.builder() - .setMaxResults(3) - .setCategoryDenylist(Arrays.asList("bagel")) - .build()) + .setMaxResults(3) + .setCategoryDenylist(Arrays.asList("bagel")) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -207,7 +231,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -228,7 +252,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(3).build()) + .setMaxResults(3) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -251,7 +275,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -322,14 +346,14 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -353,7 +377,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyAsync( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -379,7 +403,7 @@ public class ImageClassifierTest { MediaPipeException.class, () -> imageClassifier.classifyForVideo( - getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -388,7 +412,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -405,13 +429,14 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.VIDEO) .build(); ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options); for (int i = 0; i < 3; i++) { - ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i); + ImageClassifierResult results = + imageClassifier.classifyForVideo(image, /* timestampMs= */ i); assertHasOneHead(results); assertCategoriesAre( results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", ""))); @@ -424,7 +449,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -436,11 +461,11 @@ public class ImageClassifierTest { .build(); try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageClassifier.classifyAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageClassifier.classifyAsync(image, /*timestampMs=*/ 0)); + () -> imageClassifier.classifyAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -453,7 +478,7 @@ public class ImageClassifierTest { ImageClassifierOptions options = ImageClassifierOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(FLOAT_MODEL_FILE).build()) - .setClassifierOptions(ClassifierOptions.builder().setMaxResults(1).build()) + .setMaxResults(1) .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageClassificationResult, inputImage) -> { @@ -466,7 +491,7 @@ public class ImageClassifierTest { try (ImageClassifier imageClassifier = ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageClassifier.classifyAsync(image, /*timestampMs=*/ i); + imageClassifier.classifyAsync(image, /* timestampMs= */ i); } } } From 01010fa24887e50f1bb851e9758847f6f340bea3 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 07:15:52 -0800 Subject: [PATCH 123/137] Internal change PiperOrigin-RevId: 492188196 --- .../com/google/mediapipe/tasks/audio/BUILD | 2 +- .../audio/audioembedder/AudioEmbedder.java | 40 ++++++++--- .../tasks/components/processors/BUILD | 13 ---- .../processors/EmbedderOptions.java | 68 ------------------ .../com/google/mediapipe/tasks/text/BUILD | 2 +- .../tasks/text/textembedder/TextEmbedder.java | 41 ++++++++--- .../com/google/mediapipe/tasks/vision/BUILD | 2 +- .../vision/imageembedder/ImageEmbedder.java | 40 ++++++++--- .../imageembedder/ImageEmbedderTest.java | 69 +++++++++---------- 9 files changed, 126 insertions(+), 151 deletions(-) delete mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD index 2afc75ec0..2d29ccf23 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -92,12 +92,12 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java index c0bc04a4e..4bc505d84 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioembedder/AudioEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.tasks.audio.core.RunningMode; import com.google.mediapipe.tasks.components.containers.AudioData; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -309,10 +309,24 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score - * threshold, number of results, etc. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -354,7 +368,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -362,7 +378,9 @@ public final class AudioEmbedder extends BaseAudioTaskApi { public static Builder builder() { return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder() - .setRunningMode(RunningMode.AUDIO_CLIPS); + .setRunningMode(RunningMode.AUDIO_CLIPS) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -372,12 +390,14 @@ public final class AudioEmbedder extends BaseAudioTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder = AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD index e61e59390..1f99f1612 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -29,19 +29,6 @@ android_library( ], ) -android_library( - name = "embedderoptions", - srcs = ["EmbedderOptions.java"], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - deps = [ - "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", - "//third_party:autovalue", - "@maven//:com_google_guava_guava", - ], -) - # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java deleted file mode 100644 index 3cd197234..000000000 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/EmbedderOptions.java +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 The MediaPipe Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.mediapipe.tasks.components.processors; - -import com.google.auto.value.AutoValue; -import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; - -/** Embedder options shared across MediaPipe Java embedding tasks. */ -@AutoValue -public abstract class EmbedderOptions { - - /** Builder for {@link EmbedderOptions} */ - @AutoValue.Builder - public abstract static class Builder { - /** - * Sets whether L2 normalization should be performed on the returned embeddings. Use this option - * only if the model does not already contain a native L2_NORMALIZATION TF Lite Op. - * In most cases, this is already the case and L2 norm is thus achieved through TF Lite - * inference. - * - *

False by default. - */ - public abstract Builder setL2Normalize(boolean l2Normalize); - - /** - * Sets whether the returned embedding should be quantized to bytes via scalar quantization. - * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is guaranteed - * to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} if this is - * not the case. - * - *

False by default. - */ - public abstract Builder setQuantize(boolean quantize); - - public abstract EmbedderOptions build(); - } - - public abstract boolean l2Normalize(); - - public abstract boolean quantize(); - - public static Builder builder() { - return new AutoValue_EmbedderOptions.Builder().setL2Normalize(false).setQuantize(false); - } - - /** - * Converts an {@link EmbedderOptions} object to an {@link EmbedderOptionsProto.EmbedderOptions} - * protobuf message. - */ - public EmbedderOptionsProto.EmbedderOptions convertToProto() { - return EmbedderOptionsProto.EmbedderOptions.newBuilder() - .setL2Normalize(l2Normalize()) - .setQuantize(quantize()) - .build(); - } -} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index f9c8e7c76..5b10e9aab 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -74,11 +74,11 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java index 95fa1f087..9b464d0e8 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.java @@ -25,7 +25,7 @@ import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.OutputHandler; @@ -41,7 +41,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; /** * Performs embedding extraction on text. @@ -218,20 +217,38 @@ public final class TextEmbedder implements AutoCloseable { public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); public abstract TextEmbedderOptions build(); } abstract BaseOptions baseOptions(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); public static Builder builder() { - return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder(); + return new AutoValue_TextEmbedder_TextEmbedderOptions.Builder() + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link TextEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -240,12 +257,14 @@ public final class TextEmbedder implements AutoCloseable { BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.Builder taskOptionsBuilder = TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( TextEmbedderGraphOptionsProto.TextEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 2d130ff05..b61c174fe 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -190,11 +190,11 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite", + "//mediapipe/tasks/cc/components/processors/proto:embedder_options_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java index 0d8ecd5c3..af053d860 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedder.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Embedding; import com.google.mediapipe.tasks.components.containers.EmbeddingResult; import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; +import com.google.mediapipe.tasks.components.processors.proto.EmbedderOptionsProto; import com.google.mediapipe.tasks.components.utils.CosineSimilarity; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; @@ -369,10 +369,24 @@ public final class ImageEmbedder extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode runningMode); /** - * Sets the optional {@link EmbedderOptions} controling embedder behavior, such as - * L2-normalization and scalar quantization. + * Sets whether L2 normalization should be performed on the returned embeddings. Use this + * option only if the model does not already contain a native L2_NORMALIZATION TF + * Lite Op. In most cases, this is already the case and L2 norm is thus achieved through TF + * Lite inference. + * + *

False by default. */ - public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions); + public abstract Builder setL2Normalize(boolean l2Normalize); + + /** + * Sets whether the returned embedding should be quantized to bytes via scalar quantization. + * Embeddings are implicitly assumed to be unit-norm and therefore any dimensions is + * guaranteed to have value in [-1.0, 1.0]. Use {@link #setL2Normalize(boolean)} + * if this is not the case. + * + *

False by default. + */ + public abstract Builder setQuantize(boolean quantize); /** * Sets the {@link ResultListener} to receive the embedding results asynchronously when the @@ -414,7 +428,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi { abstract RunningMode runningMode(); - abstract Optional embedderOptions(); + abstract boolean l2Normalize(); + + abstract boolean quantize(); abstract Optional> resultListener(); @@ -422,7 +438,9 @@ public final class ImageEmbedder extends BaseVisionTaskApi { public static Builder builder() { return new AutoValue_ImageEmbedder_ImageEmbedderOptions.Builder() - .setRunningMode(RunningMode.IMAGE); + .setRunningMode(RunningMode.IMAGE) + .setL2Normalize(false) + .setQuantize(false); } /** Converts a {@link ImageEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */ @@ -432,12 +450,14 @@ public final class ImageEmbedder extends BaseVisionTaskApi { BaseOptionsProto.BaseOptions.newBuilder(); baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + EmbedderOptionsProto.EmbedderOptions.Builder embedderOptionsBuilder = + EmbedderOptionsProto.EmbedderOptions.newBuilder(); + embedderOptionsBuilder.setL2Normalize(l2Normalize()); + embedderOptionsBuilder.setQuantize(quantize()); ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.Builder taskOptionsBuilder = ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.newBuilder() - .setBaseOptions(baseOptionsBuilder); - if (embedderOptions().isPresent()) { - taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto()); - } + .setBaseOptions(baseOptionsBuilder) + .setEmbedderOptions(embedderOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( ImageEmbedderGraphOptionsProto.ImageEmbedderGraphOptions.ext, diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java index 56249ead9..8dec6f80b 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imageembedder/ImageEmbedderTest.java @@ -25,7 +25,6 @@ import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; -import com.google.mediapipe.tasks.components.processors.EmbedderOptions; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.TestUtils; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -92,8 +91,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -105,12 +104,8 @@ public class ImageEmbedderTest { @Test public void embed_succeedsWithL2Normalization() throws Exception { BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); - EmbedderOptions embedderOptions = EmbedderOptions.builder().setL2Normalize(true).build(); ImageEmbedderOptions options = - ImageEmbedderOptions.builder() - .setBaseOptions(baseOptions) - .setEmbedderOptions(embedderOptions) - .build(); + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setL2Normalize(true).build(); ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -118,8 +113,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -131,12 +126,8 @@ public class ImageEmbedderTest { @Test public void embed_succeedsWithQuantization() throws Exception { BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(MOBILENET_EMBEDDER).build(); - EmbedderOptions embedderOptions = EmbedderOptions.builder().setQuantize(true).build(); ImageEmbedderOptions options = - ImageEmbedderOptions.builder() - .setBaseOptions(baseOptions) - .setEmbedderOptions(embedderOptions) - .build(); + ImageEmbedderOptions.builder().setBaseOptions(baseOptions).setQuantize(true).build(); ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options); @@ -144,8 +135,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ true); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ true); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ true); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ true); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -168,8 +159,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(resultRoi, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRoi, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -190,8 +181,8 @@ public class ImageEmbedderTest { imageEmbedder.embed(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultRotated, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultRotated, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -214,8 +205,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(resultRoiRotated, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(resultRoiRotated, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -277,12 +268,14 @@ public class ImageEmbedderTest { assertThrows( MediaPipeException.class, () -> - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -303,7 +296,8 @@ public class ImageEmbedderTest { exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + () -> + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); } @@ -327,7 +321,8 @@ public class ImageEmbedderTest { assertThrows( MediaPipeException.class, () -> - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 0)); + imageEmbedder.embedForVideo( + getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 0)); assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); } @@ -340,8 +335,8 @@ public class ImageEmbedderTest { ImageEmbedderResult resultCrop = imageEmbedder.embed(getImageFromAsset(BURGER_CROP_IMAGE)); // Check results. - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); - assertHasOneHeadAndCorrectDimension(resultCrop, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); + assertHasOneHeadAndCorrectDimension(resultCrop, /* quantized= */ false); // Check similarity. double similarity = ImageEmbedder.cosineSimilarity( @@ -363,8 +358,8 @@ public class ImageEmbedderTest { for (int i = 0; i < 3; ++i) { ImageEmbedderResult result = - imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ i); - assertHasOneHeadAndCorrectDimension(result, /*quantized=*/ false); + imageEmbedder.embedForVideo(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ i); + assertHasOneHeadAndCorrectDimension(result, /* quantized= */ false); } } @@ -378,17 +373,18 @@ public class ImageEmbedderTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageEmbedderResult, inputImage) -> { - assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); assertImageSizeIsExpected(inputImage); }) .build(); try (ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { - imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /*timestampMs=*/ 1); + imageEmbedder.embedAsync(getImageFromAsset(BURGER_IMAGE), /* timestampMs= */ 1); MediaPipeException exception = assertThrows( MediaPipeException.class, - () -> imageEmbedder.embedAsync(image, /*timestampMs=*/ 0)); + () -> imageEmbedder.embedAsync(image, /* timestampMs= */ 0)); assertThat(exception) .hasMessageThat() .contains("having a smaller timestamp than the processed timestamp"); @@ -405,14 +401,15 @@ public class ImageEmbedderTest { .setRunningMode(RunningMode.LIVE_STREAM) .setResultListener( (imageEmbedderResult, inputImage) -> { - assertHasOneHeadAndCorrectDimension(imageEmbedderResult, /*quantized=*/ false); + assertHasOneHeadAndCorrectDimension( + imageEmbedderResult, /* quantized= */ false); assertImageSizeIsExpected(inputImage); }) .build(); try (ImageEmbedder imageEmbedder = ImageEmbedder.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { for (int i = 0; i < 3; ++i) { - imageEmbedder.embedAsync(image, /*timestampMs=*/ i); + imageEmbedder.embedAsync(image, /* timestampMs= */ i); } } } From a430939fe4b333ddb31a254f6a08b072f7dfff57 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 1 Dec 2022 07:42:55 -0800 Subject: [PATCH 124/137] Document RunningMode PiperOrigin-RevId: 492193299 --- .../vision/gesture_recognizer/gesture_recognizer.ts | 8 ++++++-- .../web/vision/hand_landmarker/hand_landmarker.ts | 8 ++++++-- .../web/vision/image_classifier/image_classifier.ts | 6 ++++-- .../tasks/web/vision/image_embedder/image_embedder.ts | 11 ++++------- .../web/vision/object_detector/object_detector.ts | 8 ++++++-- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 7441911c1..9ec63b07a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -225,7 +225,9 @@ export class GestureRecognizer extends /** * Performs gesture recognition on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `image`. + * * @param image A single image to process. * @return The detected gestures. */ @@ -235,7 +237,9 @@ export class GestureRecognizer extends /** * Performs gesture recognition on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * GestureRecognizer is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The detected gestures. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 6d69d568c..290f49455 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -177,7 +177,9 @@ export class HandLandmarker extends VisionTaskRunner { /** * Performs hand landmarks detection on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `image`. + * * @param image An image to process. * @return The detected hand landmarks. */ @@ -187,7 +189,9 @@ export class HandLandmarker extends VisionTaskRunner { /** * Performs hand landmarks detection on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * HandLandmarker is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The detected hand landmarks. diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 604795f9f..185ddf9ea 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -120,7 +120,8 @@ export class ImageClassifier extends VisionTaskRunner { /** * Performs image classification on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `image`. * * @param image An image to process. * @return The classification result of the image @@ -131,7 +132,8 @@ export class ImageClassifier extends VisionTaskRunner { /** * Performs image classification on the provided video frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ImageClassifier is created with running mode `video`. * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 68068db6d..91352e934 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -122,10 +122,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Performs embedding extraction on the provided single image and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is not set or - * expliclity set to `false`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `image`. * * @param image The image to process. * @return The classification result of the image @@ -136,9 +134,8 @@ export class ImageEmbedder extends VisionTaskRunner { /** * Performs embedding extraction on the provided video frame and waits - * synchronously for the response. - * - * Only use this method when the `useStreamMode` option is set to `true`. + * synchronously for the response. Only use this method when the + * ImageEmbedder is created with running mode `video`. * * @param imageFrame The image frame to process. * @param timestamp The timestamp of the current frame, in ms. diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 0f039acb2..7711c39e9 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -151,7 +151,9 @@ export class ObjectDetector extends VisionTaskRunner { /** * Performs object detection on the provided single image and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `image`. + * * @param image An image to process. * @return The list of detected objects */ @@ -161,7 +163,9 @@ export class ObjectDetector extends VisionTaskRunner { /** * Performs object detection on the provided vidoe frame and waits - * synchronously for the response. + * synchronously for the response. Only use this method when the + * ObjectDetector is created with running mode `video`. + * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @return The list of detected objects From e7eee27c1c78649e126d197ec338b779ff72d356 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 08:14:53 -0800 Subject: [PATCH 125/137] Remove the deleted library "mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions" from mediapipe_tasks_aar's android_library deps list. PiperOrigin-RevId: 492200061 --- .../java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl | 1 - 1 file changed, 1 deletion(-) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 6ca67c096..d91c03cc2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -286,7 +286,6 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:androidx_annotation", From 3ee37800e2d63092d8f8ded69619380eb55ad9ea Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 08:41:33 -0800 Subject: [PATCH 126/137] Depending on "inference_calculator_cpu" when the mediapipe tasks can only support cpu inference. PiperOrigin-RevId: 492205954 --- mediapipe/tasks/cc/audio/audio_classifier/BUILD | 2 +- mediapipe/tasks/cc/audio/audio_embedder/BUILD | 2 +- mediapipe/tasks/cc/text/text_classifier/BUILD | 2 +- mediapipe/tasks/cc/text/text_embedder/BUILD | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index a817bcc3b..f61472413 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -55,7 +55,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/audio/audio_embedder/BUILD b/mediapipe/tasks/cc/audio/audio_embedder/BUILD index adba28e6a..6a0f627b2 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/cc/audio/audio_embedder/BUILD @@ -56,7 +56,7 @@ cc_library( "//mediapipe/calculators/core:side_packet_to_stream_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator", "//mediapipe/calculators/tensor:audio_to_tensor_calculator_cc_proto", - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 61395cf4e..3c9c3fc0e 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -47,7 +47,7 @@ cc_library( name = "text_classifier_graph", srcs = ["text_classifier_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index f19af35be..4c970159e 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -48,8 +48,8 @@ cc_library( name = "text_embedder_graph", srcs = ["text_embedder_graph.cc"], deps = [ - "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator_cpu", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:builder", From e685ac93446e22d31a6bc269416ff13dece6edbe Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Thu, 1 Dec 2022 08:45:47 -0800 Subject: [PATCH 127/137] Re-use classifier options for ObjectDetector PiperOrigin-RevId: 492206856 --- .../web/components/utils/cosine_similarity.ts | 1 + .../tasks/web/vision/object_detector/BUILD | 1 + .../object_detector_options.d.ts | 33 ++----------------- 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/mediapipe/tasks/web/components/utils/cosine_similarity.ts b/mediapipe/tasks/web/components/utils/cosine_similarity.ts index fb1d0c185..1f483b9b6 100644 --- a/mediapipe/tasks/web/components/utils/cosine_similarity.ts +++ b/mediapipe/tasks/web/components/utils/cosine_similarity.ts @@ -36,6 +36,7 @@ export function computeCosineSimilarity(u: Embedding, v: Embedding): number { throw new Error( 'Cannot compute cosine similarity between quantized and float embeddings.'); } + function convertToBytes(data: Uint8Array): number[] { return Array.from(data, v => v - 128); } diff --git a/mediapipe/tasks/web/vision/object_detector/BUILD b/mediapipe/tasks/web/vision/object_detector/BUILD index b6bef6bfa..198585258 100644 --- a/mediapipe/tasks/web/vision/object_detector/BUILD +++ b/mediapipe/tasks/web/vision/object_detector/BUILD @@ -35,6 +35,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:vision_task_options", ], ) diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts index 1d20ce1e2..7564e7760 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector_options.d.ts @@ -14,36 +14,9 @@ * limitations under the License. */ +import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; /** Options to configure the MediaPipe Object Detector Task */ -export interface ObjectDetectorOptions extends VisionTaskOptions { - /** - * The locale to use for display names specified through the TFLite Model - * Metadata, if any. Defaults to English. - */ - displayNamesLocale?: string|undefined; - - /** The maximum number of top-scored detection results to return. */ - maxResults?: number|undefined; - - /** - * Overrides the value provided in the model metadata. Results below this - * value are rejected. - */ - scoreThreshold?: number|undefined; - - /** - * Allowlist of category names. If non-empty, detection results whose category - * name is not in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryDenylist`. - */ - categoryAllowlist?: string[]|undefined; - - /** - * Denylist of category names. If non-empty, detection results whose category - * name is in this set will be filtered out. Duplicate or unknown category - * names are ignored. Mutually exclusive with `categoryAllowlist`. - */ - categoryDenylist?: string[]|undefined; -} +export interface ObjectDetectorOptions extends VisionTaskOptions, + ClassifierOptions {} From 02aa162c9e953b05153f68d13e55a06b34571a0f Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 11:09:26 -0800 Subject: [PATCH 128/137] Rename gesture_recognizer test_data to testdata to be consistent with rest of model_maker PiperOrigin-RevId: 492246728 --- .../python/vision/gesture_recognizer/BUILD | 12 ++++++------ .../gesture_recognizer/gesture_recognizer_demo.py | 2 +- .../gesture_recognizer/gesture_recognizer_test.py | 2 +- .../gesture_recognizer/metadata_writer_test.py | 2 +- .../metadata/custom_gesture_classifier.tflite | Bin .../metadata/custom_gesture_classifier_meta.json | 0 .../call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg | Bin .../call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg | Bin .../call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg | Bin .../call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg | Bin .../call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg | Bin .../call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg | Bin .../call/17d804b5-7118-462d-8191-58d764f591b8.jpg | Bin .../call/1d65a858-623a-4984-9420-958c7e870c3e.jpg | Bin .../call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg | Bin .../call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg | Bin .../four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg | Bin .../four/077fa4bf-a99e-496b-b895-709afc614eec.jpg | Bin .../four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg | Bin .../four/07fdea90-1102-4419-a3af-b394cb29531b.jpg | Bin .../four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg | Bin .../four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg | Bin .../four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg | Bin .../four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg | Bin .../four/249c5023-6106-447a-84ac-17eb4713731b.jpg | Bin .../four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg | Bin .../none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg | Bin .../none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg | Bin .../none/00c84257-800d-4032-9e64-e47eb97005f5.jpg | Bin .../none/0a038096-c14f-46ac-9155-980161ebc440.jpg | Bin .../none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg | Bin .../none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg | Bin .../none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg | Bin .../none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg | Bin .../none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg | Bin .../none/0a787971-9377-4888-803f-aef21863ef7d.jpg | Bin .../rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg | Bin .../rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg | Bin .../rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg | Bin .../rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg | Bin .../rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg | Bin .../rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg | Bin .../rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg | Bin .../rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg | Bin .../rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg | Bin .../rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg | Bin 46 files changed, 9 insertions(+), 9 deletions(-) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/metadata/custom_gesture_classifier.tflite (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/metadata/custom_gesture_classifier_meta.json (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg (100%) rename mediapipe/model_maker/python/vision/gesture_recognizer/{test_data => testdata}/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg (100%) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD index b9425a181..256447a8d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/BUILD @@ -24,9 +24,9 @@ package( # TODO: Remove the unncessary test data once the demo data are moved to an open-sourced # directory. filegroup( - name = "test_data", + name = "testdata", srcs = glob([ - "test_data/**", + "testdata/**", ]), ) @@ -53,7 +53,7 @@ py_test( name = "dataset_test", srcs = ["dataset_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], deps = [ @@ -136,7 +136,7 @@ py_test( size = "large", srcs = ["gesture_recognizer_test.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], shard_count = 2, @@ -151,7 +151,7 @@ py_test( name = "metadata_writer_test", srcs = ["metadata_writer_test.py"], data = [ - ":test_data", + ":testdata", ], deps = [ ":metadata_writer", @@ -164,7 +164,7 @@ py_binary( name = "gesture_recognizer_demo", srcs = ["gesture_recognizer_demo.py"], data = [ - ":test_data", + ":testdata", "//mediapipe/model_maker/models/gesture_recognizer:models", ], python_version = "PY3", diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py index 06075fbc6..1cf9f0619 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_demo.py @@ -31,7 +31,7 @@ FLAGS = flags.FLAGS # TODO: Move hand gesture recognizer demo dataset to an # open-sourced directory. -TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data' +TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data' def define_flags(): diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index 9cee88362..280fc6a82 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -25,7 +25,7 @@ from mediapipe.model_maker.python.core.utils import test_util from mediapipe.model_maker.python.vision import gesture_recognizer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/test_data' +_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata' tf.keras.backend.experimental.enable_tf_random_generator() diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py index e1101e066..83998141d 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/metadata_writer_test.py @@ -23,7 +23,7 @@ from mediapipe.model_maker.python.vision.gesture_recognizer import metadata_writ from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer as base_metadata_writer from mediapipe.tasks.python.test import test_utils -_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata" +_TEST_DATA_DIR = "mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata" _EXPECTED_JSON = test_utils.get_test_data_path( os.path.join(_TEST_DATA_DIR, "custom_gesture_classifier_meta.json")) diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier.tflite rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier.tflite diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/metadata/custom_gesture_classifier_meta.json rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/metadata/custom_gesture_classifier_meta.json diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0413d5c5-f5ba-476f-a921-ea5e967692a9.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/044768ad-1709-44ba-b041-c2f8cbe4c166.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/0e022ee9-74fd-44fe-adad-60c11835e44f.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/143f8b21-1dc3-4383-bf36-0a54244dfbc0.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/172ba7f6-c6ba-4398-89a2-25375dccfefa.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17b3aa02-dc4d-448d-8601-e2b67193d436.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/17d804b5-7118-462d-8191-58d764f591b8.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1d65a858-623a-4984-9420-958c7e870c3e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/1f5fb137-c7a9-435b-85dd-6d7b63ea233a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/call/21de0cfe-af9f-42c2-95d4-aa3d852e7dad.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/06aa70cc-a12a-4b1e-85cf-e54d44c19a3a.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/077fa4bf-a99e-496b-b895-709afc614eec.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07a5a144-c635-4441-aedb-5c8e9da79aac.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/07fdea90-1102-4419-a3af-b394cb29531b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/0c960166-75b0-4c1b-a3cc-2ddbd5a21703.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/105f8f8e-ccd6-45a0-b22a-e314930bc13e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/116292ef-5947-4d6c-a479-630ebb8a1050.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/15a73593-b13e-4a1b-99bb-51775cfdfc42.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/249c5023-6106-447a-84ac-17eb4713731b.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/four/25bb4c45-e40b-482c-b588-04db60b7e450.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00af1db1-7c86-4e9b-9383-1fbd06c3492d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00b85ea4-8c5d-4302-b847-0a5de1d7dab2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/00c84257-800d-4032-9e64-e47eb97005f5.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a038096-c14f-46ac-9155-980161ebc440.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a0ef3d2-2560-4a93-904d-437189fffbf2.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a272153-56c7-42d5-a17d-cd307a1cd6d4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4a8907-1950-4e43-9a03-1740e78224ef.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a4bc2da-f5b3-48cd-8f0d-c61dbd08ba53.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a71a6e8-bb06-4ed0-a60b-c2a602fce261.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/none/0a787971-9377-4888-803f-aef21863ef7d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/026fd791-8f64-4fae-8cb0-0e01dc4362ce.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/055f8be9-f7fd-4c7f-ad3f-7b404b6489c3.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/09a619ab-cdf7-4a66-911f-347113f050f1.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0c6628ea-4a8c-49c9-b7cf-c30aef18dc3d.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/0cc7ad09-ae5f-45a8-b264-4216176369b6.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/10eacf4b-8aaf-46d9-be21-7fb8d8353005.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/15cb4e8b-ba1d-46f1-8456-247016a599a4.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/18e20af8-8fe1-48d4-bd0e-83fa9e2db88e.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/1bed937b-7ae4-4070-891c-daf69415da41.jpg diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg b/mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg similarity index 100% rename from mediapipe/model_maker/python/vision/gesture_recognizer/test_data/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg rename to mediapipe/model_maker/python/vision/gesture_recognizer/testdata/raw_data/rock/20e2164d-3473-4d42-8755-22cdbd4417ba.jpg From 1e2cb2b35968100e6ec6cd974c2ec01e7bf6be9e Mon Sep 17 00:00:00 2001 From: Hadon Nash Date: Thu, 1 Dec 2022 11:33:15 -0800 Subject: [PATCH 129/137] Internal change PiperOrigin-RevId: 492253867 --- mediapipe/framework/input_stream_handler.cc | 4 +- .../immediate_input_stream_handler_test.cc | 37 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index d1dffa414..a7bd9ef43 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -354,7 +354,9 @@ NodeReadiness SyncSet::GetReadiness(Timestamp* min_stream_timestamp) { } } *min_stream_timestamp = std::min(min_packet, min_bound); - if (*min_stream_timestamp == Timestamp::Done()) { + if (*min_stream_timestamp >= Timestamp::OneOverPostStream()) { + // Either OneOverPostStream or Done indicates no more packets. + *min_stream_timestamp = Timestamp::Done(); last_processed_ts_ = Timestamp::Done().PreviousAllowedInStream(); return NodeReadiness::kReadyForClose; } diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index e721afb02..e5de7f0c9 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -230,6 +230,43 @@ TEST_F(ImmediateInputStreamHandlerTest, StreamDoneReady) { input_stream_handler_->ClearCurrentInputs(cc_); } +// This test checks that the state is ReadyForClose after all streams reach +// Timestamp::Max. +TEST_F(ImmediateInputStreamHandlerTest, ReadyForCloseAfterTimestampMax) { + Timestamp min_stream_timestamp; + std::list packets; + + // One packet arrives, ready for process. + packets.push_back(Adopt(new std::string("packet 1")).At(Timestamp(10))); + input_stream_handler_->AddPackets(name_to_id_["input_a"], packets); + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp(10), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); + + // No packets arrive, not ready. + EXPECT_FALSE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Unset(), cc_->InputTimestamp()); + + // Timestamp::Max arrives, ready for close. + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_a"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_b"], Timestamp::Max().NextAllowedInStream()); + input_stream_handler_->SetNextTimestampBound( + name_to_id_["input_c"], Timestamp::Max().NextAllowedInStream()); + + EXPECT_TRUE(input_stream_handler_->ScheduleInvocations( + /*max_allowance=*/1, &min_stream_timestamp)); + EXPECT_EQ(Timestamp::Done(), cc_->InputTimestamp()); + input_stream_handler_->FinalizeInputSet(cc_->InputTimestamp(), + &cc_->Inputs()); + input_stream_handler_->ClearCurrentInputs(cc_); +} + // This test checks that when any stream is done, the state is ready to close. TEST_F(ImmediateInputStreamHandlerTest, ReadyForClose) { Timestamp min_stream_timestamp; From 40eb0e63858bd6c8746f4d5127a76ebef1f71cf7 Mon Sep 17 00:00:00 2001 From: Camillo Lugaresi Date: Thu, 1 Dec 2022 12:57:07 -0800 Subject: [PATCH 130/137] Internal change PiperOrigin-RevId: 492276913 --- mediapipe/gpu/multi_pool.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mediapipe/gpu/multi_pool.h b/mediapipe/gpu/multi_pool.h index 8a3cf6be0..e677c3bbf 100644 --- a/mediapipe/gpu/multi_pool.h +++ b/mediapipe/gpu/multi_pool.h @@ -59,6 +59,8 @@ class MultiPool { MultiPool(SimplePoolFactory factory = DefaultMakeSimplePool, MultiPoolOptions options = kDefaultMultiPoolOptions) : create_simple_pool_(factory), options_(options) {} + explicit MultiPool(MultiPoolOptions options) + : MultiPool(DefaultMakeSimplePool, options) {} // Obtains an item. May either be reused or created anew. Item Get(const Spec& spec); From fd79f18aeb41d78966a91dbd38107534c3fb29e8 Mon Sep 17 00:00:00 2001 From: Khanh LeViet Date: Thu, 1 Dec 2022 14:13:01 -0800 Subject: [PATCH 131/137] Make BaseOptions to pass absolute path to C++ layer. PiperOrigin-RevId: 492296573 --- mediapipe/tasks/python/core/base_options.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mediapipe/tasks/python/core/base_options.py b/mediapipe/tasks/python/core/base_options.py index 122dc620f..b48fa2ccc 100644 --- a/mediapipe/tasks/python/core/base_options.py +++ b/mediapipe/tasks/python/core/base_options.py @@ -14,6 +14,7 @@ """Base options for MediaPipe Task APIs.""" import dataclasses +import os from typing import Any, Optional from mediapipe.tasks.cc.core.proto import base_options_pb2 @@ -49,10 +50,14 @@ class BaseOptions: @doc_controls.do_not_generate_docs def to_pb2(self) -> _BaseOptionsProto: """Generates a BaseOptions protobuf object.""" + if self.model_asset_path is not None: + full_path = os.path.abspath(self.model_asset_path) + else: + full_path = None + return _BaseOptionsProto( model_asset=_ExternalFileProto( - file_name=self.model_asset_path, - file_content=self.model_asset_buffer)) + file_name=full_path, file_content=self.model_asset_buffer)) @classmethod @doc_controls.do_not_generate_docs From af990c3da1633f164ccf8f75edb0683079b0c005 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 14:58:30 -0800 Subject: [PATCH 132/137] Open up the visibility of "//mediapipe/java/com/google/mediapipe/framework/image:image". PiperOrigin-RevId: 492308109 --- mediapipe/java/com/google/mediapipe/framework/image/BUILD | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index bb3be318d..d9508c1f7 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -20,9 +20,7 @@ android_library( name = "image", srcs = glob(["*.java"]), manifest = "AndroidManifest.xml", - visibility = [ - "//mediapipe:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ "//third_party:androidx_legacy_support_v4", "//third_party:autovalue", From ead41132a856379a9a7d22f29abe471dc11f2b4a Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 1 Dec 2022 15:00:00 -0800 Subject: [PATCH 133/137] Load model file content from model file path with the help of GetResourceContents in browsers. This can handle the model files that are provided via a custom ResourceProviderFn. PiperOrigin-RevId: 492308453 --- mediapipe/tasks/cc/core/model_resources.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 618761f32..d5c12ee95 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -99,11 +99,21 @@ const tflite::Model* ModelResources::GetTfLiteModel() const { absl::Status ModelResources::BuildModelFromExternalFileProto() { if (model_file_->has_file_name()) { +#ifdef __EMSCRIPTEN__ + // In browsers, the model file may require a custom ResourceProviderFn to + // provide the model content. The open() method may not work in this case. + // Thus, loading the model content from the model file path in advance with + // the help of GetResourceContents. + MP_RETURN_IF_ERROR(mediapipe::GetResourceContents( + model_file_->file_name(), model_file_->mutable_file_content())); + model_file_->clear_file_name(); +#else // If the model file name is a relative path, searches the file in a // platform-specific location and returns the absolute path on success. ASSIGN_OR_RETURN(std::string path_to_resource, mediapipe::PathToResourceAsFile(model_file_->file_name())); model_file_->set_file_name(path_to_resource); +#endif // __EMSCRIPTEN__ } ASSIGN_OR_RETURN( model_file_handler_, From 768d2dc548f123246d34fe258d6ab75d05c51d3e Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Thu, 1 Dec 2022 16:47:05 -0800 Subject: [PATCH 134/137] Separate web and java api landmark and world landmark to two classes. This makes the platforms interface consistent. PiperOrigin-RevId: 492332990 --- .../tasks/components/containers/BUILD | 9 +++ .../tasks/components/containers/Landmark.java | 20 +++--- .../containers/NormalizedLandmark.java | 63 +++++++++++++++++++ .../com/google/mediapipe/tasks/vision/BUILD | 2 + .../GestureRecognizerResult.java | 45 ++++++------- .../handlandmarker/HandLandmarkerResult.java | 52 +++++++-------- .../GestureRecognizerTest.java | 4 +- .../handlandmarker/HandLandmarkerTest.java | 4 +- .../web/components/containers/landmark.d.ts | 28 ++++++--- .../gesture_recognizer/gesture_recognizer.ts | 12 ++-- .../gesture_recognizer_result.d.ts | 4 +- .../vision/hand_landmarker/hand_landmarker.ts | 10 ++- .../hand_landmarker_result.d.ts | 4 +- 13 files changed, 161 insertions(+), 96 deletions(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index d6e6ac740..ad17d5552 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -83,6 +83,15 @@ android_library( ], ) +android_library( + name = "normalized_landmark", + srcs = ["NormalizedLandmark.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + # Expose the java source files for building mediapipe tasks core AAR. filegroup( name = "java_src", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index e45866190..7fb1b99d0 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -18,16 +18,16 @@ import com.google.auto.value.AutoValue; import java.util.Objects; /** - * Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the - * landmark coordinates is normalized respect to the dimension of image, and the coordinates values - * are in the range of [0,1]. Otherwise, it represenet a point in world coordinates. + * Landmark represents a point in 3D space with x, y, z coordinates. The landmark coordinates are in + * meters. z represents the landmark depth, and the smaller the value the closer the world landmark + * is to the camera. */ @AutoValue public abstract class Landmark { private static final float TOLERANCE = 1e-6f; - public static Landmark create(float x, float y, float z, boolean normalized) { - return new AutoValue_Landmark(x, y, z, normalized); + public static Landmark create(float x, float y, float z) { + return new AutoValue_Landmark(x, y, z); } // The x coordinates of the landmark. @@ -39,28 +39,24 @@ public abstract class Landmark { // The z coordinates of the landmark. public abstract float z(); - // Whether this landmark is normalized with respect to the image size. - public abstract boolean normalized(); - @Override public final boolean equals(Object o) { if (!(o instanceof Landmark)) { return false; } Landmark other = (Landmark) o; - return other.normalized() == this.normalized() - && Math.abs(other.x() - this.x()) < TOLERANCE + return Math.abs(other.x() - this.x()) < TOLERANCE && Math.abs(other.x() - this.y()) < TOLERANCE && Math.abs(other.x() - this.z()) < TOLERANCE; } @Override public final int hashCode() { - return Objects.hash(x(), y(), z(), normalized()); + return Objects.hash(x(), y(), z()); } @Override public final String toString() { - return ""; + return ""; } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java new file mode 100644 index 000000000..e77f3c3d4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/NormalizedLandmark.java @@ -0,0 +1,63 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. x and y are + * normalized to [0.0, 1.0] by the image width and height respectively. z represents the landmark + * depth, and the smaller the value the closer the landmark is to the camera. The magnitude of z + * uses roughly the same scale as x. + */ +@AutoValue +public abstract class NormalizedLandmark { + private static final float TOLERANCE = 1e-6f; + + public static NormalizedLandmark create(float x, float y, float z) { + return new AutoValue_NormalizedLandmark(x, y, z); + } + + // The x coordinates of the normalized landmark. + public abstract float x(); + + // The y coordinates of the normalized landmark. + public abstract float y(); + + // The z coordinates of the normalized landmark. + public abstract float z(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof NormalizedLandmark)) { + return false; + } + NormalizedLandmark other = (NormalizedLandmark) o; + return Math.abs(other.x() - this.x()) < TOLERANCE + && Math.abs(other.x() - this.y()) < TOLERANCE + && Math.abs(other.x() - this.z()) < TOLERANCE; + } + + @Override + public final int hashCode() { + return Objects.hash(x(), y(), z()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index b61c174fe..6161fe032 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -135,6 +135,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", @@ -167,6 +168,7 @@ android_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", "@maven//:androidx_annotation_annotation", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java index ef76bf226..90b92175d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -43,41 +42,36 @@ public abstract class GestureRecognizerResult implements TaskResult { * @param gesturesProto a List of {@link ClassificationList} */ static GestureRecognizerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, List gesturesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); List> multiHandGestures = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + com.google.mediapipe.tasks.components.containers.NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -118,11 +112,10 @@ public abstract class GestureRecognizerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java index 2889b0e0b..9092c0a2d 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerResult.java @@ -15,13 +15,12 @@ package com.google.mediapipe.tasks.vision.handlandmarker; import com.google.auto.value.AutoValue; -import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; -import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; -import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto; import com.google.mediapipe.formats.proto.ClassificationProto.Classification; import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.core.TaskResult; import java.util.ArrayList; import java.util.Collections; @@ -32,47 +31,41 @@ import java.util.List; public abstract class HandLandmarkerResult implements TaskResult { /** - * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and - * handedness protobuf messages. + * Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and handedness + * protobuf messages. * * @param landmarksProto a List of {@link NormalizedLandmarkList} * @param worldLandmarksProto a List of {@link LandmarkList} * @param handednessesProto a List of {@link ClassificationList} */ static HandLandmarkerResult create( - List landmarksProto, - List worldLandmarksProto, + List landmarksProto, + List worldLandmarksProto, List handednessesProto, long timestampMs) { - List> multiHandLandmarks = - new ArrayList<>(); - List> multiHandWorldLandmarks = - new ArrayList<>(); + List> multiHandLandmarks = new ArrayList<>(); + List> multiHandWorldLandmarks = new ArrayList<>(); List> multiHandHandednesses = new ArrayList<>(); - for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { - List handLandmarks = - new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = new ArrayList<>(); multiHandLandmarks.add(handLandmarks); - for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + for (LandmarkProto.NormalizedLandmark handLandmarkProto : + handLandmarksProto.getLandmarkList()) { handLandmarks.add( - com.google.mediapipe.tasks.components.containers.Landmark.create( - handLandmarkProto.getX(), - handLandmarkProto.getY(), - handLandmarkProto.getZ(), - true)); + NormalizedLandmark.create( + handLandmarkProto.getX(), handLandmarkProto.getY(), handLandmarkProto.getZ())); } } - for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { - List handWorldLandmarks = - new ArrayList<>(); + for (LandmarkProto.LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + List handWorldLandmarks = new ArrayList<>(); multiHandWorldLandmarks.add(handWorldLandmarks); - for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + for (LandmarkProto.Landmark handWorldLandmarkProto : + handWorldLandmarksProto.getLandmarkList()) { handWorldLandmarks.add( com.google.mediapipe.tasks.components.containers.Landmark.create( handWorldLandmarkProto.getX(), handWorldLandmarkProto.getY(), - handWorldLandmarkProto.getZ(), - false)); + handWorldLandmarkProto.getZ())); } } for (ClassificationList handednessProto : handednessesProto) { @@ -98,11 +91,10 @@ public abstract class HandLandmarkerResult implements TaskResult { public abstract long timestampMs(); /** Hand landmarks of detected hands. */ - public abstract List> landmarks(); + public abstract List> landmarks(); /** Hand landmarks in world coordniates of detected hands. */ - public abstract List> - worldLandmarks(); + public abstract List> worldLandmarks(); /** Handedness of detected hands. */ public abstract List> handednesses(); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java index c0be4cffe..5821b36cc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizerTest.java @@ -28,7 +28,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.components.processors.ClassifierOptions; import com.google.mediapipe.tasks.core.BaseOptions; @@ -603,7 +603,7 @@ public class GestureRecognizerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java index 9e12d210f..c313d385d 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/handlandmarker/HandLandmarkerTest.java @@ -27,7 +27,7 @@ import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.MPImage; import com.google.mediapipe.tasks.components.containers.Category; -import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; @@ -399,7 +399,7 @@ public class HandLandmarkerTest { assertThat(actualResult.landmarks().get(0)) .comparingElementsUsing( Correspondence.from( - (Correspondence.BinaryPredicate) + (Correspondence.BinaryPredicate) (actual, expected) -> { return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) .compare(actual.x(), expected.x()) diff --git a/mediapipe/tasks/web/components/containers/landmark.d.ts b/mediapipe/tasks/web/components/containers/landmark.d.ts index c887303d0..0f916bf88 100644 --- a/mediapipe/tasks/web/components/containers/landmark.d.ts +++ b/mediapipe/tasks/web/components/containers/landmark.d.ts @@ -15,10 +15,27 @@ */ /** - * Landmark represents a point in 3D space with x, y, z coordinates. If - * normalized is true, the landmark coordinates is normalized respect to the - * dimension of image, and the coordinates values are in the range of [0,1]. - * Otherwise, it represenet a point in world coordinates. + * Normalized Landmark represents a point in 3D space with x, y, z coordinates. + * x and y are normalized to [0.0, 1.0] by the image width and height + * respectively. z represents the landmark depth, and the smaller the value the + * closer the landmark is to the camera. The magnitude of z uses roughly the + * same scale as x. + */ +export declare interface NormalizedLandmark { + /** The x coordinates of the normalized landmark. */ + x: number; + + /** The y coordinates of the normalized landmark. */ + y: number; + + /** The z coordinates of the normalized landmark. */ + z: number; +} + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. The + * landmark coordinates are in meters. z represents the landmark depth, + * and the smaller the value the closer the world landmark is to the camera. */ export declare interface Landmark { /** The x coordinates of the landmark. */ @@ -29,7 +46,4 @@ export declare interface Landmark { /** The z coordinates of the landmark. */ z: number; - - /** Whether this landmark is normalized with respect to the image size. */ - normalized: boolean; } diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 9ec63b07a..15b6acb1a 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -27,7 +27,7 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -67,7 +67,7 @@ FULL_IMAGE_RECT.setHeight(1); export class GestureRecognizer extends VisionTaskRunner { private gestures: Category[][] = []; - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -306,13 +306,12 @@ export class GestureRecognizer extends for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, - z: handLandmarkProto.getZ() ?? 0, - normalized: true + z: handLandmarkProto.getZ() ?? 0 }); } this.landmarks.push(landmarks); @@ -333,8 +332,7 @@ export class GestureRecognizer extends worldLandmarks.push({ x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, - z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false + z: handWorldLandmarkProto.getZ() ?? 0 }); } this.worldLandmarks.push(worldLandmarks); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts index 7c295c9e9..e570270b2 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the gesture recognition results generated by `GestureRecognizer`. */ export declare interface GestureRecognizerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 290f49455..c657275bf 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -24,7 +24,7 @@ import {HandDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_detecto import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options_pb'; import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; @@ -59,7 +59,7 @@ FULL_IMAGE_RECT.setHeight(1); /** Performs hand landmarks detection on images. */ export class HandLandmarker extends VisionTaskRunner { - private landmarks: Landmark[][] = []; + private landmarks: NormalizedLandmark[][] = []; private worldLandmarks: Landmark[][] = []; private handednesses: Category[][] = []; @@ -255,13 +255,12 @@ export class HandLandmarker extends VisionTaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: Landmark[] = []; + const landmarks: NormalizedLandmark[] = []; for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { landmarks.push({ x: handLandmarkProto.getX() ?? 0, y: handLandmarkProto.getY() ?? 0, z: handLandmarkProto.getZ() ?? 0, - normalized: true }); } this.landmarks.push(landmarks); @@ -269,7 +268,7 @@ export class HandLandmarker extends VisionTaskRunner { } /** - * Converts raw data into a landmark, and adds it to our worldLandmarks + * Converts raw data into a world landmark, and adds it to our worldLandmarks * list. */ private adddJsWorldLandmarks(data: Uint8Array[]): void { @@ -283,7 +282,6 @@ export class HandLandmarker extends VisionTaskRunner { x: handWorldLandmarkProto.getX() ?? 0, y: handWorldLandmarkProto.getY() ?? 0, z: handWorldLandmarkProto.getZ() ?? 0, - normalized: false }); } this.worldLandmarks.push(worldLandmarks); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 044bdfbe7..89f867d69 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -15,14 +15,14 @@ */ import {Category} from '../../../../tasks/web/components/containers/category'; -import {Landmark} from '../../../../tasks/web/components/containers/landmark'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; /** * Represents the hand landmarks deection results generated by `HandLandmarker`. */ export declare interface HandLandmarkerResult { /** Hand landmarks of detected hands. */ - landmarks: Landmark[][]; + landmarks: NormalizedLandmark[][]; /** Hand landmarks in world coordniates of detected hands. */ worldLandmarks: Landmark[][]; From dabc2af15baad67d92ac5e9d1b2b2a588167664f Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 12:04:06 -0800 Subject: [PATCH 135/137] Fix base bath loading in Fileset resolver PiperOrigin-RevId: 492526041 --- mediapipe/tasks/web/core/fileset_resolver.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mediapipe/tasks/web/core/fileset_resolver.ts b/mediapipe/tasks/web/core/fileset_resolver.ts index 7d68dbc16..d4691243b 100644 --- a/mediapipe/tasks/web/core/fileset_resolver.ts +++ b/mediapipe/tasks/web/core/fileset_resolver.ts @@ -48,16 +48,16 @@ async function createFileset( if (await isSimdSupported()) { return { wasmLoaderPath: - `/${basePath}/${taskName}_wasm_internal.js`, + `${basePath}/${taskName}_wasm_internal.js`, wasmBinaryPath: - `/${basePath}/${taskName}_wasm_internal.wasm`, + `${basePath}/${taskName}_wasm_internal.wasm`, }; } else { return { wasmLoaderPath: - `/${basePath}/${taskName}_wasm_nosimd_internal.js`, - wasmBinaryPath: `/${basePath}/${ - taskName}_wasm_nosimd_internal.wasm`, + `${basePath}/${taskName}_wasm_nosimd_internal.js`, + wasmBinaryPath: + `${basePath}/${taskName}_wasm_nosimd_internal.wasm`, }; } } From da9587033d118eb58672f25c8f2e541ba7037209 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 12:40:59 -0800 Subject: [PATCH 136/137] Move shared code to TaskRunner PiperOrigin-RevId: 492534879 --- .../tasks/web/audio/audio_classifier/BUILD | 3 +- .../audio_classifier/audio_classifier.ts | 38 ++++++++------ .../audio_classifier_options.d.ts | 4 +- .../tasks/web/audio/audio_embedder/BUILD | 1 - .../audio/audio_embedder/audio_embedder.ts | 48 ++++++++--------- .../audio_embedder_options.d.ts | 4 +- mediapipe/tasks/web/audio/core/BUILD | 13 +---- .../web/audio/core/audio_task_options.d.ts | 23 --------- .../tasks/web/audio/core/audio_task_runner.ts | 17 +------ .../tasks/web/components/processors/BUILD | 1 - .../web/components/processors/base_options.ts | 2 +- mediapipe/tasks/web/core/BUILD | 8 +-- .../tasks/web/core/classifier_options.d.ts | 2 - .../tasks/web/core/embedder_options.d.ts | 2 - mediapipe/tasks/web/core/task_runner.ts | 43 ++++++++++------ ..._options.d.ts => task_runner_options.d.ts} | 8 ++- mediapipe/tasks/web/text/core/BUILD | 11 ---- .../web/text/core/text_task_options.d.ts | 23 --------- .../tasks/web/text/text_classifier/BUILD | 5 +- .../text/text_classifier/text_classifier.ts | 51 +++++++++++-------- .../text_classifier_options.d.ts | 4 +- mediapipe/tasks/web/text/text_embedder/BUILD | 4 +- .../web/text/text_embedder/text_embedder.ts | 51 +++++++++++-------- .../text_embedder/text_embedder_options.d.ts | 4 +- mediapipe/tasks/web/vision/core/BUILD | 2 - .../web/vision/core/vision_task_options.d.ts | 8 +-- .../web/vision/core/vision_task_runner.ts | 15 ++---- .../gesture_recognizer/gesture_recognizer.ts | 30 +++++------ .../vision/hand_landmarker/hand_landmarker.ts | 30 +++++------ .../image_classifier/image_classifier.ts | 38 ++++++++------ .../vision/image_embedder/image_embedder.ts | 38 ++++++++------ .../vision/object_detector/object_detector.ts | 36 +++++++------ 32 files changed, 262 insertions(+), 305 deletions(-) delete mode 100644 mediapipe/tasks/web/audio/core/audio_task_options.d.ts rename mediapipe/tasks/web/core/{base_options.d.ts => task_runner_options.d.ts} (85%) delete mode 100644 mediapipe/tasks/web/text/core/BUILD delete mode 100644 mediapipe/tasks/web/text/core/text_task_options.d.ts diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index c419d3b98..6f785dd0d 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -25,7 +25,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -36,7 +36,6 @@ mediapipe_ts_declaration( "audio_classifier_result.d.ts", ], deps = [ - "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/core", diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index e606019f2..4e12780d2 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -22,8 +22,8 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {AudioTaskRunner} from '../../../../tasks/web/audio/core/audio_task_runner'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; -import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; @@ -56,13 +56,12 @@ export class AudioClassifier extends AudioTaskRunner { * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, audioClassifierOptions: AudioClassifierOptions): Promise { - const classifier = await TaskRunner.createInstance( - AudioClassifier, /* initializeCanvas= */ false, wasmFileset); - await classifier.setOptions(audioClassifierOptions); - return classifier; + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + audioClassifierOptions); } /** @@ -75,8 +74,9 @@ export class AudioClassifier extends AudioTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -86,20 +86,26 @@ export class AudioClassifier extends AudioTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts index 975b1e315..dc3c494bf 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_options.d.ts @@ -14,9 +14,9 @@ * limitations under the License. */ -import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Audio Classifier Task */ export declare interface AudioClassifierOptions extends ClassifierOptions, - AudioTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/audio_embedder/BUILD b/mediapipe/tasks/web/audio/audio_embedder/BUILD index 1a66464bd..0555bb639 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/BUILD +++ b/mediapipe/tasks/web/audio/audio_embedder/BUILD @@ -36,7 +36,6 @@ mediapipe_ts_declaration( "audio_embedder_result.d.ts", ], deps = [ - "//mediapipe/tasks/web/audio/core:audio_task_options", "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index c87aceabe..d08eb4791 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -25,7 +25,7 @@ import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/pr import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; -import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/graph_runner'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {AudioEmbedderOptions} from './audio_embedder_options'; @@ -58,23 +58,12 @@ export class AudioEmbedder extends AudioTaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, audioEmbedderOptions: AudioEmbedderOptions): Promise { - // Create a file locator based on the loader options - const fileLocator: FileLocator = { - locateFile() { - // The only file we load is the Wasm binary - return wasmFileset.wasmBinaryPath.toString(); - } - }; - - const embedder = await createMediaPipeLib( - AudioEmbedder, wasmFileset.wasmLoaderPath, - /* assetLoaderScript= */ undefined, - /* glCanvas= */ undefined, fileLocator); - await embedder.setOptions(audioEmbedderOptions); - return embedder; + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + audioEmbedderOptions); } /** @@ -87,8 +76,9 @@ export class AudioEmbedder extends AudioTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return AudioEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -98,20 +88,26 @@ export class AudioEmbedder extends AudioTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return AudioEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return AudioTaskRunner.createInstance( + AudioEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts index 98f412d0f..ac22728ab 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_options.d.ts @@ -14,9 +14,9 @@ * limitations under the License. */ -import {AudioTaskOptions} from '../../../../tasks/web/audio/core/audio_task_options'; import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Audio Embedder Task */ export declare interface AudioEmbedderOptions extends EmbedderOptions, - AudioTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/audio/core/BUILD b/mediapipe/tasks/web/audio/core/BUILD index 91ebbf524..9ab6c7bee 100644 --- a/mediapipe/tasks/web/audio/core/BUILD +++ b/mediapipe/tasks/web/audio/core/BUILD @@ -1,24 +1,13 @@ # This package contains options shared by all MediaPipe Audio Tasks for Web. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") package(default_visibility = ["//mediapipe/tasks:internal"]) -mediapipe_ts_declaration( - name = "audio_task_options", - srcs = ["audio_task_options.d.ts"], - deps = [ - "//mediapipe/tasks/web/core", - ], -) - mediapipe_ts_library( name = "audio_task_runner", srcs = ["audio_task_runner.ts"], deps = [ - ":audio_task_options", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", ], diff --git a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts b/mediapipe/tasks/web/audio/core/audio_task_options.d.ts deleted file mode 100644 index e3068625d..000000000 --- a/mediapipe/tasks/web/audio/core/audio_task_options.d.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** The options for configuring a MediaPipe Audio Task. */ -export declare interface AudioTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; -} diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index ceff3895b..00cfe0253 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -14,26 +14,13 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; - -import {AudioTaskOptions} from './audio_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Base class for all MediaPipe Audio Tasks. */ -export abstract class AudioTaskRunner extends TaskRunner { - protected abstract baseOptions?: BaseOptionsProto|undefined; +export abstract class AudioTaskRunner extends TaskRunner { private defaultSampleRate = 48000; - /** Configures the shared options of an audio task. */ - async setOptions(options: AudioTaskOptions): Promise { - this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); - } - } - /** * Sets the sample rate for API calls that omit an explicit sample rate. * `48000` is used as a default if this method is not called. diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index 1b56bf4c9..86e743928 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -17,7 +17,6 @@ mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ - "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:classification_result", ], diff --git a/mediapipe/tasks/web/components/processors/base_options.ts b/mediapipe/tasks/web/components/processors/base_options.ts index ac24a8db6..16d562262 100644 --- a/mediapipe/tasks/web/components/processors/base_options.ts +++ b/mediapipe/tasks/web/components/processors/base_options.ts @@ -18,7 +18,7 @@ import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inferen import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb'; -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {BaseOptions} from '../../../../tasks/web/core/task_runner_options'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index d709e3409..de429690d 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -7,18 +7,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) mediapipe_ts_declaration( name = "core", srcs = [ - "base_options.d.ts", + "task_runner_options.d.ts", "wasm_fileset.d.ts", ], ) mediapipe_ts_library( name = "task_runner", - srcs = [ - "task_runner.ts", - ], + srcs = ["task_runner.ts"], deps = [ ":core", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", "//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:register_model_resources_graph_service_ts", diff --git a/mediapipe/tasks/web/core/classifier_options.d.ts b/mediapipe/tasks/web/core/classifier_options.d.ts index 1d804d629..08e7a7664 100644 --- a/mediapipe/tasks/web/core/classifier_options.d.ts +++ b/mediapipe/tasks/web/core/classifier_options.d.ts @@ -14,8 +14,6 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - /** Options to configure a MediaPipe Classifier Task. */ export declare interface ClassifierOptions { /** diff --git a/mediapipe/tasks/web/core/embedder_options.d.ts b/mediapipe/tasks/web/core/embedder_options.d.ts index 3ec2a170c..8669acfcb 100644 --- a/mediapipe/tasks/web/core/embedder_options.d.ts +++ b/mediapipe/tasks/web/core/embedder_options.d.ts @@ -14,8 +14,6 @@ * limitations under the License. */ -import {BaseOptions} from '../../../tasks/web/core/base_options'; - /** Options to configure a MediaPipe Embedder Task */ export declare interface EmbedderOptions { /** diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 4085be697..c2691fc76 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -14,6 +14,9 @@ * limitations under the License. */ +import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; +import {convertBaseOptionsToProto} from '../../../tasks/web/components/processors/base_options'; +import {TaskRunnerOptions} from '../../../tasks/web/core/task_runner_options'; import {createMediaPipeLib, FileLocator, GraphRunner, WasmMediaPipeConstructor, WasmModule} from '../../../web/graph_runner/graph_runner'; import {SupportImage} from '../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../web/graph_runner/register_model_resources_graph_service'; @@ -28,7 +31,9 @@ const WasmMediaPipeImageLib = SupportModelResourcesGraphService(SupportImage(GraphRunner)); /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends WasmMediaPipeImageLib { +export abstract class TaskRunner extends + WasmMediaPipeImageLib { + protected abstract baseOptions: BaseOptionsProto; private processingErrors: Error[] = []; /** @@ -36,9 +41,10 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { * supported and loads the relevant WASM binary. * @return A fully instantiated instance of `T`. */ - protected static async createInstance( + protected static async createInstance, + O extends TaskRunnerOptions>( type: WasmMediaPipeConstructor, initializeCanvas: boolean, - fileset: WasmFileset): Promise { + fileset: WasmFileset, options: O): Promise { const fileLocator: FileLocator = { locateFile() { // The only file loaded with this mechanism is the Wasm binary @@ -46,19 +52,16 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { } }; - if (initializeCanvas) { - // Fall back to an OffscreenCanvas created by the GraphRunner if - // OffscreenCanvas is available - const canvas = typeof OffscreenCanvas === 'undefined' ? - document.createElement('canvas') : - undefined; - return createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); - } else { - return createMediaPipeLib( - type, fileset.wasmLoaderPath, NO_ASSETS, /* glCanvas= */ null, - fileLocator); - } + // Initialize a canvas if requested. If OffscreenCanvas is availble, we + // let the graph runner initialize it by passing `undefined`. + const canvas = initializeCanvas ? (typeof OffscreenCanvas === 'undefined' ? + document.createElement('canvas') : + undefined) : + null; + const instance = await createMediaPipeLib( + type, fileset.wasmLoaderPath, NO_ASSETS, canvas, fileLocator); + await instance.setOptions(options); + return instance; } constructor( @@ -74,6 +77,14 @@ export abstract class TaskRunner extends WasmMediaPipeImageLib { this.registerModelResourcesGraphService(); } + /** Configures the shared options of a MediaPipe Task. */ + async setOptions(options: O): Promise { + if (options.baseOptions) { + this.baseOptions = await convertBaseOptionsToProto( + options.baseOptions, this.baseOptions); + } + } + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, diff --git a/mediapipe/tasks/web/core/base_options.d.ts b/mediapipe/tasks/web/core/task_runner_options.d.ts similarity index 85% rename from mediapipe/tasks/web/core/base_options.d.ts rename to mediapipe/tasks/web/core/task_runner_options.d.ts index 86635b8c7..aa0b4a028 100644 --- a/mediapipe/tasks/web/core/base_options.d.ts +++ b/mediapipe/tasks/web/core/task_runner_options.d.ts @@ -16,7 +16,7 @@ // Placeholder for internal dependency on trusted resource url -/** Options to configure MediaPipe Tasks in general. */ +/** Options to configure MediaPipe model loading and processing. */ export declare interface BaseOptions { /** * The model path to the model asset file. Only one of `modelAssetPath` or @@ -33,3 +33,9 @@ export declare interface BaseOptions { /** Overrides the default backend to use for the provided model. */ delegate?: 'cpu'|'gpu'|undefined; } + +/** Options to configure MediaPipe Tasks in general. */ +export declare interface TaskRunnerOptions { + /** Options to configure the loading of the model assets. */ + baseOptions?: BaseOptions; +} diff --git a/mediapipe/tasks/web/text/core/BUILD b/mediapipe/tasks/web/text/core/BUILD deleted file mode 100644 index 3e7faec93..000000000 --- a/mediapipe/tasks/web/text/core/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -# This package contains options shared by all MediaPipe Texxt Tasks for Web. - -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration") - -package(default_visibility = ["//mediapipe/tasks:internal"]) - -mediapipe_ts_declaration( - name = "text_task_options", - srcs = ["text_task_options.d.ts"], - deps = ["//mediapipe/tasks/web/core"], -) diff --git a/mediapipe/tasks/web/text/core/text_task_options.d.ts b/mediapipe/tasks/web/text/core/text_task_options.d.ts deleted file mode 100644 index 4874e35bf..000000000 --- a/mediapipe/tasks/web/text/core/text_task_options.d.ts +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2022 The MediaPipe Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import {BaseOptions} from '../../../../tasks/web/core/base_options'; - -/** The options for configuring a MediaPipe Text task. */ -export declare interface TextTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; -} diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index f3d272daa..2a7de21d6 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -17,15 +17,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -38,7 +39,7 @@ mediapipe_ts_declaration( deps = [ "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:classification_result", + "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", - "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index 197869a36..bd2a207ce 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -17,12 +17,13 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextClassifierGraphOptions} from '../../../../tasks/cc/text/text_classifier/proto/text_classifier_graph_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/processors/classifier_options'; import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; @@ -40,7 +41,7 @@ const TEXT_CLASSIFIER_GRAPH = // tslint:disable:jspb-use-builder-pattern /** Performs Natural Language classification. */ -export class TextClassifier extends TaskRunner { +export class TextClassifier extends TaskRunner { private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); @@ -53,13 +54,12 @@ export class TextClassifier extends TaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, textClassifierOptions: TextClassifierOptions): Promise { - const classifier = await TaskRunner.createInstance( - TextClassifier, /* initializeCanvas= */ false, wasmFileset); - await classifier.setOptions(textClassifierOptions); - return classifier; + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + textClassifierOptions); } /** @@ -72,8 +72,9 @@ export class TextClassifier extends TaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -83,13 +84,19 @@ export class TextClassifier extends TaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextClassifier, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -101,18 +108,20 @@ export class TextClassifier extends TaskRunner { * * @param options The options for the text classifier. */ - async setOptions(options: TextClassifierOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } - + override async setOptions(options: TextClassifierOptions): Promise { + await super.setOptions(options); this.options.setClassifierOptions(convertClassifierOptionsToProto( options, this.options.getClassifierOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } /** * Performs Natural Language classification on the provided text and waits diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts index b50767e1a..25592deb5 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_options.d.ts @@ -15,8 +15,8 @@ */ import {ClassifierOptions} from '../../../../tasks/web/core/classifier_options'; -import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Text Classifier Task */ export declare interface TextClassifierOptions extends ClassifierOptions, - TextTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/text/text_embedder/BUILD b/mediapipe/tasks/web/text/text_embedder/BUILD index b858f6b83..17d105258 100644 --- a/mediapipe/tasks/web/text/text_embedder/BUILD +++ b/mediapipe/tasks/web/text/text_embedder/BUILD @@ -17,15 +17,16 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/text/text_embedder/proto:text_embedder_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:embedding_result", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:embedder_options", "//mediapipe/tasks/web/components/processors:embedder_result", "//mediapipe/tasks/web/components/utils:cosine_similarity", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", "//mediapipe/tasks/web/core:task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", ], ) @@ -39,6 +40,5 @@ mediapipe_ts_declaration( "//mediapipe/tasks/web/components/containers:embedding_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:embedder_options", - "//mediapipe/tasks/web/text/core:text_task_options", ], ) diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 511fd2411..d2899fbe2 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -17,14 +17,15 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {EmbeddingResult} from '../../../../tasks/cc/components/containers/proto/embeddings_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; import {TextEmbedderGraphOptions as TextEmbedderGraphOptionsProto} from '../../../../tasks/cc/text/text_embedder/proto/text_embedder_graph_options_pb'; import {Embedding} from '../../../../tasks/web/components/containers/embedding_result'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {convertEmbedderOptionsToProto} from '../../../../tasks/web/components/processors/embedder_options'; import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/processors/embedder_result'; import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {TextEmbedderOptions} from './text_embedder_options'; @@ -44,7 +45,7 @@ const TEXT_EMBEDDER_CALCULATOR = /** * Performs embedding extraction on text. */ -export class TextEmbedder extends TaskRunner { +export class TextEmbedder extends TaskRunner { private embeddingResult: TextEmbedderResult = {embeddings: []}; private readonly options = new TextEmbedderGraphOptionsProto(); @@ -57,13 +58,12 @@ export class TextEmbedder extends TaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, textEmbedderOptions: TextEmbedderOptions): Promise { - const embedder = await TaskRunner.createInstance( - TextEmbedder, /* initializeCanvas= */ false, wasmFileset); - await embedder.setOptions(textEmbedderOptions); - return embedder; + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + textEmbedderOptions); } /** @@ -76,8 +76,9 @@ export class TextEmbedder extends TaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return TextEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,13 +88,19 @@ export class TextEmbedder extends TaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return TextEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return TaskRunner.createInstance( + TextEmbedder, /* initializeCanvas= */ false, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } /** @@ -105,17 +112,21 @@ export class TextEmbedder extends TaskRunner { * * @param options The options for the text embedder. */ - async setOptions(options: TextEmbedderOptions): Promise { - if (options.baseOptions) { - const baseOptionsProto = await convertBaseOptionsToProto( - options.baseOptions, this.options.getBaseOptions()); - this.options.setBaseOptions(baseOptionsProto); - } + override async setOptions(options: TextEmbedderOptions): Promise { + await super.setOptions(options); this.options.setEmbedderOptions(convertEmbedderOptionsToProto( options, this.options.getEmbedderOptions())); this.refreshGraph(); } + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + /** * Performs embeding extraction on the provided text and waits synchronously * for the response. diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts index 9ea570304..7689ee0c1 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder_options.d.ts @@ -15,8 +15,8 @@ */ import {EmbedderOptions} from '../../../../tasks/web/core/embedder_options'; -import {TextTaskOptions} from '../../../../tasks/web/text/core/text_task_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** Options to configure the MediaPipe Text Embedder Task */ export declare interface TextEmbedderOptions extends EmbedderOptions, - TextTaskOptions {} + TaskRunnerOptions {} diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 1d8944f14..b389a9b01 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -17,8 +17,6 @@ mediapipe_ts_library( srcs = ["vision_task_runner.ts"], deps = [ ":vision_task_options", - "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", - "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner", "//mediapipe/web/graph_runner:graph_runner_ts", diff --git a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts index e04eb6596..76c0177a0 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_options.d.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_options.d.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import {BaseOptions} from '../../../../tasks/web/core/base_options'; +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; /** * The two running modes of a vision task. @@ -23,12 +23,8 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options'; */ export type RunningMode = 'image'|'video'; - /** The options for configuring a MediaPipe vision task. */ -export declare interface VisionTaskOptions { - /** Options to configure the loading of the model assets. */ - baseOptions?: BaseOptions; - +export declare interface VisionTaskOptions extends TaskRunnerOptions { /** * The running mode of the task. Default to the image mode. * Vision tasks have two running modes: diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index 79ff45156..78b4859f2 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -14,24 +14,17 @@ * limitations under the License. */ -import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; -import {convertBaseOptionsToProto} from '../../../../tasks/web/components/processors/base_options'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {ImageSource} from '../../../../web/graph_runner/graph_runner'; import {VisionTaskOptions} from './vision_task_options'; /** Base class for all MediaPipe Vision Tasks. */ -export abstract class VisionTaskRunner extends TaskRunner { - protected abstract baseOptions?: BaseOptionsProto|undefined; - +export abstract class VisionTaskRunner extends + TaskRunner { /** Configures the shared options of a vision task. */ - async setOptions(options: VisionTaskOptions): Promise { - this.baseOptions = this.baseOptions ?? new BaseOptionsProto(); - if (options.baseOptions) { - this.baseOptions = await convertBaseOptionsToProto( - options.baseOptions, this.baseOptions); - } + override async setOptions(options: VisionTaskOptions): Promise { + await super.setOptions(options); if ('runningMode' in options) { const useStreamMode = !!options.runningMode && options.runningMode !== 'image'; diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 15b6acb1a..8baee5ce3 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -88,14 +88,13 @@ export class GestureRecognizer extends * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, gestureRecognizerOptions: GestureRecognizerOptions): Promise { - const recognizer = await VisionTaskRunner.createInstance( - GestureRecognizer, /* initializeCanvas= */ true, wasmFileset); - await recognizer.setOptions(gestureRecognizerOptions); - return recognizer; + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + gestureRecognizerOptions); } /** @@ -108,8 +107,9 @@ export class GestureRecognizer extends static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return GestureRecognizer.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -119,13 +119,12 @@ export class GestureRecognizer extends * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return GestureRecognizer.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + GestureRecognizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } constructor( @@ -134,6 +133,7 @@ export class GestureRecognizer extends super(wasmModule, glCanvas); this.options = new GestureRecognizerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarkerGraphOptions = new HandLandmarkerGraphOptions(); this.options.setHandLandmarkerGraphOptions(this.handLandmarkerGraphOptions); this.handLandmarksDetectorGraphOptions = @@ -151,11 +151,11 @@ export class GestureRecognizer extends this.initDefaults(); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index c657275bf..263ed4b48 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -77,13 +77,12 @@ export class HandLandmarker extends VisionTaskRunner { * Note that either a path to the model asset or a model buffer needs to * be provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, handLandmarkerOptions: HandLandmarkerOptions): Promise { - const landmarker = await VisionTaskRunner.createInstance( - HandLandmarker, /* initializeCanvas= */ true, wasmFileset); - await landmarker.setOptions(handLandmarkerOptions); - return landmarker; + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + handLandmarkerOptions); } /** @@ -96,8 +95,9 @@ export class HandLandmarker extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return HandLandmarker.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -107,13 +107,12 @@ export class HandLandmarker extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return HandLandmarker.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + HandLandmarker, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } constructor( @@ -122,6 +121,7 @@ export class HandLandmarker extends VisionTaskRunner { super(wasmModule, glCanvas); this.options = new HandLandmarkerGraphOptions(); + this.options.setBaseOptions(new BaseOptionsProto()); this.handLandmarksDetectorGraphOptions = new HandLandmarksDetectorGraphOptions(); this.options.setHandLandmarksDetectorGraphOptions( @@ -132,11 +132,11 @@ export class HandLandmarker extends VisionTaskRunner { this.initDefaults(); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 185ddf9ea..90dbf9798 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -23,7 +23,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {convertFromClassificationResultProto} from '../../../../tasks/web/components/processors/classifier_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; @@ -55,13 +55,12 @@ export class ImageClassifier extends VisionTaskRunner { * that either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, imageClassifierOptions: ImageClassifierOptions): Promise { - const classifier = await VisionTaskRunner.createInstance( - ImageClassifier, /* initializeCanvas= */ true, wasmFileset); - await classifier.setOptions(imageClassifierOptions); - return classifier; + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + imageClassifierOptions); } /** @@ -74,8 +73,9 @@ export class ImageClassifier extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageClassifier.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -85,20 +85,26 @@ export class ImageClassifier extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the model asset. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageClassifier.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageClassifier, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 91352e934..559332650 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -25,7 +25,7 @@ import {convertFromEmbeddingResultProto} from '../../../../tasks/web/components/ import {computeCosineSimilarity} from '../../../../tasks/web/components/utils/cosine_similarity'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageEmbedderOptions} from './image_embedder_options'; @@ -57,13 +57,12 @@ export class ImageEmbedder extends VisionTaskRunner { * either a path to the TFLite model or the model itself needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, imageEmbedderOptions: ImageEmbedderOptions): Promise { - const embedder = await VisionTaskRunner.createInstance( - ImageEmbedder, /* initializeCanvas= */ true, wasmFileset); - await embedder.setOptions(imageEmbedderOptions); - return embedder; + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + imageEmbedderOptions); } /** @@ -76,8 +75,9 @@ export class ImageEmbedder extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ImageEmbedder.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,20 +87,26 @@ export class ImageEmbedder extends VisionTaskRunner { * Wasm binary and its loader. * @param modelAssetPath The path to the TFLite model. */ - static async createFromModelPath( + static createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ImageEmbedder.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ImageEmbedder, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 7711c39e9..03171003f 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -21,7 +21,7 @@ import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/b import {ObjectDetectorOptions as ObjectDetectorOptionsProto} from '../../../../tasks/cc/vision/object_detector/proto/object_detector_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; -import {ImageSource} from '../../../../web/graph_runner/graph_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ObjectDetectorOptions} from './object_detector_options'; @@ -54,13 +54,12 @@ export class ObjectDetector extends VisionTaskRunner { * either a path to the model asset or a model buffer needs to be * provided (via `baseOptions`). */ - static async createFromOptions( + static createFromOptions( wasmFileset: WasmFileset, objectDetectorOptions: ObjectDetectorOptions): Promise { - const detector = await VisionTaskRunner.createInstance( - ObjectDetector, /* initializeCanvas= */ true, wasmFileset); - await detector.setOptions(objectDetectorOptions); - return detector; + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + objectDetectorOptions); } /** @@ -73,8 +72,9 @@ export class ObjectDetector extends VisionTaskRunner { static createFromModelBuffer( wasmFileset: WasmFileset, modelAssetBuffer: Uint8Array): Promise { - return ObjectDetector.createFromOptions( - wasmFileset, {baseOptions: {modelAssetBuffer}}); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); } /** @@ -87,17 +87,23 @@ export class ObjectDetector extends VisionTaskRunner { static async createFromModelPath( wasmFileset: WasmFileset, modelAssetPath: string): Promise { - const response = await fetch(modelAssetPath.toString()); - const graphData = await response.arrayBuffer(); - return ObjectDetector.createFromModelBuffer( - wasmFileset, new Uint8Array(graphData)); + return VisionTaskRunner.createInstance( + ObjectDetector, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); } - protected override get baseOptions(): BaseOptionsProto|undefined { - return this.options.getBaseOptions(); + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super(wasmModule, glCanvas); + this.options.setBaseOptions(new BaseOptionsProto()); } - protected override set baseOptions(proto: BaseOptionsProto|undefined) { + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { this.options.setBaseOptions(proto); } From e457039fc6350fbd2e75aa2d034f9b68af6d3410 Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 2 Dec 2022 16:16:34 -0800 Subject: [PATCH 137/137] Don't inherit from GraphRunner PiperOrigin-RevId: 492584486 --- .../audio_classifier/audio_classifier.ts | 9 +++-- .../audio/audio_embedder/audio_embedder.ts | 25 ++++++++------ mediapipe/tasks/web/core/task_runner.ts | 24 +++++++------- .../text/text_classifier/text_classifier.ts | 11 ++++--- .../web/text/text_embedder/text_embedder.ts | 4 +-- .../gesture_recognizer/gesture_recognizer.ts | 33 +++++++++++-------- .../vision/hand_landmarker/hand_landmarker.ts | 26 ++++++++------- .../image_classifier/image_classifier.ts | 11 ++++--- .../vision/image_embedder/image_embedder.ts | 4 +-- .../vision/object_detector/object_detector.ts | 9 ++--- .../graph_runner/graph_runner_image_lib.ts | 2 +- .../register_model_resources_graph_service.ts | 4 +-- 12 files changed, 92 insertions(+), 70 deletions(-) diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 4e12780d2..265ba2b33 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -145,8 +145,11 @@ export class AudioClassifier extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioClassifierResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.classificationResults = []; this.finishProcessing(); @@ -189,7 +192,7 @@ export class AudioClassifier extends AudioTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoVectorListener( + this.graphRunner.attachProtoVectorListener( TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { this.addJsAudioClassificationResults(binaryProtos); }); diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts index d08eb4791..445dd5172 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder.ts @@ -158,8 +158,11 @@ export class AudioEmbedder extends AudioTaskRunner { protected override process( audioData: Float32Array, sampleRate: number, timestampMs: number): AudioEmbedderResult[] { - this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestampMs); - this.addAudioToStreamWithShape(audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, AUDIO_STREAM, timestampMs); + this.graphRunner.addDoubleToStream( + sampleRate, SAMPLE_RATE_STREAM, timestampMs); + this.graphRunner.addAudioToStreamWithShape( + audioData, /* numChannels= */ 1, /* numSamples= */ audioData.length, + AUDIO_STREAM, timestampMs); this.embeddingResults = []; this.finishProcessing(); @@ -189,19 +192,21 @@ export class AudioEmbedder extends AudioTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResults.push( convertFromEmbeddingResultProto(embeddingResult)); }); - this.attachProtoVectorListener(TIMESTAMPED_EMBEDDINGS_STREAM, data => { - for (const binaryProto of data) { - const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); - this.embeddingResults.push( - convertFromEmbeddingResultProto(embeddingResult)); - } - }); + this.graphRunner.attachProtoVectorListener( + TIMESTAMPED_EMBEDDINGS_STREAM, data => { + for (const binaryProto of data) { + const embeddingResult = + EmbeddingResult.deserializeBinary(binaryProto); + this.embeddingResults.push( + convertFromEmbeddingResultProto(embeddingResult)); + } + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index c2691fc76..d769139bc 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -27,13 +27,15 @@ import {WasmFileset} from './wasm_fileset'; const NO_ASSETS = undefined; // tslint:disable-next-line:enforce-name-casing -const WasmMediaPipeImageLib = +const GraphRunnerImageLibType = SupportModelResourcesGraphService(SupportImage(GraphRunner)); +/** An implementation of the GraphRunner that supports image operations */ +export class GraphRunnerImageLib extends GraphRunnerImageLibType {} /** Base class for all MediaPipe Tasks. */ -export abstract class TaskRunner extends - WasmMediaPipeImageLib { +export abstract class TaskRunner { protected abstract baseOptions: BaseOptionsProto; + protected graphRunner: GraphRunnerImageLib; private processingErrors: Error[] = []; /** @@ -67,14 +69,14 @@ export abstract class TaskRunner extends constructor( wasmModule: WasmModule, glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { - super(wasmModule, glCanvas); + this.graphRunner = new GraphRunnerImageLib(wasmModule, glCanvas); // Disables the automatic render-to-screen code, which allows for pure // CPU processing. - this.setAutoRenderToScreen(false); + this.graphRunner.setAutoRenderToScreen(false); // Enables use of our model resource caching graph service. - this.registerModelResourcesGraphService(); + this.graphRunner.registerModelResourcesGraphService(); } /** Configures the shared options of a MediaPipe Task. */ @@ -95,11 +97,11 @@ export abstract class TaskRunner extends * @param isBinary This should be set to true if the graph is in * binary format, and false if it is in human-readable text format. */ - override setGraph(graphData: Uint8Array, isBinary: boolean): void { - this.attachErrorListener((code, message) => { + protected setGraph(graphData: Uint8Array, isBinary: boolean): void { + this.graphRunner.attachErrorListener((code, message) => { this.processingErrors.push(new Error(message)); }); - super.setGraph(graphData, isBinary); + this.graphRunner.setGraph(graphData, isBinary); this.handleErrors(); } @@ -108,8 +110,8 @@ export abstract class TaskRunner extends * far as possible, performing all processing until no more processing can be * done. */ - override finishProcessing(): void { - super.finishProcessing(); + protected finishProcessing(): void { + this.graphRunner.finishProcessing(); this.handleErrors(); } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index bd2a207ce..8810d4b42 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -133,7 +133,7 @@ export class TextClassifier extends TaskRunner { classify(text: string): TextClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.classificationResult; @@ -157,10 +157,11 @@ export class TextClassifier extends TaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index d2899fbe2..62f9b06db 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -136,7 +136,7 @@ export class TextEmbedder extends TaskRunner { */ embed(text: string): TextEmbedderResult { // Get text embeddings by running our MediaPipe graph. - this.addStringToStream( + this.graphRunner.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); return this.embeddingResult; @@ -173,7 +173,7 @@ export class TextEmbedder extends TaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { const embeddingResult = EmbeddingResult.deserializeBinary(binaryProto); this.embeddingResult = convertFromEmbeddingResultProto(embeddingResult); }); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 8baee5ce3..69a8118a6 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -257,8 +257,9 @@ export class GestureRecognizer extends this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -365,18 +366,22 @@ export class GestureRecognizer extends graphConfig.addNode(recognizerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HAND_GESTURES_STREAM, binaryProto => { - this.gestures.push(...this.toJsCategories(binaryProto)); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HAND_GESTURES_STREAM, binaryProto => { + this.gestures.push(...this.toJsCategories(binaryProto)); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 263ed4b48..9a0823f23 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -208,8 +208,9 @@ export class HandLandmarker extends VisionTaskRunner { this.worldLandmarks = []; this.handednesses = []; - this.addGpuBufferAsImageToStream(imageSource, IMAGE_STREAM, timestamp); - this.addProtoToStream( + this.graphRunner.addGpuBufferAsImageToStream( + imageSource, IMAGE_STREAM, timestamp); + this.graphRunner.addProtoToStream( FULL_IMAGE_RECT.serializeBinary(), 'mediapipe.NormalizedRect', NORM_RECT_STREAM, timestamp); this.finishProcessing(); @@ -312,15 +313,18 @@ export class HandLandmarker extends VisionTaskRunner { graphConfig.addNode(landmarkerNode); - this.attachProtoVectorListener(LANDMARKS_STREAM, binaryProto => { - this.addJsLandmarks(binaryProto); - }); - this.attachProtoVectorListener(WORLD_LANDMARKS_STREAM, binaryProto => { - this.adddJsWorldLandmarks(binaryProto); - }); - this.attachProtoVectorListener(HANDEDNESS_STREAM, binaryProto => { - this.handednesses.push(...this.toJsCategories(binaryProto)); - }); + this.graphRunner.attachProtoVectorListener( + LANDMARKS_STREAM, binaryProto => { + this.addJsLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + WORLD_LANDMARKS_STREAM, binaryProto => { + this.adddJsWorldLandmarks(binaryProto); + }); + this.graphRunner.attachProtoVectorListener( + HANDEDNESS_STREAM, binaryProto => { + this.handednesses.push(...this.toJsCategories(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index 90dbf9798..40e8b5099 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -155,7 +155,7 @@ export class ImageClassifier extends VisionTaskRunner { ImageClassifierResult { // Get classification result by running our MediaPipe graph. this.classificationResult = {classifications: []}; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.classificationResult; @@ -181,10 +181,11 @@ export class ImageClassifier extends VisionTaskRunner { graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { - this.classificationResult = convertFromClassificationResultProto( - ClassificationResult.deserializeBinary(binaryProto)); - }); + this.graphRunner.attachProtoListener( + CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts index 559332650..f8b0204ee 100644 --- a/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts +++ b/mediapipe/tasks/web/vision/image_embedder/image_embedder.ts @@ -169,7 +169,7 @@ export class ImageEmbedder extends VisionTaskRunner { protected process(image: ImageSource, timestamp: number): ImageEmbedderResult { // Get embeddings by running our MediaPipe graph. - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( image, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return this.embeddings; @@ -201,7 +201,7 @@ export class ImageEmbedder extends VisionTaskRunner { graphConfig.addNode(embedderNode); - this.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { + this.graphRunner.attachProtoListener(EMBEDDINGS_STREAM, binaryProto => { this.addJsImageEmdedding(binaryProto); }); diff --git a/mediapipe/tasks/web/vision/object_detector/object_detector.ts b/mediapipe/tasks/web/vision/object_detector/object_detector.ts index 03171003f..e2cfe0575 100644 --- a/mediapipe/tasks/web/vision/object_detector/object_detector.ts +++ b/mediapipe/tasks/web/vision/object_detector/object_detector.ts @@ -185,7 +185,7 @@ export class ObjectDetector extends VisionTaskRunner { Detection[] { // Get detections by running our MediaPipe graph. this.detections = []; - this.addGpuBufferAsImageToStream( + this.graphRunner.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); return [...this.detections]; @@ -242,9 +242,10 @@ export class ObjectDetector extends VisionTaskRunner { graphConfig.addNode(detectorNode); - this.attachProtoVectorListener(DETECTIONS_STREAM, binaryProto => { - this.addJsObjectDetections(binaryProto); - }); + this.graphRunner.attachProtoVectorListener( + DETECTIONS_STREAM, binaryProto => { + this.addJsObjectDetections(binaryProto); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index e886999cb..7a4ea09e2 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -22,7 +22,7 @@ export declare interface WasmImageModule { * An implementation of GraphRunner that supports binding GPU image data as * `mediapipe::Image` instances. We implement as a proper TS mixin, to allow for * effective multiple inheritance. Example usage: - * `const WasmMediaPipeImageLib = SupportImage(GraphRunner);` + * `const GraphRunnerImageLib = SupportImage(GraphRunner);` */ // tslint:disable-next-line:enforce-name-casing export function SupportImage(Base: TBase) { diff --git a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts index bc9c93e8a..9f2791d80 100644 --- a/mediapipe/web/graph_runner/register_model_resources_graph_service.ts +++ b/mediapipe/web/graph_runner/register_model_resources_graph_service.ts @@ -20,8 +20,8 @@ export declare interface WasmModuleRegisterModelResources { * An implementation of GraphRunner that supports registering model * resources to a cache, in the form of a GraphService C++-side. We implement as * a proper TS mixin, to allow for effective multiple inheritance. Sample usage: - * `const WasmMediaPipeImageLib = SupportModelResourcesGraphService( - * GraphRunner);` + * `const GraphRunnerWithModelResourcesLib = + * SupportModelResourcesGraphService(GraphRunner);` */ // tslint:disable:enforce-name-casing export function SupportModelResourcesGraphService(