OpenSource MediaPipe Tasks Java

PiperOrigin-RevId: 477747787
This commit is contained in:
Sebastian Schmidt 2022-09-29 09:41:10 -07:00 committed by Copybara-Service
parent a8ca669f05
commit 227cc20bff
29 changed files with 2515 additions and 0 deletions

View File

@ -34,6 +34,7 @@ android_library(
android_library(
name = "android_framework_no_mff",
proguard_specs = [":proguard.pgcfg"],
visibility = ["//visibility:public"],
exports = [
":android_framework_no_proguard",
],

View File

@ -11,3 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "category",
srcs = ["Category.java"],
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "detection",
srcs = ["Detection.java"],
deps = [
":category",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,86 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Objects;
/**
* Category is a util class, contains a category name, its display name, a float value as score, and
* the index of the label in the corresponding label file. Typically it's used as result of
* classification or detection tasks.
*/
@AutoValue
public abstract class Category {
private static final float TOLERANCE = 1e-6f;
/**
* Creates a {@link Category} instance.
*
* @param score the probability score of this label category.
* @param index the index of the label in the corresponding label file.
* @param categoryName the label of this category object.
* @param displayName the display name of the label.
*/
public static Category create(float score, int index, String categoryName, String displayName) {
return new AutoValue_Category(score, index, categoryName, displayName);
}
/** The probability score of this label category. */
public abstract float score();
/** The index of the label in the corresponding label file. Returns -1 if the index is not set. */
public abstract int index();
/** The label of this category object. */
public abstract String categoryName();
/**
* The display name of the label, which may be translated for different locales. For example, a
* label, "apple", may be translated into Spanish for display purpose, so that the display name is
* "manzana".
*/
public abstract String displayName();
@Override
public final boolean equals(Object o) {
if (!(o instanceof Category)) {
return false;
}
Category other = (Category) o;
return Math.abs(other.score() - this.score()) < TOLERANCE
&& other.index() == this.index()
&& other.categoryName().equals(this.categoryName())
&& other.displayName().equals(this.displayName());
}
@Override
public final int hashCode() {
return Objects.hash(categoryName(), displayName(), score(), index());
}
@Override
public final String toString() {
return "<Category \""
+ categoryName()
+ "\" (displayName="
+ displayName()
+ " score="
+ score()
+ " index="
+ index()
+ ")>";
}
}

View File

@ -0,0 +1,50 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import android.graphics.RectF;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents one detected object in the results of {@link
* com.google.mediapipe.tasks.version.objectdetector.ObjectDetector}.
*/
@AutoValue
public abstract class Detection {
/**
* Creates a {@link Detection} instance from a list of {@link Category} and a bounding box.
*
* @param categories a list of {@link Category} objects that contain category name, display name,
* score, and the label index.
* @param boundingBox a {@link RectF} object to represent the bounding box.
*/
public static Detection create(List<Category> categories, RectF boundingBox) {
// As an open source project, we've been trying avoiding depending on common java libraries,
// such as Guava, because it may introduce conflicts with clients who also happen to use those
// libraries. Therefore, instead of using ImmutableList here, we convert the List into
// unmodifiableList
return new AutoValue_Detection(Collections.unmodifiableList(categories), boundingBox);
}
/** A list of {@link Category} objects. */
public abstract List<Category> categories();
/** A {@link RectF} object to represent the bounding box of the detected object. */
public abstract RectF boundingBox();
}

View File

@ -11,3 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
"//third_party:autovalue",
"//third_party/java/protobuf:protobuf_lite",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,96 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import com.google.auto.value.AutoValue;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.util.Optional;
/** Options to configure MediaPipe Tasks in general. */
@AutoValue
public abstract class BaseOptions {
/** Builder for {@link BaseOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the model path to a tflite model with metadata in the assets.
*
* <p>Note: when model path is set, both model file descriptor and model buffer should be empty.
*/
public abstract Builder setModelAssetPath(String value);
/**
* Sets the native fd int of a tflite model with metadata.
*
* <p>Note: when model file descriptor is set, both model path and model buffer should be empty.
*/
public abstract Builder setModelAssetFileDescriptor(Integer value);
/**
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model
* with metadata.
*
* <p>Note: when model buffer is set, both model file and model file descriptor should be empty.
*/
public abstract Builder setModelAssetBuffer(ByteBuffer value);
/**
* Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default
* delegate CPU is used.
*/
public abstract Builder setDelegate(Delegate delegate);
abstract BaseOptions autoBuild();
/**
* Validates and builds the {@link BaseOptions} instance.
*
* @throws IllegalArgumentException if {@link BaseOptions} is invalid, or the provided model
* buffer is not a direct {@link ByteBuffer} or a {@link MappedByteBuffer}.
*/
public final BaseOptions build() {
BaseOptions options = autoBuild();
int modelAssetPathPresent = options.modelAssetPath().isPresent() ? 1 : 0;
int modelAssetFileDescriptorPresent = options.modelAssetFileDescriptor().isPresent() ? 1 : 0;
int modelAssetBufferPresent = options.modelAssetBuffer().isPresent() ? 1 : 0;
if (modelAssetPathPresent + modelAssetFileDescriptorPresent + modelAssetBufferPresent != 1) {
throw new IllegalArgumentException(
"Please specify only one of the model asset path, the model asset file descriptor, and"
+ " the model asset buffer.");
}
if (options.modelAssetBuffer().isPresent()
&& !(options.modelAssetBuffer().get().isDirect()
|| options.modelAssetBuffer().get() instanceof MappedByteBuffer)) {
throw new IllegalArgumentException(
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
return options;
}
}
abstract Optional<String> modelAssetPath();
abstract Optional<Integer> modelAssetFileDescriptor();
abstract Optional<ByteBuffer> modelAssetBuffer();
abstract Delegate delegate();
public static Builder builder() {
return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU);
}
}

View File

@ -0,0 +1,22 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
/** MediaPipe Tasks delegate. */
// TODO implement advanced delegate setting.
public enum Delegate {
CPU,
GPU,
}

View File

@ -0,0 +1,20 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
/** Interface for the customizable MediaPipe task error listener. */
public interface ErrorListener {
void onError(RuntimeException e);
}

View File

@ -0,0 +1,49 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import java.util.concurrent.atomic.AtomicBoolean;
/** Facililates creation and destruction of the native ModelResourcesCache. */
class ModelResourcesCache {
private final long nativeHandle;
private final AtomicBoolean isHandleValid;
public ModelResourcesCache() {
nativeHandle = nativeCreateModelResourcesCache();
isHandleValid = new AtomicBoolean(true);
}
public boolean isHandleValid() {
return isHandleValid.get();
}
public long getNativeHandle() {
if (isHandleValid.get()) {
return nativeHandle;
}
return 0;
}
public void release() {
if (isHandleValid.compareAndSet(true, false)) {
nativeReleaseModelResourcesCache(nativeHandle);
}
}
private native long nativeCreateModelResourcesCache();
private native void nativeReleaseModelResourcesCache(long nativeHandle);
}

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import com.google.mediapipe.framework.GraphService;
/** Java wrapper for graph service of ModelResourcesCacheService. */
class ModelResourcesCacheService implements GraphService<ModelResourcesCache> {
public ModelResourcesCacheService() {}
@Override
public void installServiceObject(long context, ModelResourcesCache object) {
nativeInstallServiceObject(context, object.getNativeHandle());
}
public native void nativeInstallServiceObject(long context, long object);
}

View File

@ -0,0 +1,130 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import android.util.Log;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import java.util.List;
/** Base class for handling MediaPipe task graph outputs. */
public class OutputHandler<OutputT extends TaskResult, InputT> {
/**
* Interface for converting MediaPipe graph output {@link Packet}s to task result object and task
* input object.
*/
public interface OutputPacketConverter<OutputT extends TaskResult, InputT> {
OutputT convertToTaskResult(List<Packet> packets);
InputT convertToTaskInput(List<Packet> packets);
}
/** Interface for the customizable MediaPipe task result listener. */
public interface ResultListener<OutputT extends TaskResult, InputT> {
void run(OutputT result, InputT input);
}
private static final String TAG = "OutputHandler";
// A task-specific graph output packet converter that should be implemented per task.
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
// The user-defined task result listener.
private ResultListener<OutputT, InputT> resultListener;
// The user-defined error listener.
protected ErrorListener errorListener;
// The cached task result for non latency sensitive use cases.
protected OutputT cachedTaskResult;
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
private boolean handleTimestampBoundChanges = false;
/**
* Sets a callback to be invoked to convert a {@link Packet} list to a task result object and a
* task input object.
*
* @param converter the task-specific {@link OutputPacketConverter} callback.
*/
public void setOutputPacketConverter(OutputPacketConverter<OutputT, InputT> converter) {
this.outputPacketConverter = converter;
}
/**
* Sets a callback to be invoked when task result objects become available.
*
* @param listener the user-defined {@link ResultListener} callback.
*/
public void setResultListener(ResultListener<OutputT, InputT> listener) {
this.resultListener = listener;
}
/**
* Sets a callback to be invoked when exceptions are thrown from the task graph.
*
* @param listener The user-defined {@link ErrorListener} callback.
*/
public void setErrorListener(ErrorListener listener) {
this.errorListener = listener;
}
/**
* Sets whether the output handler should react to the timestamp bound changes that are reprsented
* as empty output {@link Packet}s.
*
* @param handleTimestampBoundChanges A boolean value.
*/
public void setHandleTimestampBoundChanges(boolean handleTimestampBoundChanges) {
this.handleTimestampBoundChanges = handleTimestampBoundChanges;
}
/** Returns true if the task graph is set to handle timestamp bound changes. */
boolean handleTimestampBoundChanges() {
return handleTimestampBoundChanges;
}
/* Returns the cached task result object. */
public OutputT retrieveCachedTaskResult() {
OutputT taskResult = cachedTaskResult;
cachedTaskResult = null;
return taskResult;
}
/**
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
*
* @param packets A list of output {@link Packet}s.
*/
void run(List<Packet> packets) {
OutputT taskResult = null;
try {
taskResult = outputPacketConverter.convertToTaskResult(packets);
if (resultListener == null) {
cachedTaskResult = taskResult;
} else {
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
resultListener.run(taskResult, taskInput);
}
} catch (MediaPipeException e) {
if (errorListener != null) {
errorListener.onError(e);
} else {
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e);
}
} finally {
for (Packet packet : packets) {
if (packet != null) {
packet.release();
}
}
}
}
}

View File

@ -0,0 +1,156 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node;
import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo;
import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions;
import java.util.ArrayList;
import java.util.List;
/**
* {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link
* com.google.mediapipe.framework.Graph}.
*/
@AutoValue
public abstract class TaskInfo<T extends TaskOptions> {
/** Builder for {@link TaskInfo}. */
@AutoValue.Builder
public abstract static class Builder<T extends TaskOptions> {
/** Sets the MediaPipe task graph name. */
public abstract Builder<T> setTaskGraphName(String value);
/** Sets a list of task graph input stream info {@link String}s in the form TAG:name. */
public abstract Builder<T> setInputStreams(List<String> value);
/** Sets a list of task graph output stream info {@link String}s in the form TAG:name. */
public abstract Builder<T> setOutputStreams(List<String> value);
/** Sets to true if the task requires a flow limiter. */
public abstract Builder<T> setEnableFlowLimiting(Boolean value);
/**
* Sets a task-specific options instance.
*
* @param value a task-specific options that is derived from {@link TaskOptions}.
*/
public abstract Builder<T> setTaskOptions(T value);
public abstract TaskInfo<T> autoBuild();
/**
* Validates and builds the {@link TaskInfo} instance. *
*
* @throws IllegalArgumentException if the required information such as task graph name, graph
* input streams, and the graph output streams are empty.
*/
public final TaskInfo<T> build() {
TaskInfo<T> taskInfo = autoBuild();
if (taskInfo.taskGraphName().isEmpty()
|| taskInfo.inputStreams().isEmpty()
|| taskInfo.outputStreams().isEmpty()) {
throw new IllegalArgumentException(
"Task graph's name, input streams, and output streams should be non-empty.");
}
return taskInfo;
}
}
abstract String taskGraphName();
abstract T taskOptions();
abstract List<String> inputStreams();
abstract List<String> outputStreams();
abstract Boolean enableFlowLimiting();
public static <T extends TaskOptions> Builder<T> builder() {
return new AutoValue_TaskInfo.Builder<T>();
}
/* Returns a list of the output stream names without the stream tags. */
List<String> outputStreamNames() {
List<String> streamNames = new ArrayList<>(outputStreams().size());
for (String stream : outputStreams()) {
streamNames.add(stream.substring(stream.lastIndexOf(':') + 1));
}
return streamNames;
}
/**
* Creates a MediaPipe Task {@link CalculatorGraphConfig} protobuf message from the {@link
* TaskInfo} instance.
*/
CalculatorGraphConfig generateGraphConfig() {
CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder();
Node.Builder taskSubgraphBuilder =
Node.newBuilder()
.setCalculator(taskGraphName())
.setOptions(taskOptions().convertToCalculatorOptionsProto());
for (String outputStream : outputStreams()) {
taskSubgraphBuilder.addOutputStream(outputStream);
graphBuilder.addOutputStream(outputStream);
}
if (!enableFlowLimiting()) {
for (String inputStream : inputStreams()) {
taskSubgraphBuilder.addInputStream(inputStream);
graphBuilder.addInputStream(inputStream);
}
graphBuilder.addNode(taskSubgraphBuilder.build());
return graphBuilder.build();
}
Node.Builder flowLimiterCalculatorBuilder =
Node.newBuilder()
.setCalculator("FlowLimiterCalculator")
.addInputStreamInfo(
InputStreamInfo.newBuilder().setTagIndex("FINISHED").setBackEdge(true).build())
.setOptions(
CalculatorOptions.newBuilder()
.setExtension(
FlowLimiterCalculatorOptions.ext,
FlowLimiterCalculatorOptions.newBuilder()
.setMaxInFlight(1)
.setMaxInQueue(1)
.build())
.build());
for (String inputStream : inputStreams()) {
graphBuilder.addInputStream(inputStream);
flowLimiterCalculatorBuilder.addInputStream(stripTagIndex(inputStream));
String taskInputStream = addStreamNamePrefix(inputStream);
flowLimiterCalculatorBuilder.addOutputStream(stripTagIndex(taskInputStream));
taskSubgraphBuilder.addInputStream(taskInputStream);
}
flowLimiterCalculatorBuilder.addInputStream(
"FINISHED:" + stripTagIndex(outputStreams().get(0)));
graphBuilder.addNode(flowLimiterCalculatorBuilder.build());
graphBuilder.addNode(taskSubgraphBuilder.build());
return graphBuilder.build();
}
private String stripTagIndex(String tagIndexName) {
return tagIndexName.substring(tagIndexName.lastIndexOf(':') + 1);
}
private String addStreamNamePrefix(String tagIndexName) {
return tagIndexName.substring(0, tagIndexName.lastIndexOf(':') + 1)
+ "throttled_"
+ stripTagIndex(tagIndexName);
}
}

View File

@ -0,0 +1,75 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.tasks.core;
import com.google.mediapipe.calculator.proto.InferenceCalculatorProto;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.tasks.core.proto.AccelerationProto;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
import com.google.protobuf.ByteString;
/**
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
* {@link TaskOptions}.
*/
public abstract class TaskOptions {
/**
* Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf
* message.
*/
public abstract CalculatorOptions convertToCalculatorOptionsProto();
/**
* Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf
* message.
*/
protected BaseOptionsProto.BaseOptions convertBaseOptionsToProto(BaseOptions options) {
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
ExternalFileProto.ExternalFile.newBuilder();
options.modelAssetPath().ifPresent(externalFileBuilder::setFileName);
options
.modelAssetFileDescriptor()
.ifPresent(
fd ->
externalFileBuilder.setFileDescriptorMeta(
ExternalFileProto.FileDescriptorMeta.newBuilder().setFd(fd).build()));
options
.modelAssetBuffer()
.ifPresent(
modelBuffer -> {
modelBuffer.rewind();
externalFileBuilder.setFileContent(ByteString.copyFrom(modelBuffer));
});
AccelerationProto.Acceleration.Builder accelerationBuilder =
AccelerationProto.Acceleration.newBuilder();
switch (options.delegate()) {
case CPU:
accelerationBuilder.setXnnpack(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
.getDefaultInstance());
break;
case GPU:
accelerationBuilder.setGpu(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.getDefaultInstance());
break;
}
return BaseOptionsProto.BaseOptions.newBuilder()
.setModelAsset(externalFileBuilder.build())
.setAcceleration(accelerationBuilder.build())
.build();
}
}

View File

@ -0,0 +1,24 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
/**
* Interface for the MediaPipe Task result. Any MediaPipe task-specific result class should
* implement {@link TaskResult}.
*/
public interface TaskResult {
/** Returns the timestamp that is associated with the task result object. */
long timestampMs();
}

View File

@ -0,0 +1,265 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import android.content.Context;
import android.util.Log;
import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
/** The runner of MediaPipe task graphs. */
public class TaskRunner implements AutoCloseable {
private static final String TAG = TaskRunner.class.getSimpleName();
private static final long TIMESATMP_UNITS_PER_SECOND = 1000000;
private final OutputHandler<? extends TaskResult, ?> outputHandler;
private final AtomicBoolean graphStarted = new AtomicBoolean(false);
private final Graph graph;
private final ModelResourcesCache modelResourcesCache;
private final AndroidPacketCreator packetCreator;
private long lastSeenTimestamp = Long.MIN_VALUE;
private ErrorListener errorListener;
/**
* Create a {@link TaskRunner} instance.
*
* @param context an Android {@link Context}.
* @param taskInfo a {@link TaskInfo} instance contains task graph name, task options, and graph
* input and output stream names.
* @param outputHandler a {@link OutputHandler} instance handles task result object and runtime
* exception.
* @throws MediaPipeException for any error during {@link TaskRunner} creation.
*/
public static TaskRunner create(
Context context,
TaskInfo<? extends TaskOptions> taskInfo,
OutputHandler<? extends TaskResult, ?> outputHandler) {
AndroidAssetUtil.initializeNativeAssetManager(context);
Graph mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
ModelResourcesCache graphModelResourcesCache = new ModelResourcesCache();
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
mediapipeGraph.addMultiStreamCallback(
taskInfo.outputStreamNames(),
outputHandler::run,
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
mediapipeGraph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully started.
mediapipeGraph.waitUntilGraphIdle();
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
}
/**
* Sets a callback to be invoked when exceptions are thrown by the {@link TaskRunner} instance.
*
* @param listener an {@link ErrorListener} callback.
*/
public void setErrorListener(ErrorListener listener) {
this.errorListener = listener;
}
/** Returns the {@link AndroidPacketCreator} associated to the {@link TaskRunner} instance. */
public AndroidPacketCreator getPacketCreator() {
return packetCreator;
}
/**
* A synchronous method for processing batch data.
*
* <p>Note: This method is designed for processing batch data such as unrelated images and texts.
* The call blocks the current thread until a failure status or a successful result is returned.
* An internal timestamp will be assigend per invocation. This method is thread-safe and allows
* clients to call it from different threads.
*
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
*/
public synchronized TaskResult process(Map<String, Packet> inputs) {
addPackets(inputs, generateSyntheticTimestamp());
graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult();
}
/**
* A synchronous method for processing offline streaming data.
*
* <p>Note: This method is designed for processing offline streaming data such as the decoded
* frames from a video file and an audio file. The call blocks the current thread until a failure
* status or a successful result is returned. The caller must ensure that the input timestamp is
* greater than the timestamps of previous invocations. This method is thread-unsafe and it is the
* caller's responsibility to synchronize access to this method across multiple threads and to
* ensure that the input packet timestamps are in order.
*
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
* @param inputTimestamp the timestamp of the input packets.
*/
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
addPackets(inputs, inputTimestamp);
graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult();
}
/**
* An asynchronous method for handling live streaming data.
*
* <p>Note: This method that is designed for handling live streaming data such as live camera and
* microphone data. A user-defined packets callback function must be provided in the constructor
* to receive the output packets. The caller must ensure that the input packet timestamps are
* monotonically increasing. This method is thread-unsafe and it is the caller's responsibility to
* synchronize access to this method across multiple threads and to ensure that the input packet
* timestamps are in order.
*
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
* @param inputTimestamp the timestamp of the input packets.
*/
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
addPackets(inputs, inputTimestamp);
}
/**
* Resets and restarts the {@link TaskRunner} instance. This can be useful for resetting a
* stateful task graph to process new data.
*/
public void restart() {
if (graphStarted.get()) {
try {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
} catch (MediaPipeException e) {
reportError(e);
}
}
try {
graph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully restarted.
graph.waitUntilGraphIdle();
graphStarted.set(true);
} catch (MediaPipeException e) {
reportError(e);
}
}
/** Closes and cleans up the {@link TaskRunner} instance. */
@Override
public void close() {
if (!graphStarted.get()) {
return;
}
try {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
if (modelResourcesCache != null) {
modelResourcesCache.release();
}
} catch (MediaPipeException e) {
// Note: errors during Process are reported at the earliest opportunity,
// which may be addPacket or waitUntilDone, depending on timing. For consistency,
// we want to always report them using the same async handler if installed.
reportError(e);
}
try {
graph.tearDown();
} catch (MediaPipeException e) {
reportError(e);
}
}
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
if (!graphStarted.get()) {
reportError(
new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"The task graph hasn't been successfully started or error occurs during graph"
+ " initializaton."));
}
try {
for (Map.Entry<String, Packet> entry : inputs.entrySet()) {
// addConsumablePacketToInputStream allows the graph to take exclusive ownership of the
// packet, which may allow for more memory optimizations.
graph.addConsumablePacketToInputStream(entry.getKey(), entry.getValue(), inputTimestamp);
// If addConsumablePacket succeeded, we don't need to release the packet ourselves.
entry.setValue(null);
}
} catch (MediaPipeException e) {
// TODO: do not suppress exceptions here!
if (errorListener == null) {
Log.e(TAG, "Mediapipe error: ", e);
} else {
throw e;
}
} finally {
for (Packet packet : inputs.values()) {
// In case of error, addConsumablePacketToInputStream will not release the packet, so we
// have to release it ourselves.
if (packet != null) {
packet.release();
}
}
}
}
/**
* Checks if the input timestamp is strictly greater than the last timestamp that has been
* processed.
*
* @param inputTimestamp the input timestamp.
*/
private void validateInputTimstamp(long inputTimestamp) {
if (lastSeenTimestamp >= inputTimestamp) {
reportError(
new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"The received packets having a smaller timestamp than the processed timestamp."));
}
lastSeenTimestamp = inputTimestamp;
}
/** Generates a synthetic input timestamp in the batch processing mode. */
private long generateSyntheticTimestamp() {
long timestamp =
lastSeenTimestamp == Long.MIN_VALUE ? 0 : lastSeenTimestamp + TIMESATMP_UNITS_PER_SECOND;
lastSeenTimestamp = timestamp;
return timestamp;
}
/** Private constructor. */
private TaskRunner(
Graph graph,
ModelResourcesCache modelResourcesCache,
OutputHandler<? extends TaskResult, ?> outputHandler) {
this.outputHandler = outputHandler;
this.graph = graph;
this.modelResourcesCache = modelResourcesCache;
this.packetCreator = new AndroidPacketCreator(graph);
graphStarted.set(true);
}
/** Reports error. */
private void reportError(MediaPipeException e) {
if (errorListener != null) {
errorListener.onError(e);
} else {
throw e;
}
}
}

View File

@ -0,0 +1,40 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library_with_tflite(
name = "model_resources_cache_jni",
srcs = [
"model_resources_cache_jni.cc",
],
hdrs = [
"model_resources_cache_jni.h",
],
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources_cache",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
] + select({
"//conditions:default": ["//third_party/java/jdk:jni"],
"//mediapipe:android": [],
}),
alwayslink = 1,
)

View File

@ -0,0 +1,72 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library_with_tflite(
name = "model_resources_cache_jni",
srcs = [
"model_resources_cache_jni.cc",
],
hdrs = [
"model_resources_cache_jni.h",
] + select({
# The Android toolchain makes "jni.h" available in the include path.
# For non-Android toolchains, generate jni.h and jni_md.h.
"//mediapipe:android": [],
"//conditions:default": [
":jni.h",
":jni_md.h",
],
}),
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources_cache",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
] + select({
"//conditions:default": [],
"//mediapipe:android": [],
}),
alwayslink = 1,
)
# Silly rules to make
# #include <jni.h>
# in the source headers work
# (in combination with the "includes" attribute of the tf_cuda_library rule
# above. Not needed when using the Android toolchain).
#
# Inspired from:
# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
# but hopefully there is a simpler alternative to this.
genrule(
name = "copy_jni_h",
srcs = ["@bazel_tools//tools/jdk:jni_header"],
outs = ["jni.h"],
cmd = "cp -f $< $@",
)
genrule(
name = "copy_jni_md_h",
srcs = select({
"//mediapipe:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
"//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
}),
outs = ["jni_md.h"],
cmd = "cp -f $< $@",
)

View File

@ -0,0 +1,49 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h"
#include <utility>
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace {
using ::mediapipe::tasks::core::kModelResourcesCacheService;
using ::mediapipe::tasks::core::ModelResourcesCache;
using HandleType = std::shared_ptr<ModelResourcesCache>*;
} // namespace
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) {
auto ptr = std::make_shared<ModelResourcesCache>(
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
HandleType handle = new std::shared_ptr<ModelResourcesCache>(std::move(ptr));
return reinterpret_cast<jlong>(handle);
}
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
jlong nativeHandle) {
delete reinterpret_cast<HandleType>(nativeHandle);
}
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
jlong objectHandle) {
mediapipe::android::GraphServiceHelper::SetServiceObject(
contextHandle, kModelResourcesCacheService,
*reinterpret_cast<HandleType>(objectHandle));
}

View File

@ -0,0 +1,45 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
#define JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
#define MODEL_RESOURCES_CACHE_METHOD(METHOD_NAME) \
Java_com_google_mediapipe_tasks_core_ModelResourcesCache_##METHOD_NAME
#define MODEL_RESOURCES_CACHE_SERVICE_METHOD(METHOD_NAME) \
Java_com_google_mediapipe_tasks_core_ModelResourcesCacheService_##METHOD_NAME
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz);
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
jlong nativeHandle);
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
jlong objectHandle);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_

View File

@ -11,3 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["*.java"]),
deps = [
":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"@maven//:com_google_guava_guava",
],
)
# The native library of all MediaPipe vision tasks.
cc_binary(
name = "libmediapipe_tasks_vision_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_vision_jni_lib",
srcs = [":libmediapipe_tasks_vision_jni.so"],
alwayslink = 1,
)

View File

@ -0,0 +1,114 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap;
import java.util.Map;
/** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private final TaskRunner runner;
private final RunningMode runningMode;
static {
System.loadLibrary("mediapipe_tasks_vision_jni");
}
/**
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision
* task {@link RunningMode}.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) {
this.runner = runner;
this.runningMode = runningMode;
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if the task is not in the image mode.
*/
protected TaskResult processImageData(String imageStreamName, Image image) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets);
}
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode.
*/
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode.
*/
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe vision task. */
@Override
public void close() {
runner.close();
}
}

View File

@ -0,0 +1,32 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
/**
* MediaPipe vision task running mode. A MediaPipe vision task can be run with three different
* modes:
*
* <ul>
* <li>IMAGE: The mode for running a mediapipe vision task on single image inputs.
* <li>VIDEO: The mode for running a mediapipe vision task on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for running a mediapipe vision task on a live stream of input data,
* such as from camera.
* </ul>
*/
public enum RunningMode {
IMAGE,
VIDEO,
LIVE_STREAM
}

View File

@ -11,3 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "objectdetector",
srcs = [
"ObjectDetectionResult.java",
"ObjectDetector.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = ":AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,75 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import android.graphics.RectF;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue
public abstract class ObjectDetectionResult implements TaskResult {
private static final int DEFAULT_CATEGORY_INDEX = -1;
@Override
public abstract long timestampMs();
public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections();
/**
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link Detection} protobuf messages.
*/
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
List<Category> categories = new ArrayList<>();
for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
categories.add(
Category.create(
detectionProto.getScore(idx),
detectionProto.getLabelIdCount() > idx
? detectionProto.getLabelId(idx)
: DEFAULT_CATEGORY_INDEX,
detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
detectionProto.getDisplayNameCount() > idx
? detectionProto.getDisplayName(idx)
: ""));
}
RectF boundingBox = new RectF();
if (detectionProto.getLocationData().hasBoundingBox()) {
BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
boundingBox.set(
/*left=*/ boundingBoxProto.getXmin(),
/*top=*/ boundingBoxProto.getYmin(),
/*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
/*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
}
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.create(
categories, boundingBox));
}
return new AutoValue_ObjectDetectionResult(
timestampMs, Collections.unmodifiableList(detections));
}
}

View File

@ -0,0 +1,415 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs object detection on images.
*
* <p>The API expects a TFLite model with <a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
*
* <p>The API supports models with one image input tensor and four output tensors. To be more
* specific, here are the requirements.
*
* <ul>
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
* <ul>
* <li>image input of size {@code [batch x height x width x channels]}.
* <li>batch inference is not supported ({@code batch} is required to be 1).
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
* <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
* to the metadata for input normalization.
* </ul>
* <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e:
* <ul>
* <li>Location tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing
* bounding boxes in the form [top, left, right, bottom].
* <li>{@code BoundingBoxProperties} are required to be attached to the metadata and
* must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}.
* </ul>
* <li>Classes tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results]}, each value representing the integer
* index of a class.
* <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS}
* associated files, they are used to convert the tensor values into labels.
* </ul>
* <li>scores tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results]}, each value representing the score of
* the detected object.
* </ul>
* <li>Number of detection tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>integer num_results as a tensor of size {@code [1]}.
* </ul>
* </ul>
* </ul>
*
* <p>An example of such model can be found on <a
* href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow
* Hub.</a>.
*/
public final class ObjectDetector extends BaseVisionTaskApi {
private static final String TAG = ObjectDetector.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
/**
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
* ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the detection model with metadata in the assets.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
* ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the detection model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates an {@link ObjectDetector} instance from a model buffer and the default {@link
* ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
* model.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param detectorOptions a {@link ObjectDetectorOptions} instance.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromOptions(
Context context, ObjectDetectorOptions detectorOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ObjectDetectionResult, Image> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, Image>() {
@Override
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
return ObjectDetectionResult.create(
PacketGetter.getProtoVector(
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
packets.get(DETECTIONS_OUT_STREAM_INDEX).getTimestamp());
}
@Override
public Image convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
detectorOptions.resultListener().ifPresent(handler::setResultListener);
detectorOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<ObjectDetectorOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(detectorOptions)
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
return new ObjectDetector(runner, detectorOptions.runningMode());
}
/**
* Constructor to initialize an {@link ObjectDetector} from a {@link TaskRunner} and a {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode);
}
/**
* Performs object detection on the provided single image. Only use this method when the {@link
* ObjectDetector} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(Image inputImage) {
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage);
}
/**
* Performs object detection on the provided video frame. Only use this method when the {@link
* ObjectDetector} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
return (ObjectDetectionResult)
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
}
/**
* Sends live image data to perform object detection, and the results will be available via the
* {@link ResultListener} provided in the {@link ObjectDetectorOptions}. Only use this method when
* the {@link ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
}
/** Options for setting up an {@link ObjectDetector}. */
@AutoValue
public abstract static class ObjectDetectorOptions extends TaskOptions {
/** Builder for {@link ObjectDetectorOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the object detector task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the object detector task. Default to the image mode. Object
* detector has three modes:
*
* <ul>
* <li>IMAGE: The mode for detecting objects on single image inputs.
* <li>VIDEO: The mode for detecting objects on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for detecting objects on a live stream of input data, such
* as from camera. In this mode, {@code setResultListener} must be called to set up a
* listener to receive the detection results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode value);
/** Sets the maximum number of top-scored classification results to return. */
public abstract Builder setDisplayNamesLocale(String value);
/**
* Sets the optional maximum number of top-scored detection results to return.
*
* <p>Overrides the ones provided in the model metadata. Results below this value are
* rejected.
*/
public abstract Builder setMaxResults(Integer value);
/**
* Sets the optional score threshold that overrides the one provided in the model metadata (if
* any). Results below this value are rejected.
*/
public abstract Builder setScoreThreshold(Float value);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> value);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> value);
/**
* Sets the result listener to receive the detection results asynchronously when the object
* detector is in the live stream mode.
*/
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
/** Sets an optional error listener. */
public abstract Builder setErrorListener(ErrorListener value);
abstract ObjectDetectorOptions autoBuild();
/**
* Validates and builds the {@link ObjectDetectorOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the object detector is
* in the live stream mode.
*/
public final ObjectDetectorOptions build() {
ObjectDetectorOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The object detector is in the live stream mode, a user-defined result listener"
+ " must be provided in ObjectDetectorOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The object detector is in the image or the video mode, a user-defined result"
+ " listener shouldn't be provided in ObjectDetectorOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ObjectDetectionResult, Image>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_ObjectDetector_ObjectDetectorOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/** Converts a {@link ObjectDetectorOptions} to a {@link CalculatorOptions} protobuf message. */
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ObjectDetectorOptionsProto.ObjectDetectorOptions.Builder taskOptionsBuilder =
ObjectDetectorOptionsProto.ObjectDetectorOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
displayNamesLocale().ifPresent(taskOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(taskOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(taskOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
taskOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
taskOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
return CalculatorOptions.newBuilder()
.setExtension(
ObjectDetectorOptionsProto.ObjectDetectorOptions.ext, taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -11,3 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "test_utils",
srcs = ["TestUtils.java"],
deps = [
"//third_party/java/android_libs/guava_jdk5:io",
],
)

View File

@ -0,0 +1,83 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import android.content.Context;
import android.content.res.AssetManager;
import com.google.common.io.ByteStreams;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/** Helper class for the Java test in MediaPipe Tasks. */
public final class TestUtils {
/**
* Loads the file and create a {@link File} object by reading a file from the asset directory.
* Simulates downloading or reading a file that's not precompiled with the app.
*
* @return a {@link File} object for the model.
*/
public static File loadFile(Context context, String fileName) {
File target = new File(context.getFilesDir(), fileName);
try (InputStream is = context.getAssets().open(fileName);
FileOutputStream os = new FileOutputStream(target)) {
ByteStreams.copy(is, os);
} catch (IOException e) {
throw new AssertionError("Failed to load model file at " + fileName, e);
}
return target;
}
/**
* Reads a file into a direct {@link ByteBuffer} object from the asset directory.
*
* @return a {@link ByteBuffer} object for the file.
*/
public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName)
throws IOException {
AssetManager assetManager = context.getAssets();
InputStream inputStream = assetManager.open(fileName);
byte[] bytes = ByteStreams.toByteArray(inputStream);
ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
buffer.put(bytes);
return buffer;
}
/**
* Reads a file into a non-direct {@link ByteBuffer} object from the asset directory.
*
* @return a {@link ByteBuffer} object for the file.
*/
public static ByteBuffer loadToNonDirectByteBuffer(Context context, String fileName)
throws IOException {
AssetManager assetManager = context.getAssets();
InputStream inputStream = assetManager.open(fileName);
byte[] bytes = ByteStreams.toByteArray(inputStream);
return ByteBuffer.wrap(bytes);
}
public enum ByteBufferType {
DIRECT,
BACK_UP_ARRAY,
OTHER // Non-direct ByteBuffer without a back-up array.
}
private TestUtils() {}
}

View File

@ -12,4 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@build_bazel_rules_android//android:rules.bzl", "android_library_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,456 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import android.graphics.RectF;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Detection;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link ObjectDetector}. */
@RunWith(Suite.class)
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
public class ObjectDetectorTest {
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
private static final int IMAGE_WIDTH = 1200;
private static final int IMAGE_HEIGHT = 600;
private static final float CAT_SCORE = 0.69f;
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
// TODO: Figure out why android_x86 and android_arm tests have slightly different
// scores (0.6875 vs 0.69921875).
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
private static final float PIXEL_DIFF_TOLERANCE = 5.0f;
@RunWith(AndroidJUnit4.class)
public static final class General extends ObjectDetectorTest {
@Test
public void detect_successWithValidModels() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void detect_successWithNoOptions() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@Test
public void detect_succeedsWithMaxResultsOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(8)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// results should have 8 detected objects because maxResults was set to 8.
assertThat(results.detections()).hasSize(8);
}
@Test
public void detect_succeedsWithScoreThresholdOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setScoreThreshold(0.68f)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block all other other objects, except cat.
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void detect_succeedsWithAllowListOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("cat"))
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the allowlist, results should only contain cat, and there are 6 detected
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
assertThat(results.detections()).hasSize(5);
}
@Test
public void detect_succeedsWithDenyListOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setCategoryDenylist(Arrays.asList("cat"))
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the denylist, the highest result is not cat anymore.
assertThat(results.detections().get(0).categories().get(0).categoryName())
.isNotEqualTo("cat");
}
@Test
public void detect_succeedsWithModelFileObject() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@Test
public void detect_succeedsWithModelBuffer() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromBuffer(
ApplicationProvider.getApplicationContext(),
TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@Test
public void detect_succeedsWithModelBufferAndOptions() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetBuffer(
TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE))
.build())
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void create_failsWithMissingModel() throws Exception {
String nonexistentFile = "/path/to/non/existent/file";
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(), nonexistentFile));
assertThat(exception).hasMessageThat().contains(nonexistentFile);
}
@Test
public void create_failsWithInvalidModelBuffer() throws Exception {
// Create a non-direct model ByteBuffer.
ByteBuffer modelBuffer =
TestUtils.loadToNonDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ObjectDetector.createFromBuffer(
ApplicationProvider.getApplicationContext(), modelBuffer));
assertThat(exception)
.hasMessageThat()
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
@Test
public void detect_failsWithBothAllowAndDenyListOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("cat"))
.setCategoryDenylist(Arrays.asList("dog"))
.build();
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
ObjectDetector.createFromOptions(
ApplicationProvider.getApplicationContext(), options));
assertThat(exception)
.hasMessageThat()
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
}
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends ObjectDetectorTest {
@Test
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(mode)
.setResultListener((objectDetectionResult, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener shouldn't be provided");
}
}
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((objectDetectionResult, inputImage) -> {})
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void detect_successWithImageMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void detect_successWithVideoMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ObjectDetectionResult results =
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
}
@Test
public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
.build();
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
objectDetector.detectAsync(image, 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
@Test
public void detect_successWithLiveSteamMode() throws Exception {
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
.build();
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
objectDetector.detectAsync(image, i);
}
}
}
}
private static Image getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
// Checks if results has one and only detection result, which is a cat.
private static void assertContainsOnlyCat(
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
assertThat(result.detections()).hasSize(1);
Detection catResult = result.detections().get(0);
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
// We only support one category for each detected object at this point.
assertIsCat(catResult.categories().get(0), expectedScore);
}
private static void assertIsCat(Category category, float expectedScore) {
assertThat(category.categoryName()).isEqualTo("cat");
// coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite does not support label locale.
assertThat(category.displayName()).isEmpty();
assertThat((double) category.score()).isWithin(SCORE_DIFF_TOLERANCE).of(expectedScore);
assertThat(category.index()).isEqualTo(-1);
}
private static void assertApproximatelyEqualBoundingBoxes(
RectF boundingBox1, RectF boundingBox2) {
assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left);
assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top);
assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right);
assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom);
}
private static void assertImageSizeIsExpected(Image inputImage) {
assertThat(inputImage).isNotNull();
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
}
}