diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto
index 16aa86aeb..5d4ba3296 100644
--- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto
+++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto
@@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
+option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
+option java_outer_classname = "AudioClassifierGraphOptionsProto";
+
message AudioClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional AudioClassifierGraphOptions ext = 451755788;
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
new file mode 100644
index 000000000..c5cb6f8a3
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
@@ -0,0 +1,76 @@
+# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+licenses(["notice"])
+
+package(default_visibility = ["//mediapipe/tasks:internal"])
+
+android_library(
+ name = "core",
+ srcs = glob(["core/*.java"]),
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ deps = [
+ ":libmediapipe_tasks_audio_jni_lib",
+ "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
+# The native library of all MediaPipe audio tasks.
+cc_binary(
+ name = "libmediapipe_tasks_audio_jni.so",
+ linkshared = 1,
+ linkstatic = 1,
+ deps = [
+ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
+ "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
+ ],
+)
+
+cc_library(
+ name = "libmediapipe_tasks_audio_jni_lib",
+ srcs = [":libmediapipe_tasks_audio_jni.so"],
+ alwayslink = 1,
+)
+
+android_library(
+ name = "audioclassifier",
+ srcs = [
+ "audioclassifier/AudioClassifier.java",
+ "audioclassifier/AudioClassifierResult.java",
+ ],
+ javacopts = [
+ "-Xep:AndroidJdkLibsChecker:OFF",
+ ],
+ manifest = "audioclassifier/AndroidManifest.xml",
+ deps = [
+ ":core",
+ "//mediapipe/framework:calculator_options_java_proto_lite",
+ "//mediapipe/java/com/google/mediapipe/framework:android_framework",
+ "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
+ "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
+ "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
+ "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml
new file mode 100644
index 000000000..1e9817efd
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java
new file mode 100644
index 000000000..88b6daf0a
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java
@@ -0,0 +1,399 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.audio.audioclassifier;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.Packet;
+import com.google.mediapipe.framework.PacketGetter;
+import com.google.mediapipe.framework.ProtoUtil;
+import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
+import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
+import com.google.mediapipe.tasks.audio.core.RunningMode;
+import com.google.mediapipe.tasks.components.containers.AudioData;
+import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
+import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
+import com.google.mediapipe.tasks.core.BaseOptions;
+import com.google.mediapipe.tasks.core.ErrorListener;
+import com.google.mediapipe.tasks.core.OutputHandler;
+import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
+import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
+import com.google.mediapipe.tasks.core.TaskInfo;
+import com.google.mediapipe.tasks.core.TaskOptions;
+import com.google.mediapipe.tasks.core.TaskRunner;
+import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Performs audio classification on audio clips or audio stream.
+ *
+ *
This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
+ * mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
+ * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
+ *
+ *
Input tensor: (kTfLiteFloat32)
+ *
+ *
+ * - input audio buffer of size `[batch * samples]`.
+ *
- batch inference is not supported (`batch` is required to be 1).
+ *
- for multi-channel models, the channels need be interleaved.
+ *
+ *
+ * At least one output tensor with: (kTfLiteFloat32)
+ *
+ *
+ * - `[1 x N]` array with `N` represents the number of categories.
+ *
- optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
+ * containing one label per line. The first such AssociatedFile (if any) is used to fill the
+ * `category_name` field of the results. The `display_name` field is filled from the
+ * AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
+ * `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
+ * these are available, only the `index` field of the results will be filled.
+ *
+ */
+public final class AudioClassifier extends BaseAudioTaskApi {
+ private static final String TAG = AudioClassifier.class.getSimpleName();
+ private static final String AUDIO_IN_STREAM_NAME = "audio_in";
+ private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
+ private static final List INPUT_STREAMS =
+ Collections.unmodifiableList(
+ Arrays.asList(
+ "AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
+ private static final List OUTPUT_STREAMS =
+ Collections.unmodifiableList(
+ Arrays.asList(
+ "CLASSIFICATIONS:classifications_out",
+ "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
+ private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
+ private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
+ private static final String TASK_GRAPH_NAME =
+ "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
+ private static final long MICROSECONDS_PER_MILLISECOND = 1000;
+
+ static {
+ ProtoUtil.registerTypeName(
+ ClassificationsProto.ClassificationResult.class,
+ "mediapipe.tasks.components.containers.proto.ClassificationResult");
+ }
+
+ /**
+ * Creates an {@link AudioClassifier} instance from a model file and default {@link
+ * AudioClassifierOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelPath path to the classification model in the assets.
+ * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
+ */
+ public static AudioClassifier createFromFile(Context context, String modelPath) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
+ return createFromOptions(
+ context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates an {@link AudioClassifier} instance from a model file and default {@link
+ * AudioClassifierOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelFile the classification model {@link File} instance.
+ * @throws IOException if an I/O error occurs when opening the tflite model file.
+ * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
+ */
+ public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ BaseOptions baseOptions =
+ BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
+ return createFromOptions(
+ context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
+ }
+ }
+
+ /**
+ * Creates an {@link AudioClassifier} instance from a model buffer and default {@link
+ * AudioClassifierOptions}.
+ *
+ * @param context an Android {@link Context}.
+ * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
+ * classification model.
+ * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
+ */
+ public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
+ BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
+ return createFromOptions(
+ context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
+ }
+
+ /**
+ * Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
+ *
+ * @param context an Android {@link Context}.
+ * @param options an {@link AudioClassifierOptions} instance.
+ * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
+ */
+ public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
+ OutputHandler handler = new OutputHandler<>();
+ handler.setOutputPacketConverter(
+ new OutputHandler.OutputPacketConverter() {
+ @Override
+ public AudioClassifierResult convertToTaskResult(List packets) {
+ try {
+ if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
+ // For audio stream mode.
+ return AudioClassifierResult.createFromProto(
+ PacketGetter.getProto(
+ packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
+ ClassificationsProto.ClassificationResult.getDefaultInstance()),
+ packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
+ / MICROSECONDS_PER_MILLISECOND);
+ } else {
+ // For audio clips mode.
+ return AudioClassifierResult.createFromProtoList(
+ PacketGetter.getProtoVector(
+ packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
+ ClassificationsProto.ClassificationResult.parser()),
+ -1);
+ }
+ } catch (IOException e) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
+ }
+ }
+
+ @Override
+ public Void convertToTaskInput(List packets) {
+ return null;
+ }
+ });
+ if (options.resultListener().isPresent()) {
+ ResultListener resultListener =
+ new ResultListener<>() {
+ @Override
+ public void run(AudioClassifierResult audioClassifierResult, Void input) {
+ options.resultListener().get().run(audioClassifierResult);
+ }
+ };
+ handler.setResultListener(resultListener);
+ }
+ options.errorListener().ifPresent(handler::setErrorListener);
+ // Audio tasks should not drop input audio due to flow limiting, which may cause data
+ // inconsistency.
+ TaskRunner runner =
+ TaskRunner.create(
+ context,
+ TaskInfo.builder()
+ .setTaskGraphName(TASK_GRAPH_NAME)
+ .setInputStreams(INPUT_STREAMS)
+ .setOutputStreams(OUTPUT_STREAMS)
+ .setTaskOptions(options)
+ .setEnableFlowLimiting(false)
+ .build(),
+ handler);
+ return new AudioClassifier(runner, options.runningMode());
+ }
+
+ /**
+ * Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
+ * RunningMode}.
+ *
+ * @param taskRunner a {@link TaskRunner}.
+ * @param runningMode a mediapipe audio task {@link RunningMode}.
+ */
+ private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
+ super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
+ }
+
+ /*
+ * Performs audio classification on the provided audio clip. Only use this method when the
+ * AudioClassifier is created with the audio clips mode.
+ *
+ * The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
+ * audio clips with various length and audio sample rate. It's required to provide the
+ * corresponding audio sample rate within the {@link AudioData} object.
+ *
+ *
The input audio clip may be longer than what the model is able to process in a single
+ * inference. When this occurs, the input audio clip is split into multiple chunks starting at
+ * different timestamps. For this reason, this function returns a vector of ClassificationResult
+ * objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
+ * chunk data that was classified, e.g:
+ *
+ * ClassificationResult #0 (first chunk of data):
+ * timestamp_ms: 0 (starts at 0ms)
+ * classifications #0 (single head model):
+ * category #0:
+ * category_name: "Speech"
+ * score: 0.6
+ * category #1:
+ * category_name: "Music"
+ * score: 0.2
+ * ClassificationResult #1 (second chunk of data):
+ * timestamp_ms: 800 (starts at 800ms)
+ * classifications #0 (single head model):
+ * category #0:
+ * category_name: "Speech"
+ * score: 0.5
+ * category #1:
+ * category_name: "Silence"
+ * score: 0.1
+ *
+ * @param audioClip a MediaPipe {@link AudioData} object for processing.
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public AudioClassifierResult classify(AudioData audioClip) {
+ return (AudioClassifierResult) processAudioClip(audioClip);
+ }
+
+ /*
+ * Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
+ * use this method when the AudioClassifier is created with the audio stream mode.
+ *
+ *
The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
+ * be resampled, accumulated, and framed to the proper size for the underlying model to consume.
+ * It's required to provide the corresponding audio sample rate within {@link AudioData} object as
+ * well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
+ * timestamps must be monotonically increasing. This method will return immediately after
+ * the input audio data is accepted. The results will be available in the `resultListener`
+ * provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
+ * auido stream data such as microphone input.
+ *
+ *
The input audio block may be longer than what the model is able to process in a single
+ * inference. When this occurs, the input audio block is split into multiple chunks. For this
+ * reason, the callback may be called multiple times (once per chunk) for each call to this
+ * function.
+ *
+ * @param audioBlock a MediaPipe {@link AudioData} object for processing.
+ * @param timestampMs the input timestamp (in milliseconds).
+ * @throws MediaPipeException if there is an internal error.
+ */
+ public void classifyAsync(AudioData audioBlock, long timestampMs) {
+ checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
+ sendAudioStreamData(audioBlock, timestampMs);
+ }
+
+ /** Options for setting up and {@link AudioClassifier}. */
+ @AutoValue
+ public abstract static class AudioClassifierOptions extends TaskOptions {
+
+ /** Builder for {@link AudioClassifierOptions}. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ /** Sets the {@link BaseOptions} for the audio classifier task. */
+ public abstract Builder setBaseOptions(BaseOptions baseOptions);
+
+ /**
+ * Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
+ * mode. Image classifier has two modes:
+ *
+ *
+ * - AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
+ * audio clips to the `classify` method, and will receive the classification results as
+ * the return value.
+ *
- AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
+ * from microphone. Users call `classifyAsync` to push the audio data into the
+ * AudioClassifier, the classification results will be available in the result callback
+ * when the audio classifier finishes the work.
+ *
+ */
+ public abstract Builder setRunningMode(RunningMode runningMode);
+
+ /**
+ * Sets the optional {@link ClassifierOptions} controling classification behavior, such as
+ * score threshold, number of results, etc.
+ */
+ public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
+
+ /**
+ * Sets the {@link ResultListener} to receive the classification results asynchronously when
+ * the audio classifier is in the audio stream mode.
+ */
+ public abstract Builder setResultListener(
+ PureResultListener resultListener);
+
+ /** Sets an optional {@link ErrorListener}. */
+ public abstract Builder setErrorListener(ErrorListener errorListener);
+
+ abstract AudioClassifierOptions autoBuild();
+
+ /**
+ * Validates and builds the {@link AudioClassifierOptions} instance.
+ *
+ * @throws IllegalArgumentException if the result listener and the running mode are not
+ * properly configured. The result listener should only be set when the audio classifier
+ * is in the audio stream mode.
+ */
+ public final AudioClassifierOptions build() {
+ AudioClassifierOptions options = autoBuild();
+ if (options.runningMode() == RunningMode.AUDIO_STREAM) {
+ if (!options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The audio classifier is in the audio stream mode, a user-defined result listener"
+ + " must be provided in the AudioClassifierOptions.");
+ }
+ } else if (options.resultListener().isPresent()) {
+ throw new IllegalArgumentException(
+ "The audio classifier is in the audio clips mode, a user-defined result listener"
+ + " shouldn't be provided in AudioClassifierOptions.");
+ }
+ return options;
+ }
+ }
+
+ abstract BaseOptions baseOptions();
+
+ abstract RunningMode runningMode();
+
+ abstract Optional classifierOptions();
+
+ abstract Optional> resultListener();
+
+ abstract Optional errorListener();
+
+ public static Builder builder() {
+ return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
+ .setRunningMode(RunningMode.AUDIO_CLIPS);
+ }
+
+ /**
+ * Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
+ */
+ @Override
+ public CalculatorOptions convertToCalculatorOptionsProto() {
+ BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
+ BaseOptionsProto.BaseOptions.newBuilder();
+ baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
+ baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
+ AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
+ AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
+ .setBaseOptions(baseOptionsBuilder);
+ if (classifierOptions().isPresent()) {
+ taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
+ }
+ return CalculatorOptions.newBuilder()
+ .setExtension(
+ AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
+ taskOptionsBuilder.build())
+ .build();
+ }
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java
new file mode 100644
index 000000000..fcc3c6e22
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java
@@ -0,0 +1,76 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.audio.audioclassifier;
+
+import com.google.auto.value.AutoValue;
+import com.google.mediapipe.tasks.components.containers.ClassificationResult;
+import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
+import com.google.mediapipe.tasks.core.TaskResult;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/** Represents the classification results generated by {@link AudioClassifier}. */
+@AutoValue
+public abstract class AudioClassifierResult implements TaskResult {
+
+ /**
+ * Creates an {@link AudioClassifierResult} instance from a list of {@link
+ * ClassificationsProto.ClassificationResult} protobuf messages.
+ *
+ * @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
+ * to convert.
+ * @param timestampMs a timestamp for this result.
+ */
+ static AudioClassifierResult createFromProtoList(
+ List protoList, long timestampMs) {
+ List classificationResultList = new ArrayList<>();
+ for (ClassificationsProto.ClassificationResult proto : protoList) {
+ classificationResultList.add(ClassificationResult.createFromProto(proto));
+ }
+ return new AutoValue_AudioClassifierResult(
+ Optional.of(classificationResultList), Optional.empty(), timestampMs);
+ }
+
+ /**
+ * Creates an {@link AudioClassifierResult} instance from a {@link
+ * ClassificationsProto.ClassificationResult} protobuf message.
+ *
+ * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
+ * @param timestampMs a timestamp for this result.
+ */
+ static AudioClassifierResult createFromProto(
+ ClassificationsProto.ClassificationResult proto, long timestampMs) {
+ return new AutoValue_AudioClassifierResult(
+ Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
+ }
+
+ /**
+ * A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
+ * per classifier head. The list represents the audio classification result of an audio clip, and
+ * is only available when running with the audio clips mode.
+ */
+ public abstract Optional> classificationResultList();
+
+ /**
+ * Contains one set of results per classifier head. A {@link ClassificationResult} usually
+ * represents one audio classification result in an audio stream, and s only available when
+ * running with the audio stream mode.
+ */
+ public abstract Optional classificationResult();
+
+ @Override
+ public abstract long timestampMs();
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java
new file mode 100644
index 000000000..affe43559
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java
@@ -0,0 +1,151 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.audio.core;
+
+import com.google.mediapipe.framework.MediaPipeException;
+import com.google.mediapipe.framework.Packet;
+import com.google.mediapipe.tasks.components.containers.AudioData;
+import com.google.mediapipe.tasks.core.TaskResult;
+import com.google.mediapipe.tasks.core.TaskRunner;
+import java.util.HashMap;
+import java.util.Map;
+
+/** The base class of MediaPipe audio tasks. */
+public class BaseAudioTaskApi implements AutoCloseable {
+ private static final long MICROSECONDS_PER_MILLISECOND = 1000;
+ private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
+
+ private final TaskRunner runner;
+ private final RunningMode runningMode;
+ private final String audioStreamName;
+ private final String sampleRateStreamName;
+ private double defaultSampleRate;
+
+ static {
+ System.loadLibrary("mediapipe_tasks_audio_jni");
+ }
+
+ /**
+ * Constructor to initialize a {@link BaseAudioTaskApi}.
+ *
+ * @param runner a {@link TaskRunner}.
+ * @param runningMode a mediapipe audio task {@link RunningMode}.
+ * @param audioStreamName the name of the input audio stream.
+ * @param sampleRateStreamName the name of the audio sample rate stream.
+ */
+ public BaseAudioTaskApi(
+ TaskRunner runner,
+ RunningMode runningMode,
+ String audioStreamName,
+ String sampleRateStreamName) {
+ this.runner = runner;
+ this.runningMode = runningMode;
+ this.audioStreamName = audioStreamName;
+ this.sampleRateStreamName = sampleRateStreamName;
+ this.defaultSampleRate = -1.0;
+ }
+
+ /**
+ * A synchronous method to process audio clips. The call blocks the current thread until a failure
+ * status or a successful result is returned.
+ *
+ * @param audioClip a MediaPipe {@link AudioDatra} object for processing.
+ * @throws MediaPipeException if the task is not in the audio clips mode.
+ */
+ protected TaskResult processAudioClip(AudioData audioClip) {
+ if (runningMode != RunningMode.AUDIO_CLIPS) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
+ "Task is not initialized with the audio clips mode. Current running mode:"
+ + runningMode.name());
+ }
+ Map inputPackets = new HashMap<>();
+ inputPackets.put(
+ audioStreamName,
+ runner
+ .getPacketCreator()
+ .createMatrix(
+ audioClip.getFormat().getNumOfChannels(),
+ audioClip.getBufferLength(),
+ audioClip.getBuffer()));
+ inputPackets.put(
+ sampleRateStreamName,
+ runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
+ return runner.process(inputPackets);
+ }
+
+ /**
+ * Checks or sets the audio sample rate in the audio stream mode.
+ *
+ * @param sampleRate the audio sample rate.
+ * @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
+ * rate is inconsisent with the previously recevied.
+ */
+ protected void checkOrSetSampleRate(double sampleRate) {
+ if (runningMode != RunningMode.AUDIO_STREAM) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
+ "Task is not initialized with the audio stream mode. Current running mode:"
+ + runningMode.name());
+ }
+ if (defaultSampleRate > 0) {
+ if (Double.compare(sampleRate, defaultSampleRate) != 0) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
+ "The input audio sample rate: "
+ + sampleRate
+ + " is inconsistent with the previously provided: "
+ + defaultSampleRate);
+ }
+ } else {
+ Map inputPackets = new HashMap<>();
+ inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
+ runner.send(inputPackets, PRESTREAM_TIMESTAMP);
+ defaultSampleRate = sampleRate;
+ }
+ }
+ /**
+ * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
+ * available in the user-defined result listener.
+ *
+ * @param audioClip a MediaPipe {@link AudioDatra} object for processing.
+ * @param timestampMs the corresponding timestamp of the input image in milliseconds.
+ * @throws MediaPipeException if the task is not in the stream mode.
+ */
+ protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
+ if (runningMode != RunningMode.AUDIO_STREAM) {
+ throw new MediaPipeException(
+ MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
+ "Task is not initialized with the audio stream mode. Current running mode:"
+ + runningMode.name());
+ }
+ Map inputPackets = new HashMap<>();
+ inputPackets.put(
+ audioStreamName,
+ runner
+ .getPacketCreator()
+ .createMatrix(
+ audioClip.getFormat().getNumOfChannels(),
+ audioClip.getBufferLength(),
+ audioClip.getBuffer()));
+ runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
+ }
+
+ /** Closes and cleans up the MediaPipe audio task. */
+ @Override
+ public void close() {
+ runner.close();
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java
new file mode 100644
index 000000000..f0a123810
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java
@@ -0,0 +1,29 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.audio.core;
+
+/**
+ * MediaPipe audio task running mode. A MediaPipe audio task can be run with two different modes:
+ *
+ *
+ * - AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips.
+ *
- AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from
+ * microphone.
+ *
+ */
+public enum RunningMode {
+ AUDIO_CLIPS,
+ AUDIO_STREAM,
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java
new file mode 100644
index 000000000..40a05b1e4
--- /dev/null
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java
@@ -0,0 +1,334 @@
+// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.mediapipe.tasks.components.containers;
+
+import static java.lang.System.arraycopy;
+
+import android.media.AudioFormat;
+import android.media.AudioRecord;
+import com.google.auto.value.AutoValue;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+
+/**
+ * Defines a ring buffer and some utility functions to prepare the input audio samples.
+ *
+ * It maintains a Ring Buffer to hold
+ * input audio data. Clients could feed input audio data via `load` methods and access the
+ * aggregated audio samples via `getTensorBuffer` method.
+ *
+ *
Note that this class can only handle input audio in Float (in {@link
+ * android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link
+ * android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio
+ * samples in PCM Float encoding.
+ *
+ *
Typical usage in Kotlin
+ *
+ *
+ * val audioData = AudioData.create(format, modelInputLength)
+ * audioData.load(newData)
+ *
+ *
+ * Another sample usage with {@link android.media.AudioRecord}
+ *
+ *
+ * val audioData = AudioData.create(format, modelInputLength)
+ * Timer().scheduleAtFixedRate(delay, period) {
+ * audioData.load(audioRecord)
+ * }
+ *
+ */
+public class AudioData {
+
+ private static final String TAG = AudioData.class.getSimpleName();
+ private final FloatRingBuffer buffer;
+ private final AudioDataFormat format;
+
+ /**
+ * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
+ * sampleCounts} * {@code format.getNumOfChannels()}.
+ *
+ * @param format the expected {@link AudioDataFormat} of audio data loaded into this class.
+ * @param sampleCounts the number of samples.
+ */
+ public static AudioData create(AudioDataFormat format, int sampleCounts) {
+ return new AudioData(format, sampleCounts);
+ }
+
+ /**
+ * Creates a {@link AudioData} instance with a ring buffer whose size is {@code sampleCounts} *
+ * {@code format.getChannelCount()}.
+ *
+ * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
+ * the number of channels and sample rate.
+ * @param sampleCounts the number of samples to be fed into the model
+ */
+ public static AudioData create(AudioFormat format, int sampleCounts) {
+ return new AudioData(AudioDataFormat.create(format), sampleCounts);
+ }
+
+ /**
+ * Wraps a few constants describing the format of the incoming audio samples, namely number of
+ * channels and the sample rate. By default, num of channels is set to 1.
+ */
+ @AutoValue
+ public abstract static class AudioDataFormat {
+ private static final int DEFAULT_NUM_OF_CHANNELS = 1;
+
+ /** Creates a {@link AudioFormat} instance from Android AudioFormat class. */
+ public static AudioDataFormat create(AudioFormat format) {
+ return AudioDataFormat.builder()
+ .setNumOfChannels(format.getChannelCount())
+ .setSampleRate(format.getSampleRate())
+ .build();
+ }
+
+ public abstract int getNumOfChannels();
+
+ public abstract float getSampleRate();
+
+ public static Builder builder() {
+ return new AutoValue_AudioData_AudioDataFormat.Builder()
+ .setNumOfChannels(DEFAULT_NUM_OF_CHANNELS);
+ }
+
+ /** Builder for {@link AudioDataFormat} */
+ @AutoValue.Builder
+ public abstract static class Builder {
+
+ /* By default, it's set to have 1 channel. */
+ public abstract Builder setNumOfChannels(int value);
+
+ public abstract Builder setSampleRate(float value);
+
+ abstract AudioDataFormat autoBuild();
+
+ public AudioDataFormat build() {
+ AudioDataFormat format = autoBuild();
+ if (format.getNumOfChannels() <= 0) {
+ throw new IllegalArgumentException("Number of channels should be greater than 0");
+ }
+ if (format.getSampleRate() <= 0) {
+ throw new IllegalArgumentException("Sample rate should be greater than 0");
+ }
+ return format;
+ }
+ }
+ }
+
+ /**
+ * Stores the input audio samples {@code src} in the ring buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
+ * multi-channel input, the array is interleaved.
+ */
+ public void load(float[] src) {
+ load(src, 0, src.length);
+ }
+
+ /**
+ * Stores the input audio samples {@code src} in the ring buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
+ * multi-channel input, the array is interleaved.
+ * @param offsetInFloat starting position in the {@code src} array
+ * @param sizeInFloat the number of float values to be copied
+ * @throws IllegalArgumentException for incompatible audio format or incorrect input size
+ */
+ public void load(float[] src, int offsetInFloat, int sizeInFloat) {
+ if (sizeInFloat % format.getNumOfChannels() != 0) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Size (%d) needs to be a multiplier of the number of channels (%d)",
+ sizeInFloat, format.getNumOfChannels()));
+ }
+ buffer.load(src, offsetInFloat, sizeInFloat);
+ }
+
+ /**
+ * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
+ * buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
+ * multi-channel input, the array is interleaved.
+ */
+ public void load(short[] src) {
+ load(src, 0, src.length);
+ }
+
+ /**
+ * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
+ * buffer.
+ *
+ * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
+ * multi-channel input, the array is interleaved.
+ * @param offsetInShort starting position in the src array
+ * @param sizeInShort the number of short values to be copied
+ * @throws IllegalArgumentException if the source array can't be copied
+ */
+ public void load(short[] src, int offsetInShort, int sizeInShort) {
+ if (offsetInShort + sizeInShort > src.length) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
+ offsetInShort, sizeInShort, src.length));
+ }
+ float[] floatData = new float[sizeInShort];
+ for (int i = 0; i < sizeInShort; i++) {
+ // Convert the data to PCM Float encoding i.e. values between -1 and 1
+ floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
+ }
+ load(floatData);
+ }
+
+ /**
+ * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
+ * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
+ *
+ * @param record an instance of {@link android.media.AudioRecord}
+ * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
+ * there was no new data in the AudioRecord or an error occurred, this method will return 0.
+ * @throws IllegalArgumentException for unsupported audio encoding format
+ * @throws IllegalStateException if reading from AudioRecord failed
+ */
+ public int load(AudioRecord record) {
+ if (!this.format.equals(AudioDataFormat.create(record.getFormat()))) {
+ throw new IllegalArgumentException("Incompatible audio format.");
+ }
+ int loadedValues = 0;
+ if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
+ float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
+ loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
+ if (loadedValues > 0) {
+ load(newData, 0, loadedValues);
+ return loadedValues;
+ }
+ } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
+ short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
+ loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
+ if (loadedValues > 0) {
+ load(newData, 0, loadedValues);
+ return loadedValues;
+ }
+ } else {
+ throw new IllegalArgumentException(
+ "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
+ }
+
+ switch (loadedValues) {
+ case AudioRecord.ERROR_INVALID_OPERATION:
+ throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
+
+ case AudioRecord.ERROR_BAD_VALUE:
+ throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
+
+ case AudioRecord.ERROR_DEAD_OBJECT:
+ throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
+
+ case AudioRecord.ERROR:
+ throw new IllegalStateException("AudioRecord.ERROR");
+
+ default:
+ return 0;
+ }
+ }
+
+ /**
+ * Returns a float array holding all the available audio samples in {@link
+ * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
+ */
+ public float[] getBuffer() {
+ float[] bufferData = new float[buffer.getCapacity()];
+ ByteBuffer byteBuffer = buffer.getBuffer();
+ byteBuffer.asFloatBuffer().get(bufferData);
+ return bufferData;
+ }
+
+ /* Returns the {@link AudioDataFormat} associated with the tensor. */
+ public AudioDataFormat getFormat() {
+ return format;
+ }
+
+ /* Returns the audio buffer length. */
+ public int getBufferLength() {
+ return buffer.getCapacity() / format.getNumOfChannels();
+ }
+
+ private AudioData(AudioDataFormat format, int sampleCounts) {
+ this.format = format;
+ this.buffer = new FloatRingBuffer(sampleCounts * format.getNumOfChannels());
+ }
+
+ /** Actual implementation of the ring buffer. */
+ private static class FloatRingBuffer {
+
+ private final float[] buffer;
+ private int nextIndex = 0;
+
+ public FloatRingBuffer(int flatSize) {
+ buffer = new float[flatSize];
+ }
+
+ /**
+ * Loads a slice of the float array to the ring buffer. If the float array is longer than ring
+ * buffer's capacity, samples with lower indices in the array will be ignored.
+ */
+ public void load(float[] newData, int offset, int size) {
+ if (offset + size > newData.length) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
+ offset, size, newData.length));
+ }
+ // If buffer can't hold all the data, only keep the most recent data of size buffer.length
+ if (size > buffer.length) {
+ offset += (size - buffer.length);
+ size = buffer.length;
+ }
+ if (nextIndex + size < buffer.length) {
+ // No need to wrap nextIndex, just copy newData[offset:offset + size]
+ // to buffer[nextIndex:nextIndex+size]
+ arraycopy(newData, offset, buffer, nextIndex, size);
+ } else {
+ // Need to wrap nextIndex, perform copy in two chunks.
+ int firstChunkSize = buffer.length - nextIndex;
+ // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
+ arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
+ // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
+ arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
+ }
+
+ nextIndex = (nextIndex + size) % buffer.length;
+ }
+
+ public ByteBuffer getBuffer() {
+ // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
+ // can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
+ // generally we don't create direct buffer for every invocation.
+ ByteBuffer byteBuffer = ByteBuffer.allocate(Float.SIZE / 8 * buffer.length);
+ byteBuffer.order(ByteOrder.nativeOrder());
+ FloatBuffer result = byteBuffer.asFloatBuffer();
+ result.put(buffer, nextIndex, buffer.length - nextIndex);
+ result.put(buffer, 0, nextIndex);
+ byteBuffer.rewind();
+ return byteBuffer;
+ }
+
+ public int getCapacity() {
+ return buffer.length;
+ }
+ }
+}
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD
index 63697229f..5c41edbad 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD
@@ -16,6 +16,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
+android_library(
+ name = "audio_data",
+ srcs = ["AudioData.java"],
+ deps = [
+ "//third_party:autovalue",
+ "@maven//:com_google_guava_guava",
+ ],
+)
+
android_library(
name = "category",
srcs = ["Category.java"],
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java
index d18f0e41e..49c459ef1 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java
@@ -31,11 +31,22 @@ public class OutputHandler {
InputT convertToTaskInput(List packets);
}
- /** Interface for the customizable MediaPipe task result listener. */
+ /**
+ * Interface for the customizable MediaPipe task result listener that can reteive both task result
+ * objects and the correpsonding input data.
+ */
public interface ResultListener {
void run(OutputT result, InputT input);
}
+ /**
+ * Interface for the customizable MediaPipe task result listener that can only reteive task result
+ * objects.
+ */
+ public interface PureResultListener {
+ void run(OutputT result);
+ }
+
private static final String TAG = "OutputHandler";
// A task-specific graph output packet converter that should be implemented per task.
private OutputPacketConverter outputPacketConverter;
@@ -45,6 +56,8 @@ public class OutputHandler {
protected ErrorListener errorListener;
// The cached task result for non latency sensitive use cases.
protected OutputT cachedTaskResult;
+ // The latest output timestamp.
+ protected long latestOutputTimestamp = -1;
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
private boolean handleTimestampBoundChanges = false;
@@ -98,6 +111,11 @@ public class OutputHandler {
return taskResult;
}
+ /* Returns the latest output timestamp. */
+ public long getLatestOutputTimestamp() {
+ return latestOutputTimestamp;
+ }
+
/**
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
*
@@ -109,6 +127,7 @@ public class OutputHandler {
taskResult = outputPacketConverter.convertToTaskResult(packets);
if (resultListener == null) {
cachedTaskResult = taskResult;
+ latestOutputTimestamp = packets.get(0).getTimestamp();
} else {
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
resultListener.run(taskResult, taskInput);
diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java
index 5739edebe..e6fc91cf6 100644
--- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java
+++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java
@@ -93,6 +93,7 @@ public class TaskRunner implements AutoCloseable {
public synchronized TaskResult process(Map inputs) {
addPackets(inputs, generateSyntheticTimestamp());
graph.waitUntilGraphIdle();
+ lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
return outputHandler.retrieveCachedTaskResult();
}