From da4d455d0cb8acaf860abc40371456a498c448d0 Mon Sep 17 00:00:00 2001 From: Jiuqiang Tang Date: Thu, 10 Nov 2022 10:08:02 -0800 Subject: [PATCH] Implement MediaPipe Tasks AudioClassifier Java API. PiperOrigin-RevId: 487570148 --- .../audio_classifier_graph_options.proto | 3 + .../com/google/mediapipe/tasks/audio/BUILD | 76 ++++ .../audio/audioclassifier/AndroidManifest.xml | 8 + .../audioclassifier/AudioClassifier.java | 399 ++++++++++++++++++ .../AudioClassifierResult.java | 76 ++++ .../tasks/audio/core/BaseAudioTaskApi.java | 151 +++++++ .../tasks/audio/core/RunningMode.java | 29 ++ .../components/containers/AudioData.java | 334 +++++++++++++++ .../tasks/components/containers/BUILD | 9 + .../mediapipe/tasks/core/OutputHandler.java | 21 +- .../mediapipe/tasks/core/TaskRunner.java | 1 + 11 files changed, 1106 insertions(+), 1 deletion(-) create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java create mode 100644 mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 16aa86aeb..5d4ba3296 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto"; +option java_outer_classname = "AudioClassifierGraphOptionsProto"; + message AudioClassifierGraphOptions { extend mediapipe.CalculatorOptions { optional AudioClassifierGraphOptions ext = 451755788; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD new file mode 100644 index 000000000..c5cb6f8a3 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD @@ -0,0 +1,76 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +android_library( + name = "core", + srcs = glob(["core/*.java"]), + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + ":libmediapipe_tasks_audio_jni_lib", + "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "@maven//:com_google_guava_guava", + ], +) + +# The native library of all MediaPipe audio tasks. +cc_binary( + name = "libmediapipe_tasks_audio_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + ], +) + +cc_library( + name = "libmediapipe_tasks_audio_jni_lib", + srcs = [":libmediapipe_tasks_audio_jni.so"], + alwayslink = 1, +) + +android_library( + name = "audioclassifier", + srcs = [ + "audioclassifier/AudioClassifier.java", + "audioclassifier/AudioClassifierResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "audioclassifier/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml new file mode 100644 index 000000000..1e9817efd --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java new file mode 100644 index 000000000..88b6daf0a --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifier.java @@ -0,0 +1,399 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.audioclassifier; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto; +import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi; +import com.google.mediapipe.tasks.audio.core.RunningMode; +import com.google.mediapipe.tasks.components.containers.AudioData; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs audio classification on audio clips or audio stream. + * + *

This API expects a TFLite model with mandatory TFLite Model Metadata that contains the + * mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label + * items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + *

Input tensor: (kTfLiteFloat32) + * + *

+ * + *

At least one output tensor with: (kTfLiteFloat32) + * + *

+ */ +public final class AudioClassifier extends BaseAudioTaskApi { + private static final String TAG = AudioClassifier.class.getSimpleName(); + private static final String AUDIO_IN_STREAM_NAME = "audio_in"; + private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "CLASSIFICATIONS:classifications_out", + "TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out")); + private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0; + private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"; + private static final long MICROSECONDS_PER_MILLISECOND = 1000; + + static { + ProtoUtil.registerTypeName( + ClassificationsProto.ClassificationResult.class, + "mediapipe.tasks.components.containers.proto.ClassificationResult"); + } + + /** + * Creates an {@link AudioClassifier} instance from a model file and default {@link + * AudioClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the classification model in the assets. + * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation. + */ + public static AudioClassifier createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link AudioClassifier} instance from a model file and default {@link + * AudioClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the classification model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation. + */ + public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates an {@link AudioClassifier} instance from a model buffer and default {@link + * AudioClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the + * classification model. + * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation. + */ + public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance. + * + * @param context an Android {@link Context}. + * @param options an {@link AudioClassifierOptions} instance. + * @throws MediaPipeException if there is an error during {@link AudioClassifier} creation. + */ + public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public AudioClassifierResult convertToTaskResult(List packets) { + try { + if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) { + // For audio stream mode. + return AudioClassifierResult.createFromProto( + PacketGetter.getProto( + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance()), + packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp() + / MICROSECONDS_PER_MILLISECOND); + } else { + // For audio clips mode. + return AudioClassifierResult.createFromProtoList( + PacketGetter.getProtoVector( + packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.parser()), + -1); + } + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + if (options.resultListener().isPresent()) { + ResultListener resultListener = + new ResultListener<>() { + @Override + public void run(AudioClassifierResult audioClassifierResult, Void input) { + options.resultListener().get().run(audioClassifierResult); + } + }; + handler.setResultListener(resultListener); + } + options.errorListener().ifPresent(handler::setErrorListener); + // Audio tasks should not drop input audio due to flow limiting, which may cause data + // inconsistency. + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new AudioClassifier(runner, options.runningMode()); + } + + /** + * Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe audio task {@link RunningMode}. + */ + private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME); + } + + /* + * Performs audio classification on the provided audio clip. Only use this method when the + * AudioClassifier is created with the audio clips mode. + * + *

The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts + * audio clips with various length and audio sample rate. It's required to provide the + * corresponding audio sample rate within the {@link AudioData} object. + * + *

The input audio clip may be longer than what the model is able to process in a single + * inference. When this occurs, the input audio clip is split into multiple chunks starting at + * different timestamps. For this reason, this function returns a vector of ClassificationResult + * objects, each associated with a timestamp corresponding to the start (in milliseconds) of the + * chunk data that was classified, e.g: + * + * ClassificationResult #0 (first chunk of data): + * timestamp_ms: 0 (starts at 0ms) + * classifications #0 (single head model): + * category #0: + * category_name: "Speech" + * score: 0.6 + * category #1: + * category_name: "Music" + * score: 0.2 + * ClassificationResult #1 (second chunk of data): + * timestamp_ms: 800 (starts at 800ms) + * classifications #0 (single head model): + * category #0: + * category_name: "Speech" + * score: 0.5 + * category #1: + * category_name: "Silence" + * score: 0.1 + * + * @param audioClip a MediaPipe {@link AudioData} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public AudioClassifierResult classify(AudioData audioClip) { + return (AudioClassifierResult) processAudioClip(audioClip); + } + + /* + * Sends audio data (a block in a continuous audio stream) to perform audio classification. Only + * use this method when the AudioClassifier is created with the audio stream mode. + * + *

The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will + * be resampled, accumulated, and framed to the proper size for the underlying model to consume. + * It's required to provide the corresponding audio sample rate within {@link AudioData} object as + * well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The + * timestamps must be monotonically increasing. This method will return immediately after + * the input audio data is accepted. The results will be available in the `resultListener` + * provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process + * auido stream data such as microphone input. + * + *

The input audio block may be longer than what the model is able to process in a single + * inference. When this occurs, the input audio block is split into multiple chunks. For this + * reason, the callback may be called multiple times (once per chunk) for each call to this + * function. + * + * @param audioBlock a MediaPipe {@link AudioData} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void classifyAsync(AudioData audioBlock, long timestampMs) { + checkOrSetSampleRate(audioBlock.getFormat().getSampleRate()); + sendAudioStreamData(audioBlock, timestampMs); + } + + /** Options for setting up and {@link AudioClassifier}. */ + @AutoValue + public abstract static class AudioClassifierOptions extends TaskOptions { + + /** Builder for {@link AudioClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the audio classifier task. */ + public abstract Builder setBaseOptions(BaseOptions baseOptions); + + /** + * Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips + * mode. Image classifier has two modes: + * + *

    + *
  • AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed + * audio clips to the `classify` method, and will receive the classification results as + * the return value. + *
  • AUDIO_STREAM: The mode for running audio classification on the audio stream, such as + * from microphone. Users call `classifyAsync` to push the audio data into the + * AudioClassifier, the classification results will be available in the result callback + * when the audio classifier finishes the work. + *
+ */ + public abstract Builder setRunningMode(RunningMode runningMode); + + /** + * Sets the optional {@link ClassifierOptions} controling classification behavior, such as + * score threshold, number of results, etc. + */ + public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + + /** + * Sets the {@link ResultListener} to receive the classification results asynchronously when + * the audio classifier is in the audio stream mode. + */ + public abstract Builder setResultListener( + PureResultListener resultListener); + + /** Sets an optional {@link ErrorListener}. */ + public abstract Builder setErrorListener(ErrorListener errorListener); + + abstract AudioClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link AudioClassifierOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the audio classifier + * is in the audio stream mode. + */ + public final AudioClassifierOptions build() { + AudioClassifierOptions options = autoBuild(); + if (options.runningMode() == RunningMode.AUDIO_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The audio classifier is in the audio stream mode, a user-defined result listener" + + " must be provided in the AudioClassifierOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The audio classifier is in the audio clips mode, a user-defined result listener" + + " shouldn't be provided in AudioClassifierOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional classifierOptions(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder() + .setRunningMode(RunningMode.AUDIO_CLIPS); + } + + /** + * Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder = + AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (classifierOptions().isPresent()) { + taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java new file mode 100644 index 000000000..fcc3c6e22 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/audioclassifier/AudioClassifierResult.java @@ -0,0 +1,76 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.audioclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.containers.ClassificationResult; +import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** Represents the classification results generated by {@link AudioClassifier}. */ +@AutoValue +public abstract class AudioClassifierResult implements TaskResult { + + /** + * Creates an {@link AudioClassifierResult} instance from a list of {@link + * ClassificationsProto.ClassificationResult} protobuf messages. + * + * @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message + * to convert. + * @param timestampMs a timestamp for this result. + */ + static AudioClassifierResult createFromProtoList( + List protoList, long timestampMs) { + List classificationResultList = new ArrayList<>(); + for (ClassificationsProto.ClassificationResult proto : protoList) { + classificationResultList.add(ClassificationResult.createFromProto(proto)); + } + return new AutoValue_AudioClassifierResult( + Optional.of(classificationResultList), Optional.empty(), timestampMs); + } + + /** + * Creates an {@link AudioClassifierResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert. + * @param timestampMs a timestamp for this result. + */ + static AudioClassifierResult createFromProto( + ClassificationsProto.ClassificationResult proto, long timestampMs) { + return new AutoValue_AudioClassifierResult( + Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs); + } + + /** + * A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results + * per classifier head. The list represents the audio classification result of an audio clip, and + * is only available when running with the audio clips mode. + */ + public abstract Optional> classificationResultList(); + + /** + * Contains one set of results per classifier head. A {@link ClassificationResult} usually + * represents one audio classification result in an audio stream, and s only available when + * running with the audio stream mode. + */ + public abstract Optional classificationResult(); + + @Override + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java new file mode 100644 index 000000000..affe43559 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -0,0 +1,151 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.core; + +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.tasks.components.containers.AudioData; +import com.google.mediapipe.tasks.core.TaskResult; +import com.google.mediapipe.tasks.core.TaskRunner; +import java.util.HashMap; +import java.util.Map; + +/** The base class of MediaPipe audio tasks. */ +public class BaseAudioTaskApi implements AutoCloseable { + private static final long MICROSECONDS_PER_MILLISECOND = 1000; + private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2; + + private final TaskRunner runner; + private final RunningMode runningMode; + private final String audioStreamName; + private final String sampleRateStreamName; + private double defaultSampleRate; + + static { + System.loadLibrary("mediapipe_tasks_audio_jni"); + } + + /** + * Constructor to initialize a {@link BaseAudioTaskApi}. + * + * @param runner a {@link TaskRunner}. + * @param runningMode a mediapipe audio task {@link RunningMode}. + * @param audioStreamName the name of the input audio stream. + * @param sampleRateStreamName the name of the audio sample rate stream. + */ + public BaseAudioTaskApi( + TaskRunner runner, + RunningMode runningMode, + String audioStreamName, + String sampleRateStreamName) { + this.runner = runner; + this.runningMode = runningMode; + this.audioStreamName = audioStreamName; + this.sampleRateStreamName = sampleRateStreamName; + this.defaultSampleRate = -1.0; + } + + /** + * A synchronous method to process audio clips. The call blocks the current thread until a failure + * status or a successful result is returned. + * + * @param audioClip a MediaPipe {@link AudioDatra} object for processing. + * @throws MediaPipeException if the task is not in the audio clips mode. + */ + protected TaskResult processAudioClip(AudioData audioClip) { + if (runningMode != RunningMode.AUDIO_CLIPS) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the audio clips mode. Current running mode:" + + runningMode.name()); + } + Map inputPackets = new HashMap<>(); + inputPackets.put( + audioStreamName, + runner + .getPacketCreator() + .createMatrix( + audioClip.getFormat().getNumOfChannels(), + audioClip.getBufferLength(), + audioClip.getBuffer())); + inputPackets.put( + sampleRateStreamName, + runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate())); + return runner.process(inputPackets); + } + + /** + * Checks or sets the audio sample rate in the audio stream mode. + * + * @param sampleRate the audio sample rate. + * @throws MediaPipeException if the task is not in the audio stream mode or the provided sample + * rate is inconsisent with the previously recevied. + */ + protected void checkOrSetSampleRate(double sampleRate) { + if (runningMode != RunningMode.AUDIO_STREAM) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the audio stream mode. Current running mode:" + + runningMode.name()); + } + if (defaultSampleRate > 0) { + if (Double.compare(sampleRate, defaultSampleRate) != 0) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(), + "The input audio sample rate: " + + sampleRate + + " is inconsistent with the previously provided: " + + defaultSampleRate); + } + } else { + Map inputPackets = new HashMap<>(); + inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate)); + runner.send(inputPackets, PRESTREAM_TIMESTAMP); + defaultSampleRate = sampleRate; + } + } + /** + * An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be + * available in the user-defined result listener. + * + * @param audioClip a MediaPipe {@link AudioDatra} object for processing. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the stream mode. + */ + protected void sendAudioStreamData(AudioData audioClip, long timestampMs) { + if (runningMode != RunningMode.AUDIO_STREAM) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the audio stream mode. Current running mode:" + + runningMode.name()); + } + Map inputPackets = new HashMap<>(); + inputPackets.put( + audioStreamName, + runner + .getPacketCreator() + .createMatrix( + audioClip.getFormat().getNumOfChannels(), + audioClip.getBufferLength(), + audioClip.getBuffer())); + runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + + /** Closes and cleans up the MediaPipe audio task. */ + @Override + public void close() { + runner.close(); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java new file mode 100644 index 000000000..f0a123810 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/RunningMode.java @@ -0,0 +1,29 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.audio.core; + +/** + * MediaPipe audio task running mode. A MediaPipe audio task can be run with two different modes: + * + *
    + *
  • AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips. + *
  • AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from + * microphone. + *
+ */ +public enum RunningMode { + AUDIO_CLIPS, + AUDIO_STREAM, +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java new file mode 100644 index 000000000..40a05b1e4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/AudioData.java @@ -0,0 +1,334 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import static java.lang.System.arraycopy; + +import android.media.AudioFormat; +import android.media.AudioRecord; +import com.google.auto.value.AutoValue; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + +/** + * Defines a ring buffer and some utility functions to prepare the input audio samples. + * + *

It maintains a Ring Buffer to hold + * input audio data. Clients could feed input audio data via `load` methods and access the + * aggregated audio samples via `getTensorBuffer` method. + * + *

Note that this class can only handle input audio in Float (in {@link + * android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link + * android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio + * samples in PCM Float encoding. + * + *

Typical usage in Kotlin + * + *

+ *   val audioData = AudioData.create(format, modelInputLength)
+ *   audioData.load(newData)
+ * 
+ * + *

Another sample usage with {@link android.media.AudioRecord} + * + *

+ *   val audioData = AudioData.create(format, modelInputLength)
+ *   Timer().scheduleAtFixedRate(delay, period) {
+ *     audioData.load(audioRecord)
+ *   }
+ * 
+ */ +public class AudioData { + + private static final String TAG = AudioData.class.getSimpleName(); + private final FloatRingBuffer buffer; + private final AudioDataFormat format; + + /** + * Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code + * sampleCounts} * {@code format.getNumOfChannels()}. + * + * @param format the expected {@link AudioDataFormat} of audio data loaded into this class. + * @param sampleCounts the number of samples. + */ + public static AudioData create(AudioDataFormat format, int sampleCounts) { + return new AudioData(format, sampleCounts); + } + + /** + * Creates a {@link AudioData} instance with a ring buffer whose size is {@code sampleCounts} * + * {@code format.getChannelCount()}. + * + * @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines + * the number of channels and sample rate. + * @param sampleCounts the number of samples to be fed into the model + */ + public static AudioData create(AudioFormat format, int sampleCounts) { + return new AudioData(AudioDataFormat.create(format), sampleCounts); + } + + /** + * Wraps a few constants describing the format of the incoming audio samples, namely number of + * channels and the sample rate. By default, num of channels is set to 1. + */ + @AutoValue + public abstract static class AudioDataFormat { + private static final int DEFAULT_NUM_OF_CHANNELS = 1; + + /** Creates a {@link AudioFormat} instance from Android AudioFormat class. */ + public static AudioDataFormat create(AudioFormat format) { + return AudioDataFormat.builder() + .setNumOfChannels(format.getChannelCount()) + .setSampleRate(format.getSampleRate()) + .build(); + } + + public abstract int getNumOfChannels(); + + public abstract float getSampleRate(); + + public static Builder builder() { + return new AutoValue_AudioData_AudioDataFormat.Builder() + .setNumOfChannels(DEFAULT_NUM_OF_CHANNELS); + } + + /** Builder for {@link AudioDataFormat} */ + @AutoValue.Builder + public abstract static class Builder { + + /* By default, it's set to have 1 channel. */ + public abstract Builder setNumOfChannels(int value); + + public abstract Builder setSampleRate(float value); + + abstract AudioDataFormat autoBuild(); + + public AudioDataFormat build() { + AudioDataFormat format = autoBuild(); + if (format.getNumOfChannels() <= 0) { + throw new IllegalArgumentException("Number of channels should be greater than 0"); + } + if (format.getSampleRate() <= 0) { + throw new IllegalArgumentException("Sample rate should be greater than 0"); + } + return format; + } + } + } + + /** + * Stores the input audio samples {@code src} in the ring buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For + * multi-channel input, the array is interleaved. + */ + public void load(float[] src) { + load(src, 0, src.length); + } + + /** + * Stores the input audio samples {@code src} in the ring buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For + * multi-channel input, the array is interleaved. + * @param offsetInFloat starting position in the {@code src} array + * @param sizeInFloat the number of float values to be copied + * @throws IllegalArgumentException for incompatible audio format or incorrect input size + */ + public void load(float[] src, int offsetInFloat, int sizeInFloat) { + if (sizeInFloat % format.getNumOfChannels() != 0) { + throw new IllegalArgumentException( + String.format( + "Size (%d) needs to be a multiplier of the number of channels (%d)", + sizeInFloat, format.getNumOfChannels())); + } + buffer.load(src, offsetInFloat, sizeInFloat); + } + + /** + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring + * buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For + * multi-channel input, the array is interleaved. + */ + public void load(short[] src) { + load(src, 0, src.length); + } + + /** + * Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring + * buffer. + * + * @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For + * multi-channel input, the array is interleaved. + * @param offsetInShort starting position in the src array + * @param sizeInShort the number of short values to be copied + * @throws IllegalArgumentException if the source array can't be copied + */ + public void load(short[] src, int offsetInShort, int sizeInShort) { + if (offsetInShort + sizeInShort > src.length) { + throw new IllegalArgumentException( + String.format( + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", + offsetInShort, sizeInShort, src.length)); + } + float[] floatData = new float[sizeInShort]; + for (int i = 0; i < sizeInShort; i++) { + // Convert the data to PCM Float encoding i.e. values between -1 and 1 + floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE; + } + load(floatData); + } + + /** + * Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only + * supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT. + * + * @param record an instance of {@link android.media.AudioRecord} + * @return number of captured audio values whose size is {@code channelCount * sampleCount}. If + * there was no new data in the AudioRecord or an error occurred, this method will return 0. + * @throws IllegalArgumentException for unsupported audio encoding format + * @throws IllegalStateException if reading from AudioRecord failed + */ + public int load(AudioRecord record) { + if (!this.format.equals(AudioDataFormat.create(record.getFormat()))) { + throw new IllegalArgumentException("Incompatible audio format."); + } + int loadedValues = 0; + if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) { + float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()]; + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); + if (loadedValues > 0) { + load(newData, 0, loadedValues); + return loadedValues; + } + } else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) { + short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()]; + loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING); + if (loadedValues > 0) { + load(newData, 0, loadedValues); + return loadedValues; + } + } else { + throw new IllegalArgumentException( + "Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT."); + } + + switch (loadedValues) { + case AudioRecord.ERROR_INVALID_OPERATION: + throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION"); + + case AudioRecord.ERROR_BAD_VALUE: + throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE"); + + case AudioRecord.ERROR_DEAD_OBJECT: + throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT"); + + case AudioRecord.ERROR: + throw new IllegalStateException("AudioRecord.ERROR"); + + default: + return 0; + } + } + + /** + * Returns a float array holding all the available audio samples in {@link + * android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1]. + */ + public float[] getBuffer() { + float[] bufferData = new float[buffer.getCapacity()]; + ByteBuffer byteBuffer = buffer.getBuffer(); + byteBuffer.asFloatBuffer().get(bufferData); + return bufferData; + } + + /* Returns the {@link AudioDataFormat} associated with the tensor. */ + public AudioDataFormat getFormat() { + return format; + } + + /* Returns the audio buffer length. */ + public int getBufferLength() { + return buffer.getCapacity() / format.getNumOfChannels(); + } + + private AudioData(AudioDataFormat format, int sampleCounts) { + this.format = format; + this.buffer = new FloatRingBuffer(sampleCounts * format.getNumOfChannels()); + } + + /** Actual implementation of the ring buffer. */ + private static class FloatRingBuffer { + + private final float[] buffer; + private int nextIndex = 0; + + public FloatRingBuffer(int flatSize) { + buffer = new float[flatSize]; + } + + /** + * Loads a slice of the float array to the ring buffer. If the float array is longer than ring + * buffer's capacity, samples with lower indices in the array will be ignored. + */ + public void load(float[] newData, int offset, int size) { + if (offset + size > newData.length) { + throw new IllegalArgumentException( + String.format( + "Index out of range. offset (%d) + size (%d) should <= newData.length (%d)", + offset, size, newData.length)); + } + // If buffer can't hold all the data, only keep the most recent data of size buffer.length + if (size > buffer.length) { + offset += (size - buffer.length); + size = buffer.length; + } + if (nextIndex + size < buffer.length) { + // No need to wrap nextIndex, just copy newData[offset:offset + size] + // to buffer[nextIndex:nextIndex+size] + arraycopy(newData, offset, buffer, nextIndex, size); + } else { + // Need to wrap nextIndex, perform copy in two chunks. + int firstChunkSize = buffer.length - nextIndex; + // First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length] + arraycopy(newData, offset, buffer, nextIndex, firstChunkSize); + // Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize] + arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize); + } + + nextIndex = (nextIndex + size) % buffer.length; + } + + public ByteBuffer getBuffer() { + // Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which + // can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so + // generally we don't create direct buffer for every invocation. + ByteBuffer byteBuffer = ByteBuffer.allocate(Float.SIZE / 8 * buffer.length); + byteBuffer.order(ByteOrder.nativeOrder()); + FloatBuffer result = byteBuffer.asFloatBuffer(); + result.put(buffer, nextIndex, buffer.length - nextIndex); + result.put(buffer, 0, nextIndex); + byteBuffer.rewind(); + return byteBuffer; + } + + public int getCapacity() { + return buffer.length; + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 63697229f..5c41edbad 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -16,6 +16,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +android_library( + name = "audio_data", + srcs = ["AudioData.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "category", srcs = ["Category.java"], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java index d18f0e41e..49c459ef1 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java @@ -31,11 +31,22 @@ public class OutputHandler { InputT convertToTaskInput(List packets); } - /** Interface for the customizable MediaPipe task result listener. */ + /** + * Interface for the customizable MediaPipe task result listener that can reteive both task result + * objects and the correpsonding input data. + */ public interface ResultListener { void run(OutputT result, InputT input); } + /** + * Interface for the customizable MediaPipe task result listener that can only reteive task result + * objects. + */ + public interface PureResultListener { + void run(OutputT result); + } + private static final String TAG = "OutputHandler"; // A task-specific graph output packet converter that should be implemented per task. private OutputPacketConverter outputPacketConverter; @@ -45,6 +56,8 @@ public class OutputHandler { protected ErrorListener errorListener; // The cached task result for non latency sensitive use cases. protected OutputT cachedTaskResult; + // The latest output timestamp. + protected long latestOutputTimestamp = -1; // Whether the output handler should react to timestamp-bound changes by outputting empty packets. private boolean handleTimestampBoundChanges = false; @@ -98,6 +111,11 @@ public class OutputHandler { return taskResult; } + /* Returns the latest output timestamp. */ + public long getLatestOutputTimestamp() { + return latestOutputTimestamp; + } + /** * Handles a list of output {@link Packet}s. Invoked when a packet list become available. * @@ -109,6 +127,7 @@ public class OutputHandler { taskResult = outputPacketConverter.convertToTaskResult(packets); if (resultListener == null) { cachedTaskResult = taskResult; + latestOutputTimestamp = packets.get(0).getTimestamp(); } else { InputT taskInput = outputPacketConverter.convertToTaskInput(packets); resultListener.run(taskResult, taskInput); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index 5739edebe..e6fc91cf6 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -93,6 +93,7 @@ public class TaskRunner implements AutoCloseable { public synchronized TaskResult process(Map inputs) { addPackets(inputs, generateSyntheticTimestamp()); graph.waitUntilGraphIdle(); + lastSeenTimestamp = outputHandler.getLatestOutputTimestamp(); return outputHandler.retrieveCachedTaskResult(); }