OpenSource MediaPipe Tasks Java
PiperOrigin-RevId: 477747787
This commit is contained in:
parent
a8ca669f05
commit
227cc20bff
|
@ -34,6 +34,7 @@ android_library(
|
||||||
android_library(
|
android_library(
|
||||||
name = "android_framework_no_mff",
|
name = "android_framework_no_mff",
|
||||||
proguard_specs = [":proguard.pgcfg"],
|
proguard_specs = [":proguard.pgcfg"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
exports = [
|
exports = [
|
||||||
":android_framework_no_proguard",
|
":android_framework_no_proguard",
|
||||||
],
|
],
|
||||||
|
|
|
@ -11,3 +11,26 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "category",
|
||||||
|
srcs = ["Category.java"],
|
||||||
|
deps = [
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "detection",
|
||||||
|
srcs = ["Detection.java"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Category is a util class, contains a category name, its display name, a float value as score, and
|
||||||
|
* the index of the label in the corresponding label file. Typically it's used as result of
|
||||||
|
* classification or detection tasks.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class Category {
|
||||||
|
private static final float TOLERANCE = 1e-6f;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link Category} instance.
|
||||||
|
*
|
||||||
|
* @param score the probability score of this label category.
|
||||||
|
* @param index the index of the label in the corresponding label file.
|
||||||
|
* @param categoryName the label of this category object.
|
||||||
|
* @param displayName the display name of the label.
|
||||||
|
*/
|
||||||
|
public static Category create(float score, int index, String categoryName, String displayName) {
|
||||||
|
return new AutoValue_Category(score, index, categoryName, displayName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The probability score of this label category. */
|
||||||
|
public abstract float score();
|
||||||
|
|
||||||
|
/** The index of the label in the corresponding label file. Returns -1 if the index is not set. */
|
||||||
|
public abstract int index();
|
||||||
|
|
||||||
|
/** The label of this category object. */
|
||||||
|
public abstract String categoryName();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The display name of the label, which may be translated for different locales. For example, a
|
||||||
|
* label, "apple", may be translated into Spanish for display purpose, so that the display name is
|
||||||
|
* "manzana".
|
||||||
|
*/
|
||||||
|
public abstract String displayName();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final boolean equals(Object o) {
|
||||||
|
if (!(o instanceof Category)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Category other = (Category) o;
|
||||||
|
return Math.abs(other.score() - this.score()) < TOLERANCE
|
||||||
|
&& other.index() == this.index()
|
||||||
|
&& other.categoryName().equals(this.categoryName())
|
||||||
|
&& other.displayName().equals(this.displayName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final int hashCode() {
|
||||||
|
return Objects.hash(categoryName(), displayName(), score(), index());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final String toString() {
|
||||||
|
return "<Category \""
|
||||||
|
+ categoryName()
|
||||||
|
+ "\" (displayName="
|
||||||
|
+ displayName()
|
||||||
|
+ " score="
|
||||||
|
+ score()
|
||||||
|
+ " index="
|
||||||
|
+ index()
|
||||||
|
+ ")>";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents one detected object in the results of {@link
|
||||||
|
* com.google.mediapipe.tasks.version.objectdetector.ObjectDetector}.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class Detection {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link Detection} instance from a list of {@link Category} and a bounding box.
|
||||||
|
*
|
||||||
|
* @param categories a list of {@link Category} objects that contain category name, display name,
|
||||||
|
* score, and the label index.
|
||||||
|
* @param boundingBox a {@link RectF} object to represent the bounding box.
|
||||||
|
*/
|
||||||
|
public static Detection create(List<Category> categories, RectF boundingBox) {
|
||||||
|
|
||||||
|
// As an open source project, we've been trying avoiding depending on common java libraries,
|
||||||
|
// such as Guava, because it may introduce conflicts with clients who also happen to use those
|
||||||
|
// libraries. Therefore, instead of using ImmutableList here, we convert the List into
|
||||||
|
// unmodifiableList
|
||||||
|
return new AutoValue_Detection(Collections.unmodifiableList(categories), boundingBox);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A list of {@link Category} objects. */
|
||||||
|
public abstract List<Category> categories();
|
||||||
|
|
||||||
|
/** A {@link RectF} object to represent the bounding box of the detected object. */
|
||||||
|
public abstract RectF boundingBox();
|
||||||
|
}
|
|
@ -11,3 +11,27 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "core",
|
||||||
|
srcs = glob(["*.java"]),
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||||
|
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
||||||
|
"//mediapipe/framework:calculator_java_proto_lite",
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"//third_party/java/protobuf:protobuf_lite",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.MappedByteBuffer;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/** Options to configure MediaPipe Tasks in general. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class BaseOptions {
|
||||||
|
/** Builder for {@link BaseOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/**
|
||||||
|
* Sets the model path to a tflite model with metadata in the assets.
|
||||||
|
*
|
||||||
|
* <p>Note: when model path is set, both model file descriptor and model buffer should be empty.
|
||||||
|
*/
|
||||||
|
public abstract Builder setModelAssetPath(String value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the native fd int of a tflite model with metadata.
|
||||||
|
*
|
||||||
|
* <p>Note: when model file descriptor is set, both model path and model buffer should be empty.
|
||||||
|
*/
|
||||||
|
public abstract Builder setModelAssetFileDescriptor(Integer value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model
|
||||||
|
* with metadata.
|
||||||
|
*
|
||||||
|
* <p>Note: when model buffer is set, both model file and model file descriptor should be empty.
|
||||||
|
*/
|
||||||
|
public abstract Builder setModelAssetBuffer(ByteBuffer value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default
|
||||||
|
* delegate CPU is used.
|
||||||
|
*/
|
||||||
|
public abstract Builder setDelegate(Delegate delegate);
|
||||||
|
|
||||||
|
abstract BaseOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link BaseOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if {@link BaseOptions} is invalid, or the provided model
|
||||||
|
* buffer is not a direct {@link ByteBuffer} or a {@link MappedByteBuffer}.
|
||||||
|
*/
|
||||||
|
public final BaseOptions build() {
|
||||||
|
BaseOptions options = autoBuild();
|
||||||
|
int modelAssetPathPresent = options.modelAssetPath().isPresent() ? 1 : 0;
|
||||||
|
int modelAssetFileDescriptorPresent = options.modelAssetFileDescriptor().isPresent() ? 1 : 0;
|
||||||
|
int modelAssetBufferPresent = options.modelAssetBuffer().isPresent() ? 1 : 0;
|
||||||
|
|
||||||
|
if (modelAssetPathPresent + modelAssetFileDescriptorPresent + modelAssetBufferPresent != 1) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Please specify only one of the model asset path, the model asset file descriptor, and"
|
||||||
|
+ " the model asset buffer.");
|
||||||
|
}
|
||||||
|
if (options.modelAssetBuffer().isPresent()
|
||||||
|
&& !(options.modelAssetBuffer().get().isDirect()
|
||||||
|
|| options.modelAssetBuffer().get() instanceof MappedByteBuffer)) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract Optional<String> modelAssetPath();
|
||||||
|
|
||||||
|
abstract Optional<Integer> modelAssetFileDescriptor();
|
||||||
|
|
||||||
|
abstract Optional<ByteBuffer> modelAssetBuffer();
|
||||||
|
|
||||||
|
abstract Delegate delegate();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
/** MediaPipe Tasks delegate. */
|
||||||
|
// TODO implement advanced delegate setting.
|
||||||
|
public enum Delegate {
|
||||||
|
CPU,
|
||||||
|
GPU,
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
/** Interface for the customizable MediaPipe task error listener. */
|
||||||
|
public interface ErrorListener {
|
||||||
|
void onError(RuntimeException e);
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
/** Facililates creation and destruction of the native ModelResourcesCache. */
|
||||||
|
class ModelResourcesCache {
|
||||||
|
private final long nativeHandle;
|
||||||
|
private final AtomicBoolean isHandleValid;
|
||||||
|
|
||||||
|
public ModelResourcesCache() {
|
||||||
|
nativeHandle = nativeCreateModelResourcesCache();
|
||||||
|
isHandleValid = new AtomicBoolean(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isHandleValid() {
|
||||||
|
return isHandleValid.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getNativeHandle() {
|
||||||
|
if (isHandleValid.get()) {
|
||||||
|
return nativeHandle;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void release() {
|
||||||
|
if (isHandleValid.compareAndSet(true, false)) {
|
||||||
|
nativeReleaseModelResourcesCache(nativeHandle);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private native long nativeCreateModelResourcesCache();
|
||||||
|
|
||||||
|
private native void nativeReleaseModelResourcesCache(long nativeHandle);
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import com.google.mediapipe.framework.GraphService;
|
||||||
|
|
||||||
|
/** Java wrapper for graph service of ModelResourcesCacheService. */
|
||||||
|
class ModelResourcesCacheService implements GraphService<ModelResourcesCache> {
|
||||||
|
public ModelResourcesCacheService() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void installServiceObject(long context, ModelResourcesCache object) {
|
||||||
|
nativeInstallServiceObject(context, object.getNativeHandle());
|
||||||
|
}
|
||||||
|
|
||||||
|
public native void nativeInstallServiceObject(long context, long object);
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import android.util.Log;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Base class for handling MediaPipe task graph outputs. */
|
||||||
|
public class OutputHandler<OutputT extends TaskResult, InputT> {
|
||||||
|
/**
|
||||||
|
* Interface for converting MediaPipe graph output {@link Packet}s to task result object and task
|
||||||
|
* input object.
|
||||||
|
*/
|
||||||
|
public interface OutputPacketConverter<OutputT extends TaskResult, InputT> {
|
||||||
|
OutputT convertToTaskResult(List<Packet> packets);
|
||||||
|
|
||||||
|
InputT convertToTaskInput(List<Packet> packets);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Interface for the customizable MediaPipe task result listener. */
|
||||||
|
public interface ResultListener<OutputT extends TaskResult, InputT> {
|
||||||
|
void run(OutputT result, InputT input);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final String TAG = "OutputHandler";
|
||||||
|
// A task-specific graph output packet converter that should be implemented per task.
|
||||||
|
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
|
||||||
|
// The user-defined task result listener.
|
||||||
|
private ResultListener<OutputT, InputT> resultListener;
|
||||||
|
// The user-defined error listener.
|
||||||
|
protected ErrorListener errorListener;
|
||||||
|
// The cached task result for non latency sensitive use cases.
|
||||||
|
protected OutputT cachedTaskResult;
|
||||||
|
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
|
||||||
|
private boolean handleTimestampBoundChanges = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a callback to be invoked to convert a {@link Packet} list to a task result object and a
|
||||||
|
* task input object.
|
||||||
|
*
|
||||||
|
* @param converter the task-specific {@link OutputPacketConverter} callback.
|
||||||
|
*/
|
||||||
|
public void setOutputPacketConverter(OutputPacketConverter<OutputT, InputT> converter) {
|
||||||
|
this.outputPacketConverter = converter;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a callback to be invoked when task result objects become available.
|
||||||
|
*
|
||||||
|
* @param listener the user-defined {@link ResultListener} callback.
|
||||||
|
*/
|
||||||
|
public void setResultListener(ResultListener<OutputT, InputT> listener) {
|
||||||
|
this.resultListener = listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a callback to be invoked when exceptions are thrown from the task graph.
|
||||||
|
*
|
||||||
|
* @param listener The user-defined {@link ErrorListener} callback.
|
||||||
|
*/
|
||||||
|
public void setErrorListener(ErrorListener listener) {
|
||||||
|
this.errorListener = listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets whether the output handler should react to the timestamp bound changes that are reprsented
|
||||||
|
* as empty output {@link Packet}s.
|
||||||
|
*
|
||||||
|
* @param handleTimestampBoundChanges A boolean value.
|
||||||
|
*/
|
||||||
|
public void setHandleTimestampBoundChanges(boolean handleTimestampBoundChanges) {
|
||||||
|
this.handleTimestampBoundChanges = handleTimestampBoundChanges;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns true if the task graph is set to handle timestamp bound changes. */
|
||||||
|
boolean handleTimestampBoundChanges() {
|
||||||
|
return handleTimestampBoundChanges;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Returns the cached task result object. */
|
||||||
|
public OutputT retrieveCachedTaskResult() {
|
||||||
|
OutputT taskResult = cachedTaskResult;
|
||||||
|
cachedTaskResult = null;
|
||||||
|
return taskResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
|
||||||
|
*
|
||||||
|
* @param packets A list of output {@link Packet}s.
|
||||||
|
*/
|
||||||
|
void run(List<Packet> packets) {
|
||||||
|
OutputT taskResult = null;
|
||||||
|
try {
|
||||||
|
taskResult = outputPacketConverter.convertToTaskResult(packets);
|
||||||
|
if (resultListener == null) {
|
||||||
|
cachedTaskResult = taskResult;
|
||||||
|
} else {
|
||||||
|
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
|
||||||
|
resultListener.run(taskResult, taskInput);
|
||||||
|
}
|
||||||
|
} catch (MediaPipeException e) {
|
||||||
|
if (errorListener != null) {
|
||||||
|
errorListener.onError(e);
|
||||||
|
} else {
|
||||||
|
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
for (Packet packet : packets) {
|
||||||
|
if (packet != null) {
|
||||||
|
packet.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,156 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
|
||||||
|
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node;
|
||||||
|
import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo;
|
||||||
|
import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link
|
||||||
|
* com.google.mediapipe.framework.Graph}.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class TaskInfo<T extends TaskOptions> {
|
||||||
|
/** Builder for {@link TaskInfo}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder<T extends TaskOptions> {
|
||||||
|
/** Sets the MediaPipe task graph name. */
|
||||||
|
public abstract Builder<T> setTaskGraphName(String value);
|
||||||
|
|
||||||
|
/** Sets a list of task graph input stream info {@link String}s in the form TAG:name. */
|
||||||
|
public abstract Builder<T> setInputStreams(List<String> value);
|
||||||
|
|
||||||
|
/** Sets a list of task graph output stream info {@link String}s in the form TAG:name. */
|
||||||
|
public abstract Builder<T> setOutputStreams(List<String> value);
|
||||||
|
|
||||||
|
/** Sets to true if the task requires a flow limiter. */
|
||||||
|
public abstract Builder<T> setEnableFlowLimiting(Boolean value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a task-specific options instance.
|
||||||
|
*
|
||||||
|
* @param value a task-specific options that is derived from {@link TaskOptions}.
|
||||||
|
*/
|
||||||
|
public abstract Builder<T> setTaskOptions(T value);
|
||||||
|
|
||||||
|
public abstract TaskInfo<T> autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link TaskInfo} instance. *
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the required information such as task graph name, graph
|
||||||
|
* input streams, and the graph output streams are empty.
|
||||||
|
*/
|
||||||
|
public final TaskInfo<T> build() {
|
||||||
|
TaskInfo<T> taskInfo = autoBuild();
|
||||||
|
if (taskInfo.taskGraphName().isEmpty()
|
||||||
|
|| taskInfo.inputStreams().isEmpty()
|
||||||
|
|| taskInfo.outputStreams().isEmpty()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Task graph's name, input streams, and output streams should be non-empty.");
|
||||||
|
}
|
||||||
|
return taskInfo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract String taskGraphName();
|
||||||
|
|
||||||
|
abstract T taskOptions();
|
||||||
|
|
||||||
|
abstract List<String> inputStreams();
|
||||||
|
|
||||||
|
abstract List<String> outputStreams();
|
||||||
|
|
||||||
|
abstract Boolean enableFlowLimiting();
|
||||||
|
|
||||||
|
public static <T extends TaskOptions> Builder<T> builder() {
|
||||||
|
return new AutoValue_TaskInfo.Builder<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Returns a list of the output stream names without the stream tags. */
|
||||||
|
List<String> outputStreamNames() {
|
||||||
|
List<String> streamNames = new ArrayList<>(outputStreams().size());
|
||||||
|
for (String stream : outputStreams()) {
|
||||||
|
streamNames.add(stream.substring(stream.lastIndexOf(':') + 1));
|
||||||
|
}
|
||||||
|
return streamNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a MediaPipe Task {@link CalculatorGraphConfig} protobuf message from the {@link
|
||||||
|
* TaskInfo} instance.
|
||||||
|
*/
|
||||||
|
CalculatorGraphConfig generateGraphConfig() {
|
||||||
|
CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder();
|
||||||
|
Node.Builder taskSubgraphBuilder =
|
||||||
|
Node.newBuilder()
|
||||||
|
.setCalculator(taskGraphName())
|
||||||
|
.setOptions(taskOptions().convertToCalculatorOptionsProto());
|
||||||
|
for (String outputStream : outputStreams()) {
|
||||||
|
taskSubgraphBuilder.addOutputStream(outputStream);
|
||||||
|
graphBuilder.addOutputStream(outputStream);
|
||||||
|
}
|
||||||
|
if (!enableFlowLimiting()) {
|
||||||
|
for (String inputStream : inputStreams()) {
|
||||||
|
taskSubgraphBuilder.addInputStream(inputStream);
|
||||||
|
graphBuilder.addInputStream(inputStream);
|
||||||
|
}
|
||||||
|
graphBuilder.addNode(taskSubgraphBuilder.build());
|
||||||
|
return graphBuilder.build();
|
||||||
|
}
|
||||||
|
Node.Builder flowLimiterCalculatorBuilder =
|
||||||
|
Node.newBuilder()
|
||||||
|
.setCalculator("FlowLimiterCalculator")
|
||||||
|
.addInputStreamInfo(
|
||||||
|
InputStreamInfo.newBuilder().setTagIndex("FINISHED").setBackEdge(true).build())
|
||||||
|
.setOptions(
|
||||||
|
CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
FlowLimiterCalculatorOptions.ext,
|
||||||
|
FlowLimiterCalculatorOptions.newBuilder()
|
||||||
|
.setMaxInFlight(1)
|
||||||
|
.setMaxInQueue(1)
|
||||||
|
.build())
|
||||||
|
.build());
|
||||||
|
for (String inputStream : inputStreams()) {
|
||||||
|
graphBuilder.addInputStream(inputStream);
|
||||||
|
flowLimiterCalculatorBuilder.addInputStream(stripTagIndex(inputStream));
|
||||||
|
String taskInputStream = addStreamNamePrefix(inputStream);
|
||||||
|
flowLimiterCalculatorBuilder.addOutputStream(stripTagIndex(taskInputStream));
|
||||||
|
taskSubgraphBuilder.addInputStream(taskInputStream);
|
||||||
|
}
|
||||||
|
flowLimiterCalculatorBuilder.addInputStream(
|
||||||
|
"FINISHED:" + stripTagIndex(outputStreams().get(0)));
|
||||||
|
graphBuilder.addNode(flowLimiterCalculatorBuilder.build());
|
||||||
|
graphBuilder.addNode(taskSubgraphBuilder.build());
|
||||||
|
return graphBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String stripTagIndex(String tagIndexName) {
|
||||||
|
return tagIndexName.substring(tagIndexName.lastIndexOf(':') + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String addStreamNamePrefix(String tagIndexName) {
|
||||||
|
return tagIndexName.substring(0, tagIndexName.lastIndexOf(':') + 1)
|
||||||
|
+ "throttled_"
|
||||||
|
+ stripTagIndex(tagIndexName);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import com.google.mediapipe.calculator.proto.InferenceCalculatorProto;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.AccelerationProto;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
|
||||||
|
import com.google.protobuf.ByteString;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
|
||||||
|
* {@link TaskOptions}.
|
||||||
|
*/
|
||||||
|
public abstract class TaskOptions {
|
||||||
|
/**
|
||||||
|
* Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf
|
||||||
|
* message.
|
||||||
|
*/
|
||||||
|
public abstract CalculatorOptions convertToCalculatorOptionsProto();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf
|
||||||
|
* message.
|
||||||
|
*/
|
||||||
|
protected BaseOptionsProto.BaseOptions convertBaseOptionsToProto(BaseOptions options) {
|
||||||
|
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
|
||||||
|
ExternalFileProto.ExternalFile.newBuilder();
|
||||||
|
options.modelAssetPath().ifPresent(externalFileBuilder::setFileName);
|
||||||
|
options
|
||||||
|
.modelAssetFileDescriptor()
|
||||||
|
.ifPresent(
|
||||||
|
fd ->
|
||||||
|
externalFileBuilder.setFileDescriptorMeta(
|
||||||
|
ExternalFileProto.FileDescriptorMeta.newBuilder().setFd(fd).build()));
|
||||||
|
options
|
||||||
|
.modelAssetBuffer()
|
||||||
|
.ifPresent(
|
||||||
|
modelBuffer -> {
|
||||||
|
modelBuffer.rewind();
|
||||||
|
externalFileBuilder.setFileContent(ByteString.copyFrom(modelBuffer));
|
||||||
|
});
|
||||||
|
AccelerationProto.Acceleration.Builder accelerationBuilder =
|
||||||
|
AccelerationProto.Acceleration.newBuilder();
|
||||||
|
switch (options.delegate()) {
|
||||||
|
case CPU:
|
||||||
|
accelerationBuilder.setXnnpack(
|
||||||
|
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
|
||||||
|
.getDefaultInstance());
|
||||||
|
break;
|
||||||
|
case GPU:
|
||||||
|
accelerationBuilder.setGpu(
|
||||||
|
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.getDefaultInstance());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return BaseOptionsProto.BaseOptions.newBuilder()
|
||||||
|
.setModelAsset(externalFileBuilder.build())
|
||||||
|
.setAcceleration(accelerationBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for the MediaPipe Task result. Any MediaPipe task-specific result class should
|
||||||
|
* implement {@link TaskResult}.
|
||||||
|
*/
|
||||||
|
public interface TaskResult {
|
||||||
|
/** Returns the timestamp that is associated with the task result object. */
|
||||||
|
long timestampMs();
|
||||||
|
}
|
|
@ -0,0 +1,265 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.util.Log;
|
||||||
|
import com.google.mediapipe.framework.AndroidAssetUtil;
|
||||||
|
import com.google.mediapipe.framework.AndroidPacketCreator;
|
||||||
|
import com.google.mediapipe.framework.Graph;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
/** The runner of MediaPipe task graphs. */
|
||||||
|
public class TaskRunner implements AutoCloseable {
|
||||||
|
private static final String TAG = TaskRunner.class.getSimpleName();
|
||||||
|
private static final long TIMESATMP_UNITS_PER_SECOND = 1000000;
|
||||||
|
|
||||||
|
private final OutputHandler<? extends TaskResult, ?> outputHandler;
|
||||||
|
private final AtomicBoolean graphStarted = new AtomicBoolean(false);
|
||||||
|
private final Graph graph;
|
||||||
|
private final ModelResourcesCache modelResourcesCache;
|
||||||
|
private final AndroidPacketCreator packetCreator;
|
||||||
|
private long lastSeenTimestamp = Long.MIN_VALUE;
|
||||||
|
private ErrorListener errorListener;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a {@link TaskRunner} instance.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param taskInfo a {@link TaskInfo} instance contains task graph name, task options, and graph
|
||||||
|
* input and output stream names.
|
||||||
|
* @param outputHandler a {@link OutputHandler} instance handles task result object and runtime
|
||||||
|
* exception.
|
||||||
|
* @throws MediaPipeException for any error during {@link TaskRunner} creation.
|
||||||
|
*/
|
||||||
|
public static TaskRunner create(
|
||||||
|
Context context,
|
||||||
|
TaskInfo<? extends TaskOptions> taskInfo,
|
||||||
|
OutputHandler<? extends TaskResult, ?> outputHandler) {
|
||||||
|
AndroidAssetUtil.initializeNativeAssetManager(context);
|
||||||
|
Graph mediapipeGraph = new Graph();
|
||||||
|
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
|
||||||
|
ModelResourcesCache graphModelResourcesCache = new ModelResourcesCache();
|
||||||
|
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
|
||||||
|
mediapipeGraph.addMultiStreamCallback(
|
||||||
|
taskInfo.outputStreamNames(),
|
||||||
|
outputHandler::run,
|
||||||
|
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
|
||||||
|
mediapipeGraph.startRunningGraph();
|
||||||
|
// Waits until all calculators are opened and the graph is fully started.
|
||||||
|
mediapipeGraph.waitUntilGraphIdle();
|
||||||
|
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a callback to be invoked when exceptions are thrown by the {@link TaskRunner} instance.
|
||||||
|
*
|
||||||
|
* @param listener an {@link ErrorListener} callback.
|
||||||
|
*/
|
||||||
|
public void setErrorListener(ErrorListener listener) {
|
||||||
|
this.errorListener = listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the {@link AndroidPacketCreator} associated to the {@link TaskRunner} instance. */
|
||||||
|
public AndroidPacketCreator getPacketCreator() {
|
||||||
|
return packetCreator;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A synchronous method for processing batch data.
|
||||||
|
*
|
||||||
|
* <p>Note: This method is designed for processing batch data such as unrelated images and texts.
|
||||||
|
* The call blocks the current thread until a failure status or a successful result is returned.
|
||||||
|
* An internal timestamp will be assigend per invocation. This method is thread-safe and allows
|
||||||
|
* clients to call it from different threads.
|
||||||
|
*
|
||||||
|
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||||
|
*/
|
||||||
|
public synchronized TaskResult process(Map<String, Packet> inputs) {
|
||||||
|
addPackets(inputs, generateSyntheticTimestamp());
|
||||||
|
graph.waitUntilGraphIdle();
|
||||||
|
return outputHandler.retrieveCachedTaskResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A synchronous method for processing offline streaming data.
|
||||||
|
*
|
||||||
|
* <p>Note: This method is designed for processing offline streaming data such as the decoded
|
||||||
|
* frames from a video file and an audio file. The call blocks the current thread until a failure
|
||||||
|
* status or a successful result is returned. The caller must ensure that the input timestamp is
|
||||||
|
* greater than the timestamps of previous invocations. This method is thread-unsafe and it is the
|
||||||
|
* caller's responsibility to synchronize access to this method across multiple threads and to
|
||||||
|
* ensure that the input packet timestamps are in order.
|
||||||
|
*
|
||||||
|
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||||
|
* @param inputTimestamp the timestamp of the input packets.
|
||||||
|
*/
|
||||||
|
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
|
||||||
|
validateInputTimstamp(inputTimestamp);
|
||||||
|
addPackets(inputs, inputTimestamp);
|
||||||
|
graph.waitUntilGraphIdle();
|
||||||
|
return outputHandler.retrieveCachedTaskResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An asynchronous method for handling live streaming data.
|
||||||
|
*
|
||||||
|
* <p>Note: This method that is designed for handling live streaming data such as live camera and
|
||||||
|
* microphone data. A user-defined packets callback function must be provided in the constructor
|
||||||
|
* to receive the output packets. The caller must ensure that the input packet timestamps are
|
||||||
|
* monotonically increasing. This method is thread-unsafe and it is the caller's responsibility to
|
||||||
|
* synchronize access to this method across multiple threads and to ensure that the input packet
|
||||||
|
* timestamps are in order.
|
||||||
|
*
|
||||||
|
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||||
|
* @param inputTimestamp the timestamp of the input packets.
|
||||||
|
*/
|
||||||
|
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
|
||||||
|
validateInputTimstamp(inputTimestamp);
|
||||||
|
addPackets(inputs, inputTimestamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resets and restarts the {@link TaskRunner} instance. This can be useful for resetting a
|
||||||
|
* stateful task graph to process new data.
|
||||||
|
*/
|
||||||
|
public void restart() {
|
||||||
|
if (graphStarted.get()) {
|
||||||
|
try {
|
||||||
|
graphStarted.set(false);
|
||||||
|
graph.closeAllPacketSources();
|
||||||
|
graph.waitUntilGraphDone();
|
||||||
|
} catch (MediaPipeException e) {
|
||||||
|
reportError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
graph.startRunningGraph();
|
||||||
|
// Waits until all calculators are opened and the graph is fully restarted.
|
||||||
|
graph.waitUntilGraphIdle();
|
||||||
|
graphStarted.set(true);
|
||||||
|
} catch (MediaPipeException e) {
|
||||||
|
reportError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the {@link TaskRunner} instance. */
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
if (!graphStarted.get()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
graphStarted.set(false);
|
||||||
|
graph.closeAllPacketSources();
|
||||||
|
graph.waitUntilGraphDone();
|
||||||
|
if (modelResourcesCache != null) {
|
||||||
|
modelResourcesCache.release();
|
||||||
|
}
|
||||||
|
} catch (MediaPipeException e) {
|
||||||
|
// Note: errors during Process are reported at the earliest opportunity,
|
||||||
|
// which may be addPacket or waitUntilDone, depending on timing. For consistency,
|
||||||
|
// we want to always report them using the same async handler if installed.
|
||||||
|
reportError(e);
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
graph.tearDown();
|
||||||
|
} catch (MediaPipeException e) {
|
||||||
|
reportError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
|
||||||
|
if (!graphStarted.get()) {
|
||||||
|
reportError(
|
||||||
|
new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"The task graph hasn't been successfully started or error occurs during graph"
|
||||||
|
+ " initializaton."));
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
for (Map.Entry<String, Packet> entry : inputs.entrySet()) {
|
||||||
|
// addConsumablePacketToInputStream allows the graph to take exclusive ownership of the
|
||||||
|
// packet, which may allow for more memory optimizations.
|
||||||
|
graph.addConsumablePacketToInputStream(entry.getKey(), entry.getValue(), inputTimestamp);
|
||||||
|
// If addConsumablePacket succeeded, we don't need to release the packet ourselves.
|
||||||
|
entry.setValue(null);
|
||||||
|
}
|
||||||
|
} catch (MediaPipeException e) {
|
||||||
|
// TODO: do not suppress exceptions here!
|
||||||
|
if (errorListener == null) {
|
||||||
|
Log.e(TAG, "Mediapipe error: ", e);
|
||||||
|
} else {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
for (Packet packet : inputs.values()) {
|
||||||
|
// In case of error, addConsumablePacketToInputStream will not release the packet, so we
|
||||||
|
// have to release it ourselves.
|
||||||
|
if (packet != null) {
|
||||||
|
packet.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the input timestamp is strictly greater than the last timestamp that has been
|
||||||
|
* processed.
|
||||||
|
*
|
||||||
|
* @param inputTimestamp the input timestamp.
|
||||||
|
*/
|
||||||
|
private void validateInputTimstamp(long inputTimestamp) {
|
||||||
|
if (lastSeenTimestamp >= inputTimestamp) {
|
||||||
|
reportError(
|
||||||
|
new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"The received packets having a smaller timestamp than the processed timestamp."));
|
||||||
|
}
|
||||||
|
lastSeenTimestamp = inputTimestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Generates a synthetic input timestamp in the batch processing mode. */
|
||||||
|
private long generateSyntheticTimestamp() {
|
||||||
|
long timestamp =
|
||||||
|
lastSeenTimestamp == Long.MIN_VALUE ? 0 : lastSeenTimestamp + TIMESATMP_UNITS_PER_SECOND;
|
||||||
|
lastSeenTimestamp = timestamp;
|
||||||
|
return timestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Private constructor. */
|
||||||
|
private TaskRunner(
|
||||||
|
Graph graph,
|
||||||
|
ModelResourcesCache modelResourcesCache,
|
||||||
|
OutputHandler<? extends TaskResult, ?> outputHandler) {
|
||||||
|
this.outputHandler = outputHandler;
|
||||||
|
this.graph = graph;
|
||||||
|
this.modelResourcesCache = modelResourcesCache;
|
||||||
|
this.packetCreator = new AndroidPacketCreator(graph);
|
||||||
|
graphStarted.set(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Reports error. */
|
||||||
|
private void reportError(MediaPipeException e) {
|
||||||
|
if (errorListener != null) {
|
||||||
|
errorListener.onError(e);
|
||||||
|
} else {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library_with_tflite(
|
||||||
|
name = "model_resources_cache_jni",
|
||||||
|
srcs = [
|
||||||
|
"model_resources_cache_jni.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"model_resources_cache_jni.h",
|
||||||
|
],
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
|
] + select({
|
||||||
|
"//conditions:default": ["//third_party/java/jdk:jni"],
|
||||||
|
"//mediapipe:android": [],
|
||||||
|
}),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
cc_library_with_tflite(
|
||||||
|
name = "model_resources_cache_jni",
|
||||||
|
srcs = [
|
||||||
|
"model_resources_cache_jni.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"model_resources_cache_jni.h",
|
||||||
|
] + select({
|
||||||
|
# The Android toolchain makes "jni.h" available in the include path.
|
||||||
|
# For non-Android toolchains, generate jni.h and jni_md.h.
|
||||||
|
"//mediapipe:android": [],
|
||||||
|
"//conditions:default": [
|
||||||
|
":jni.h",
|
||||||
|
":jni_md.h",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
tflite_deps = [
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
|
] + select({
|
||||||
|
"//conditions:default": [],
|
||||||
|
"//mediapipe:android": [],
|
||||||
|
}),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Silly rules to make
|
||||||
|
# #include <jni.h>
|
||||||
|
# in the source headers work
|
||||||
|
# (in combination with the "includes" attribute of the tf_cuda_library rule
|
||||||
|
# above. Not needed when using the Android toolchain).
|
||||||
|
#
|
||||||
|
# Inspired from:
|
||||||
|
# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
|
||||||
|
# but hopefully there is a simpler alternative to this.
|
||||||
|
genrule(
|
||||||
|
name = "copy_jni_h",
|
||||||
|
srcs = ["@bazel_tools//tools/jdk:jni_header"],
|
||||||
|
outs = ["jni.h"],
|
||||||
|
cmd = "cp -f $< $@",
|
||||||
|
)
|
||||||
|
|
||||||
|
genrule(
|
||||||
|
name = "copy_jni_md_h",
|
||||||
|
srcs = select({
|
||||||
|
"//mediapipe:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
|
||||||
|
"//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
|
||||||
|
}),
|
||||||
|
outs = ["jni_md.h"],
|
||||||
|
cmd = "cp -f $< $@",
|
||||||
|
)
|
|
@ -0,0 +1,49 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h"
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
using ::mediapipe::tasks::core::kModelResourcesCacheService;
|
||||||
|
using ::mediapipe::tasks::core::ModelResourcesCache;
|
||||||
|
using HandleType = std::shared_ptr<ModelResourcesCache>*;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||||
|
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) {
|
||||||
|
auto ptr = std::make_shared<ModelResourcesCache>(
|
||||||
|
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
|
||||||
|
HandleType handle = new std::shared_ptr<ModelResourcesCache>(std::move(ptr));
|
||||||
|
return reinterpret_cast<jlong>(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||||
|
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
|
||||||
|
jlong nativeHandle) {
|
||||||
|
delete reinterpret_cast<HandleType>(nativeHandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
|
||||||
|
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
|
||||||
|
jlong objectHandle) {
|
||||||
|
mediapipe::android::GraphServiceHelper::SetServiceObject(
|
||||||
|
contextHandle, kModelResourcesCacheService,
|
||||||
|
*reinterpret_cast<HandleType>(objectHandle));
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
|
||||||
|
#define JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
|
||||||
|
|
||||||
|
#include <jni.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#define MODEL_RESOURCES_CACHE_METHOD(METHOD_NAME) \
|
||||||
|
Java_com_google_mediapipe_tasks_core_ModelResourcesCache_##METHOD_NAME
|
||||||
|
|
||||||
|
#define MODEL_RESOURCES_CACHE_SERVICE_METHOD(METHOD_NAME) \
|
||||||
|
Java_com_google_mediapipe_tasks_core_ModelResourcesCacheService_##METHOD_NAME
|
||||||
|
|
||||||
|
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||||
|
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz);
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
|
||||||
|
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
|
||||||
|
jlong nativeHandle);
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
|
||||||
|
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
|
||||||
|
jlong objectHandle);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif // __cplusplus
|
||||||
|
|
||||||
|
#endif // JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
|
|
@ -11,3 +11,38 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "core",
|
||||||
|
srcs = glob(["*.java"]),
|
||||||
|
deps = [
|
||||||
|
":libmediapipe_tasks_vision_jni_lib",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The native library of all MediaPipe vision tasks.
|
||||||
|
cc_binary(
|
||||||
|
name = "libmediapipe_tasks_vision_jni.so",
|
||||||
|
linkshared = 1,
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "libmediapipe_tasks_vision_jni_lib",
|
||||||
|
srcs = [":libmediapipe_tasks_vision_jni.so"],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.core;
|
||||||
|
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.image.Image;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/** The base class of MediaPipe vision tasks. */
|
||||||
|
public class BaseVisionTaskApi implements AutoCloseable {
|
||||||
|
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||||
|
private final TaskRunner runner;
|
||||||
|
private final RunningMode runningMode;
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision
|
||||||
|
* task {@link RunningMode}.
|
||||||
|
*
|
||||||
|
* @param runner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||||
|
*/
|
||||||
|
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) {
|
||||||
|
this.runner = runner;
|
||||||
|
this.runningMode = runningMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||||
|
* failure status or a successful result is returned.
|
||||||
|
*
|
||||||
|
* @param imageStreamName the image input stream name.
|
||||||
|
* @param image a MediaPipe {@link Image} object for processing.
|
||||||
|
* @throws MediaPipeException if the task is not in the image mode.
|
||||||
|
*/
|
||||||
|
protected TaskResult processImageData(String imageStreamName, Image image) {
|
||||||
|
if (runningMode != RunningMode.IMAGE) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the image mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
|
return runner.process(inputPackets);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||||
|
* until a failure status or a successful result is returned.
|
||||||
|
*
|
||||||
|
* @param imageStreamName the image input stream name.
|
||||||
|
* @param image a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||||
|
* @throws MediaPipeException if the task is not in the video mode.
|
||||||
|
*/
|
||||||
|
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) {
|
||||||
|
if (runningMode != RunningMode.VIDEO) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the video mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
|
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||||
|
* available in the user-defined result listener.
|
||||||
|
*
|
||||||
|
* @param imageStreamName the image input stream name.
|
||||||
|
* @param image a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||||
|
* @throws MediaPipeException if the task is not in the video mode.
|
||||||
|
*/
|
||||||
|
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) {
|
||||||
|
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the live stream mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||||
|
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the MediaPipe vision task. */
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
runner.close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.core;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MediaPipe vision task running mode. A MediaPipe vision task can be run with three different
|
||||||
|
* modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>IMAGE: The mode for running a mediapipe vision task on single image inputs.
|
||||||
|
* <li>VIDEO: The mode for running a mediapipe vision task on the decoded frames of a video.
|
||||||
|
* <li>LIVE_STREAM: The mode for running a mediapipe vision task on a live stream of input data,
|
||||||
|
* such as from camera.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public enum RunningMode {
|
||||||
|
IMAGE,
|
||||||
|
VIDEO,
|
||||||
|
LIVE_STREAM
|
||||||
|
}
|
|
@ -11,3 +11,34 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "objectdetector",
|
||||||
|
srcs = [
|
||||||
|
"ObjectDetectionResult.java",
|
||||||
|
"ObjectDetector.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = ":AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/framework/formats:detection_java_proto_lite",
|
||||||
|
"//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||||
|
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
|
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Represents the detection results generated by {@link ObjectDetector}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ObjectDetectionResult implements TaskResult {
|
||||||
|
private static final int DEFAULT_CATEGORY_INDEX = -1;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
|
||||||
|
public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
|
||||||
|
* messages.
|
||||||
|
*
|
||||||
|
* @param detectionList a list of {@link Detection} protobuf messages.
|
||||||
|
*/
|
||||||
|
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
|
||||||
|
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
||||||
|
for (Detection detectionProto : detectionList) {
|
||||||
|
List<Category> categories = new ArrayList<>();
|
||||||
|
for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
|
||||||
|
categories.add(
|
||||||
|
Category.create(
|
||||||
|
detectionProto.getScore(idx),
|
||||||
|
detectionProto.getLabelIdCount() > idx
|
||||||
|
? detectionProto.getLabelId(idx)
|
||||||
|
: DEFAULT_CATEGORY_INDEX,
|
||||||
|
detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
|
||||||
|
detectionProto.getDisplayNameCount() > idx
|
||||||
|
? detectionProto.getDisplayName(idx)
|
||||||
|
: ""));
|
||||||
|
}
|
||||||
|
RectF boundingBox = new RectF();
|
||||||
|
if (detectionProto.getLocationData().hasBoundingBox()) {
|
||||||
|
BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
|
||||||
|
boundingBox.set(
|
||||||
|
/*left=*/ boundingBoxProto.getXmin(),
|
||||||
|
/*top=*/ boundingBoxProto.getYmin(),
|
||||||
|
/*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
|
||||||
|
/*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
|
||||||
|
}
|
||||||
|
detections.add(
|
||||||
|
com.google.mediapipe.tasks.components.containers.Detection.create(
|
||||||
|
categories, boundingBox));
|
||||||
|
}
|
||||||
|
return new AutoValue_ObjectDetectionResult(
|
||||||
|
timestampMs, Collections.unmodifiableList(detections));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,415 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
|
import com.google.mediapipe.framework.image.Image;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
|
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs object detection on images.
|
||||||
|
*
|
||||||
|
* <p>The API expects a TFLite model with <a
|
||||||
|
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||||
|
*
|
||||||
|
* <p>The API supports models with one image input tensor and four output tensors. To be more
|
||||||
|
* specific, here are the requirements.
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
|
||||||
|
* <ul>
|
||||||
|
* <li>image input of size {@code [batch x height x width x channels]}.
|
||||||
|
* <li>batch inference is not supported ({@code batch} is required to be 1).
|
||||||
|
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
|
||||||
|
* <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
|
||||||
|
* to the metadata for input normalization.
|
||||||
|
* </ul>
|
||||||
|
* <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e:
|
||||||
|
* <ul>
|
||||||
|
* <li>Location tensor ({@code kTfLiteFloat32}):
|
||||||
|
* <ul>
|
||||||
|
* <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing
|
||||||
|
* bounding boxes in the form [top, left, right, bottom].
|
||||||
|
* <li>{@code BoundingBoxProperties} are required to be attached to the metadata and
|
||||||
|
* must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}.
|
||||||
|
* </ul>
|
||||||
|
* <li>Classes tensor ({@code kTfLiteFloat32}):
|
||||||
|
* <ul>
|
||||||
|
* <li>tensor of size {@code [1 x num_results]}, each value representing the integer
|
||||||
|
* index of a class.
|
||||||
|
* <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS}
|
||||||
|
* associated files, they are used to convert the tensor values into labels.
|
||||||
|
* </ul>
|
||||||
|
* <li>scores tensor ({@code kTfLiteFloat32}):
|
||||||
|
* <ul>
|
||||||
|
* <li>tensor of size {@code [1 x num_results]}, each value representing the score of
|
||||||
|
* the detected object.
|
||||||
|
* </ul>
|
||||||
|
* <li>Number of detection tensor ({@code kTfLiteFloat32}):
|
||||||
|
* <ul>
|
||||||
|
* <li>integer num_results as a tensor of size {@code [1]}.
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>An example of such model can be found on <a
|
||||||
|
* href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow
|
||||||
|
* Hub.</a>.
|
||||||
|
*/
|
||||||
|
public final class ObjectDetector extends BaseVisionTaskApi {
|
||||||
|
private static final String TAG = ObjectDetector.class.getSimpleName();
|
||||||
|
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||||
|
private static final List<String> INPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
|
||||||
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
|
||||||
|
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||||
|
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
|
||||||
|
* ObjectDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the detection model with metadata in the assets.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||||
|
*/
|
||||||
|
public static ObjectDetector createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
|
||||||
|
* ObjectDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the detection model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||||
|
*/
|
||||||
|
public static ObjectDetector createFromFile(Context context, File modelFile) throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ObjectDetector} instance from a model buffer and the default {@link
|
||||||
|
* ObjectDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
|
||||||
|
* model.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||||
|
*/
|
||||||
|
public static ObjectDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param detectorOptions a {@link ObjectDetectorOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||||
|
*/
|
||||||
|
public static ObjectDetector createFromOptions(
|
||||||
|
Context context, ObjectDetectorOptions detectorOptions) {
|
||||||
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
|
OutputHandler<ObjectDetectionResult, Image> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, Image>() {
|
||||||
|
@Override
|
||||||
|
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
return ObjectDetectionResult.create(
|
||||||
|
PacketGetter.getProtoVector(
|
||||||
|
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
|
||||||
|
packets.get(DETECTIONS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image convertToTaskInput(List<Packet> packets) {
|
||||||
|
return new BitmapImageBuilder(
|
||||||
|
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
detectorOptions.resultListener().ifPresent(handler::setResultListener);
|
||||||
|
detectorOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<ObjectDetectorOptions>builder()
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(INPUT_STREAMS)
|
||||||
|
.setOutputStreams(OUTPUT_STREAMS)
|
||||||
|
.setTaskOptions(detectorOptions)
|
||||||
|
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
|
||||||
|
return new ObjectDetector(runner, detectorOptions.runningMode());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize an {@link ObjectDetector} from a {@link TaskRunner} and a {@link
|
||||||
|
* RunningMode}.
|
||||||
|
*
|
||||||
|
* @param taskRunner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||||
|
*/
|
||||||
|
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
|
||||||
|
super(taskRunner, runningMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs object detection on the provided single image. Only use this method when the {@link
|
||||||
|
* ObjectDetector} is created with {@link RunningMode.IMAGE}.
|
||||||
|
*
|
||||||
|
* <p>{@link ObjectDetector} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ObjectDetectionResult detect(Image inputImage) {
|
||||||
|
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs object detection on the provided video frame. Only use this method when the {@link
|
||||||
|
* ObjectDetector} is created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ObjectDetector} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
|
||||||
|
return (ObjectDetectionResult)
|
||||||
|
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data to perform object detection, and the results will be available via the
|
||||||
|
* {@link ResultListener} provided in the {@link ObjectDetectorOptions}. Only use this method when
|
||||||
|
* the {@link ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link ObjectDetector} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param inputImage a MediaPipe {@link Image} object for processing.
|
||||||
|
* @param inputTimestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void detectAsync(Image inputImage, long inputTimestampMs) {
|
||||||
|
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up an {@link ObjectDetector}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class ObjectDetectorOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link ObjectDetectorOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the base options for the object detector task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the running mode for the object detector task. Default to the image mode. Object
|
||||||
|
* detector has three modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>IMAGE: The mode for detecting objects on single image inputs.
|
||||||
|
* <li>VIDEO: The mode for detecting objects on the decoded frames of a video.
|
||||||
|
* <li>LIVE_STREAM: The mode for for detecting objects on a live stream of input data, such
|
||||||
|
* as from camera. In this mode, {@code setResultListener} must be called to set up a
|
||||||
|
* listener to receive the detection results asynchronously.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public abstract Builder setRunningMode(RunningMode value);
|
||||||
|
|
||||||
|
/** Sets the maximum number of top-scored classification results to return. */
|
||||||
|
public abstract Builder setDisplayNamesLocale(String value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional maximum number of top-scored detection results to return.
|
||||||
|
*
|
||||||
|
* <p>Overrides the ones provided in the model metadata. Results below this value are
|
||||||
|
* rejected.
|
||||||
|
*/
|
||||||
|
public abstract Builder setMaxResults(Integer value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional score threshold that overrides the one provided in the model metadata (if
|
||||||
|
* any). Results below this value are rejected.
|
||||||
|
*/
|
||||||
|
public abstract Builder setScoreThreshold(Float value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional allowlist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||||
|
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryDenylist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryAllowlist(List<String> value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional denylist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||||
|
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryAllowlist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryDenylist(List<String> value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the result listener to receive the detection results asynchronously when the object
|
||||||
|
* detector is in the live stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
|
||||||
|
|
||||||
|
/** Sets an optional error listener. */
|
||||||
|
public abstract Builder setErrorListener(ErrorListener value);
|
||||||
|
|
||||||
|
abstract ObjectDetectorOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link ObjectDetectorOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||||
|
* properly configured. The result listener should only be set when the object detector is
|
||||||
|
* in the live stream mode.
|
||||||
|
*/
|
||||||
|
public final ObjectDetectorOptions build() {
|
||||||
|
ObjectDetectorOptions options = autoBuild();
|
||||||
|
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||||
|
if (!options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The object detector is in the live stream mode, a user-defined result listener"
|
||||||
|
+ " must be provided in ObjectDetectorOptions.");
|
||||||
|
}
|
||||||
|
} else if (options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The object detector is in the image or the video mode, a user-defined result"
|
||||||
|
+ " listener shouldn't be provided in ObjectDetectorOptions.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
|
abstract Optional<String> displayNamesLocale();
|
||||||
|
|
||||||
|
abstract Optional<Integer> maxResults();
|
||||||
|
|
||||||
|
abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
|
abstract Optional<ResultListener<ObjectDetectionResult, Image>> resultListener();
|
||||||
|
|
||||||
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_ObjectDetector_ObjectDetectorOptions.Builder()
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.setCategoryAllowlist(Collections.emptyList())
|
||||||
|
.setCategoryDenylist(Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts a {@link ObjectDetectorOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
|
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
|
||||||
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
ObjectDetectorOptionsProto.ObjectDetectorOptions.Builder taskOptionsBuilder =
|
||||||
|
ObjectDetectorOptionsProto.ObjectDetectorOptions.newBuilder()
|
||||||
|
.setBaseOptions(baseOptionsBuilder);
|
||||||
|
displayNamesLocale().ifPresent(taskOptionsBuilder::setDisplayNamesLocale);
|
||||||
|
maxResults().ifPresent(taskOptionsBuilder::setMaxResults);
|
||||||
|
scoreThreshold().ifPresent(taskOptionsBuilder::setScoreThreshold);
|
||||||
|
if (!categoryAllowlist().isEmpty()) {
|
||||||
|
taskOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
|
||||||
|
}
|
||||||
|
if (!categoryDenylist().isEmpty()) {
|
||||||
|
taskOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
|
||||||
|
}
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
ObjectDetectorOptionsProto.ObjectDetectorOptions.ext, taskOptionsBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,3 +11,15 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "test_utils",
|
||||||
|
srcs = ["TestUtils.java"],
|
||||||
|
deps = [
|
||||||
|
"//third_party/java/android_libs/guava_jdk5:io",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.core;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.content.res.AssetManager;
|
||||||
|
import com.google.common.io.ByteStreams;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
|
||||||
|
/** Helper class for the Java test in MediaPipe Tasks. */
|
||||||
|
public final class TestUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads the file and create a {@link File} object by reading a file from the asset directory.
|
||||||
|
* Simulates downloading or reading a file that's not precompiled with the app.
|
||||||
|
*
|
||||||
|
* @return a {@link File} object for the model.
|
||||||
|
*/
|
||||||
|
public static File loadFile(Context context, String fileName) {
|
||||||
|
File target = new File(context.getFilesDir(), fileName);
|
||||||
|
try (InputStream is = context.getAssets().open(fileName);
|
||||||
|
FileOutputStream os = new FileOutputStream(target)) {
|
||||||
|
ByteStreams.copy(is, os);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new AssertionError("Failed to load model file at " + fileName, e);
|
||||||
|
}
|
||||||
|
return target;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads a file into a direct {@link ByteBuffer} object from the asset directory.
|
||||||
|
*
|
||||||
|
* @return a {@link ByteBuffer} object for the file.
|
||||||
|
*/
|
||||||
|
public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName)
|
||||||
|
throws IOException {
|
||||||
|
AssetManager assetManager = context.getAssets();
|
||||||
|
InputStream inputStream = assetManager.open(fileName);
|
||||||
|
byte[] bytes = ByteStreams.toByteArray(inputStream);
|
||||||
|
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
|
||||||
|
buffer.put(bytes);
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads a file into a non-direct {@link ByteBuffer} object from the asset directory.
|
||||||
|
*
|
||||||
|
* @return a {@link ByteBuffer} object for the file.
|
||||||
|
*/
|
||||||
|
public static ByteBuffer loadToNonDirectByteBuffer(Context context, String fileName)
|
||||||
|
throws IOException {
|
||||||
|
AssetManager assetManager = context.getAssets();
|
||||||
|
InputStream inputStream = assetManager.open(fileName);
|
||||||
|
byte[] bytes = ByteStreams.toByteArray(inputStream);
|
||||||
|
return ByteBuffer.wrap(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum ByteBufferType {
|
||||||
|
DIRECT,
|
||||||
|
BACK_UP_ARRAY,
|
||||||
|
OTHER // Non-direct ByteBuffer without a back-up array.
|
||||||
|
}
|
||||||
|
|
||||||
|
private TestUtils() {}
|
||||||
|
}
|
|
@ -12,4 +12,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@build_bazel_rules_android//android:rules.bzl", "android_library_test")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
# TODO: Enable this in OSS
|
# TODO: Enable this in OSS
|
||||||
|
|
|
@ -0,0 +1,456 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.objectdetector;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import android.content.res.AssetManager;
|
||||||
|
import android.graphics.BitmapFactory;
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import androidx.test.core.app.ApplicationProvider;
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
|
import com.google.mediapipe.framework.image.Image;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Detection;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TestUtils;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Suite;
|
||||||
|
import org.junit.runners.Suite.SuiteClasses;
|
||||||
|
|
||||||
|
/** Test for {@link ObjectDetector}. */
|
||||||
|
@RunWith(Suite.class)
|
||||||
|
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
|
||||||
|
public class ObjectDetectorTest {
|
||||||
|
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
|
||||||
|
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
|
||||||
|
private static final int IMAGE_WIDTH = 1200;
|
||||||
|
private static final int IMAGE_HEIGHT = 600;
|
||||||
|
private static final float CAT_SCORE = 0.69f;
|
||||||
|
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
|
||||||
|
// TODO: Figure out why android_x86 and android_arm tests have slightly different
|
||||||
|
// scores (0.6875 vs 0.69921875).
|
||||||
|
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
|
||||||
|
private static final float PIXEL_DIFF_TOLERANCE = 5.0f;
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class General extends ObjectDetectorTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithValidModels() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setMaxResults(1)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithNoOptions() throws Exception {
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// Check if the object with the highest score is cat.
|
||||||
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithMaxResultsOption() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setMaxResults(8)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// results should have 8 detected objects because maxResults was set to 8.
|
||||||
|
assertThat(results.detections()).hasSize(8);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithScoreThresholdOption() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setScoreThreshold(0.68f)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// The score threshold should block all other other objects, except cat.
|
||||||
|
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithAllowListOption() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setCategoryAllowlist(Arrays.asList("cat"))
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// Because of the allowlist, results should only contain cat, and there are 6 detected
|
||||||
|
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
|
||||||
|
assertThat(results.detections()).hasSize(5);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithDenyListOption() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setCategoryDenylist(Arrays.asList("cat"))
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// Because of the denylist, the highest result is not cat anymore.
|
||||||
|
assertThat(results.detections().get(0).categories().get(0).categoryName())
|
||||||
|
.isNotEqualTo("cat");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithModelFileObject() throws Exception {
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(),
|
||||||
|
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// Check if the object with the highest score is cat.
|
||||||
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithModelBuffer() throws Exception {
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(),
|
||||||
|
TestUtils.loadToDirectByteBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(), MODEL_FILE));
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
// Check if the object with the highest score is cat.
|
||||||
|
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_succeedsWithModelBufferAndOptions() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetBuffer(
|
||||||
|
TestUtils.loadToDirectByteBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(), MODEL_FILE))
|
||||||
|
.build())
|
||||||
|
.setMaxResults(1)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingModel() throws Exception {
|
||||||
|
String nonexistentFile = "/path/to/non/existent/file";
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
ObjectDetector.createFromFile(
|
||||||
|
ApplicationProvider.getApplicationContext(), nonexistentFile));
|
||||||
|
assertThat(exception).hasMessageThat().contains(nonexistentFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithInvalidModelBuffer() throws Exception {
|
||||||
|
// Create a non-direct model ByteBuffer.
|
||||||
|
ByteBuffer modelBuffer =
|
||||||
|
TestUtils.loadToNonDirectByteBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(), MODEL_FILE);
|
||||||
|
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ObjectDetector.createFromBuffer(
|
||||||
|
ApplicationProvider.getApplicationContext(), modelBuffer));
|
||||||
|
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_failsWithBothAllowAndDenyListOption() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setCategoryAllowlist(Arrays.asList("cat"))
|
||||||
|
.setCategoryDenylist(Arrays.asList("dog"))
|
||||||
|
.build();
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
ObjectDetector.createFromOptions(
|
||||||
|
ApplicationProvider.getApplicationContext(), options));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
|
||||||
|
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
|
||||||
|
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
|
||||||
|
}
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class RunningModeTest extends ObjectDetectorTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||||
|
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(mode)
|
||||||
|
.setResultListener((objectDetectionResult, inputImage) -> {})
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("a user-defined result listener shouldn't be provided");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("a user-defined result listener must be provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener((objectDetectionResult, inputImage) -> {})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithImageMode() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.setMaxResults(1)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
|
||||||
|
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithVideoMode() throws Exception {
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
|
.setMaxResults(1)
|
||||||
|
.build();
|
||||||
|
ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
ObjectDetectionResult results =
|
||||||
|
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
|
||||||
|
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||||
|
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener(
|
||||||
|
(objectDetectionResult, inputImage) -> {
|
||||||
|
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
||||||
|
assertImageSizeIsExpected(inputImage);
|
||||||
|
})
|
||||||
|
.setMaxResults(1)
|
||||||
|
.build();
|
||||||
|
try (ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
|
objectDetector.detectAsync(image, 1);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithLiveSteamMode() throws Exception {
|
||||||
|
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
|
||||||
|
ObjectDetectorOptions options =
|
||||||
|
ObjectDetectorOptions.builder()
|
||||||
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener(
|
||||||
|
(objectDetectionResult, inputImage) -> {
|
||||||
|
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
|
||||||
|
assertImageSizeIsExpected(inputImage);
|
||||||
|
})
|
||||||
|
.setMaxResults(1)
|
||||||
|
.build();
|
||||||
|
try (ObjectDetector objectDetector =
|
||||||
|
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
objectDetector.detectAsync(image, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Image getImageFromAsset(String filePath) throws Exception {
|
||||||
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks if results has one and only detection result, which is a cat.
|
||||||
|
private static void assertContainsOnlyCat(
|
||||||
|
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
|
||||||
|
assertThat(result.detections()).hasSize(1);
|
||||||
|
Detection catResult = result.detections().get(0);
|
||||||
|
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
|
||||||
|
// We only support one category for each detected object at this point.
|
||||||
|
assertIsCat(catResult.categories().get(0), expectedScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertIsCat(Category category, float expectedScore) {
|
||||||
|
assertThat(category.categoryName()).isEqualTo("cat");
|
||||||
|
// coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite does not support label locale.
|
||||||
|
assertThat(category.displayName()).isEmpty();
|
||||||
|
assertThat((double) category.score()).isWithin(SCORE_DIFF_TOLERANCE).of(expectedScore);
|
||||||
|
assertThat(category.index()).isEqualTo(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertApproximatelyEqualBoundingBoxes(
|
||||||
|
RectF boundingBox1, RectF boundingBox2) {
|
||||||
|
assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left);
|
||||||
|
assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top);
|
||||||
|
assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right);
|
||||||
|
assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertImageSizeIsExpected(Image inputImage) {
|
||||||
|
assertThat(inputImage).isNotNull();
|
||||||
|
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
|
||||||
|
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user