Implement MediaPipe Tasks AudioClassifier Java API.

PiperOrigin-RevId: 487570148
This commit is contained in:
Jiuqiang Tang 2022-11-10 10:08:02 -08:00 committed by Copybara-Service
parent 2ce3a9719e
commit da4d455d0c
11 changed files with 1106 additions and 1 deletions

View File

@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
option java_outer_classname = "AudioClassifierGraphOptionsProto";
message AudioClassifierGraphOptions { message AudioClassifierGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional AudioClassifierGraphOptions ext = 451755788; optional AudioClassifierGraphOptions ext = 451755788;

View File

@ -0,0 +1,76 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["core/*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
":libmediapipe_tasks_audio_jni_lib",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"@maven//:com_google_guava_guava",
],
)
# The native library of all MediaPipe audio tasks.
cc_binary(
name = "libmediapipe_tasks_audio_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_audio_jni_lib",
srcs = [":libmediapipe_tasks_audio_jni.so"],
alwayslink = 1,
)
android_library(
name = "audioclassifier",
srcs = [
"audioclassifier/AudioClassifier.java",
"audioclassifier/AudioClassifierResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "audioclassifier/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.audio.audioclassifier">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,399 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.audioclassifier;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs audio classification on audio clips or audio stream.
*
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
*
* <p>Input tensor: (kTfLiteFloat32)
*
* <ul>
* <li>input audio buffer of size `[batch * samples]`.
* <li>batch inference is not supported (`batch` is required to be 1).
* <li>for multi-channel models, the channels need be interleaved.
* </ul>
*
* <p>At least one output tensor with: (kTfLiteFloat32)
*
* <ul>
* <li>`[1 x N]` array with `N` represents the number of categories.
* <li>optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
* `category_name` field of the results. The `display_name` field is filled from the
* AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
* `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
* these are available, only the `index` field of the results will be filled.
* </ul>
*/
public final class AudioClassifier extends BaseAudioTaskApi {
private static final String TAG = AudioClassifier.class.getSimpleName();
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"CLASSIFICATIONS:classifications_out",
"TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
static {
ProtoUtil.registerTypeName(
ClassificationsProto.ClassificationResult.class,
"mediapipe.tasks.components.containers.proto.ClassificationResult");
}
/**
* Creates an {@link AudioClassifier} instance from a model file and default {@link
* AudioClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the classification model in the assets.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link AudioClassifier} instance from a model file and default {@link
* AudioClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the classification model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates an {@link AudioClassifier} instance from a model buffer and default {@link
* AudioClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
* classification model.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
*
* @param context an Android {@link Context}.
* @param options an {@link AudioClassifierOptions} instance.
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
*/
public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
OutputHandler<AudioClassifierResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<AudioClassifierResult, Void>() {
@Override
public AudioClassifierResult convertToTaskResult(List<Packet> packets) {
try {
if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
// For audio stream mode.
return AudioClassifierResult.createFromProto(
PacketGetter.getProto(
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()),
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
/ MICROSECONDS_PER_MILLISECOND);
} else {
// For audio clips mode.
return AudioClassifierResult.createFromProtoList(
PacketGetter.getProtoVector(
packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.parser()),
-1);
}
} catch (IOException e) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
}
}
@Override
public Void convertToTaskInput(List<Packet> packets) {
return null;
}
});
if (options.resultListener().isPresent()) {
ResultListener<AudioClassifierResult, Void> resultListener =
new ResultListener<>() {
@Override
public void run(AudioClassifierResult audioClassifierResult, Void input) {
options.resultListener().get().run(audioClassifierResult);
}
};
handler.setResultListener(resultListener);
}
options.errorListener().ifPresent(handler::setErrorListener);
// Audio tasks should not drop input audio due to flow limiting, which may cause data
// inconsistency.
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<AudioClassifierOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(options)
.setEnableFlowLimiting(false)
.build(),
handler);
return new AudioClassifier(runner, options.runningMode());
}
/**
* Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe audio task {@link RunningMode}.
*/
private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
}
/*
* Performs audio classification on the provided audio clip. Only use this method when the
* AudioClassifier is created with the audio clips mode.
*
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
* audio clips with various length and audio sample rate. It's required to provide the
* corresponding audio sample rate within the {@link AudioData} object.
*
* <p>The input audio clip may be longer than what the model is able to process in a single
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
* different timestamps. For this reason, this function returns a vector of ClassificationResult
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
* chunk data that was classified, e.g:
*
* ClassificationResult #0 (first chunk of data):
* timestamp_ms: 0 (starts at 0ms)
* classifications #0 (single head model):
* category #0:
* category_name: "Speech"
* score: 0.6
* category #1:
* category_name: "Music"
* score: 0.2
* ClassificationResult #1 (second chunk of data):
* timestamp_ms: 800 (starts at 800ms)
* classifications #0 (single head model):
* category #0:
* category_name: "Speech"
* score: 0.5
* category #1:
* category_name: "Silence"
* score: 0.1
*
* @param audioClip a MediaPipe {@link AudioData} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public AudioClassifierResult classify(AudioData audioClip) {
return (AudioClassifierResult) processAudioClip(audioClip);
}
/*
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
* use this method when the AudioClassifier is created with the audio stream mode.
*
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
* timestamps must be monotonically increasing. This method will return immediately after
* the input audio data is accepted. The results will be available in the `resultListener`
* provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
* auido stream data such as microphone input.
*
* <p>The input audio block may be longer than what the model is able to process in a single
* inference. When this occurs, the input audio block is split into multiple chunks. For this
* reason, the callback may be called multiple times (once per chunk) for each call to this
* function.
*
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void classifyAsync(AudioData audioBlock, long timestampMs) {
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
sendAudioStreamData(audioBlock, timestampMs);
}
/** Options for setting up and {@link AudioClassifier}. */
@AutoValue
public abstract static class AudioClassifierOptions extends TaskOptions {
/** Builder for {@link AudioClassifierOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the {@link BaseOptions} for the audio classifier task. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/**
* Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
* mode. Image classifier has two modes:
*
* <ul>
* <li>AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
* audio clips to the `classify` method, and will receive the classification results as
* the return value.
* <li>AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
* from microphone. Users call `classifyAsync` to push the audio data into the
* AudioClassifier, the classification results will be available in the result callback
* when the audio classifier finishes the work.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode runningMode);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
/**
* Sets the {@link ResultListener} to receive the classification results asynchronously when
* the audio classifier is in the audio stream mode.
*/
public abstract Builder setResultListener(
PureResultListener<AudioClassifierResult> resultListener);
/** Sets an optional {@link ErrorListener}. */
public abstract Builder setErrorListener(ErrorListener errorListener);
abstract AudioClassifierOptions autoBuild();
/**
* Validates and builds the {@link AudioClassifierOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the audio classifier
* is in the audio stream mode.
*/
public final AudioClassifierOptions build() {
AudioClassifierOptions options = autoBuild();
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The audio classifier is in the audio stream mode, a user-defined result listener"
+ " must be provided in the AudioClassifierOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The audio classifier is in the audio clips mode, a user-defined result listener"
+ " shouldn't be provided in AudioClassifierOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<ClassifierOptions> classifierOptions();
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
.setRunningMode(RunningMode.AUDIO_CLIPS);
}
/**
* Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
*/
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder()
.setExtension(
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -0,0 +1,76 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.audioclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/** Represents the classification results generated by {@link AudioClassifier}. */
@AutoValue
public abstract class AudioClassifierResult implements TaskResult {
/**
* Creates an {@link AudioClassifierResult} instance from a list of {@link
* ClassificationsProto.ClassificationResult} protobuf messages.
*
* @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
* to convert.
* @param timestampMs a timestamp for this result.
*/
static AudioClassifierResult createFromProtoList(
List<ClassificationsProto.ClassificationResult> protoList, long timestampMs) {
List<ClassificationResult> classificationResultList = new ArrayList<>();
for (ClassificationsProto.ClassificationResult proto : protoList) {
classificationResultList.add(ClassificationResult.createFromProto(proto));
}
return new AutoValue_AudioClassifierResult(
Optional.of(classificationResultList), Optional.empty(), timestampMs);
}
/**
* Creates an {@link AudioClassifierResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
static AudioClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return new AutoValue_AudioClassifierResult(
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
}
/**
* A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
* per classifier head. The list represents the audio classification result of an audio clip, and
* is only available when running with the audio clips mode.
*/
public abstract Optional<List<ClassificationResult>> classificationResultList();
/**
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
* represents one audio classification result in an audio stream, and s only available when
* running with the audio stream mode.
*/
public abstract Optional<ClassificationResult> classificationResult();
@Override
public abstract long timestampMs();
}

View File

@ -0,0 +1,151 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.core;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap;
import java.util.Map;
/** The base class of MediaPipe audio tasks. */
public class BaseAudioTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
private final TaskRunner runner;
private final RunningMode runningMode;
private final String audioStreamName;
private final String sampleRateStreamName;
private double defaultSampleRate;
static {
System.loadLibrary("mediapipe_tasks_audio_jni");
}
/**
* Constructor to initialize a {@link BaseAudioTaskApi}.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe audio task {@link RunningMode}.
* @param audioStreamName the name of the input audio stream.
* @param sampleRateStreamName the name of the audio sample rate stream.
*/
public BaseAudioTaskApi(
TaskRunner runner,
RunningMode runningMode,
String audioStreamName,
String sampleRateStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.audioStreamName = audioStreamName;
this.sampleRateStreamName = sampleRateStreamName;
this.defaultSampleRate = -1.0;
}
/**
* A synchronous method to process audio clips. The call blocks the current thread until a failure
* status or a successful result is returned.
*
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
* @throws MediaPipeException if the task is not in the audio clips mode.
*/
protected TaskResult processAudioClip(AudioData audioClip) {
if (runningMode != RunningMode.AUDIO_CLIPS) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the audio clips mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(
audioStreamName,
runner
.getPacketCreator()
.createMatrix(
audioClip.getFormat().getNumOfChannels(),
audioClip.getBufferLength(),
audioClip.getBuffer()));
inputPackets.put(
sampleRateStreamName,
runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
return runner.process(inputPackets);
}
/**
* Checks or sets the audio sample rate in the audio stream mode.
*
* @param sampleRate the audio sample rate.
* @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
* rate is inconsisent with the previously recevied.
*/
protected void checkOrSetSampleRate(double sampleRate) {
if (runningMode != RunningMode.AUDIO_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the audio stream mode. Current running mode:"
+ runningMode.name());
}
if (defaultSampleRate > 0) {
if (Double.compare(sampleRate, defaultSampleRate) != 0) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
"The input audio sample rate: "
+ sampleRate
+ " is inconsistent with the previously provided: "
+ defaultSampleRate);
}
} else {
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
runner.send(inputPackets, PRESTREAM_TIMESTAMP);
defaultSampleRate = sampleRate;
}
}
/**
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the stream mode.
*/
protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
if (runningMode != RunningMode.AUDIO_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the audio stream mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(
audioStreamName,
runner
.getPacketCreator()
.createMatrix(
audioClip.getFormat().getNumOfChannels(),
audioClip.getBufferLength(),
audioClip.getBuffer()));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe audio task. */
@Override
public void close() {
runner.close();
}
}

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.core;
/**
* MediaPipe audio task running mode. A MediaPipe audio task can be run with two different modes:
*
* <ul>
* <li>AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips.
* <li>AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from
* microphone.
* </ul>
*/
public enum RunningMode {
AUDIO_CLIPS,
AUDIO_STREAM,
}

View File

@ -0,0 +1,334 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import static java.lang.System.arraycopy;
import android.media.AudioFormat;
import android.media.AudioRecord;
import com.google.auto.value.AutoValue;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
/**
* Defines a ring buffer and some utility functions to prepare the input audio samples.
*
* <p>It maintains a <a href="https://en.wikipedia.org/wiki/Circular_buffer">Ring Buffer</a> to hold
* input audio data. Clients could feed input audio data via `load` methods and access the
* aggregated audio samples via `getTensorBuffer` method.
*
* <p>Note that this class can only handle input audio in Float (in {@link
* android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link
* android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio
* samples in PCM Float encoding.
*
* <p>Typical usage in Kotlin
*
* <pre>
* val audioData = AudioData.create(format, modelInputLength)
* audioData.load(newData)
* </pre>
*
* <p>Another sample usage with {@link android.media.AudioRecord}
*
* <pre>
* val audioData = AudioData.create(format, modelInputLength)
* Timer().scheduleAtFixedRate(delay, period) {
* audioData.load(audioRecord)
* }
* </pre>
*/
public class AudioData {
private static final String TAG = AudioData.class.getSimpleName();
private final FloatRingBuffer buffer;
private final AudioDataFormat format;
/**
* Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
* sampleCounts} * {@code format.getNumOfChannels()}.
*
* @param format the expected {@link AudioDataFormat} of audio data loaded into this class.
* @param sampleCounts the number of samples.
*/
public static AudioData create(AudioDataFormat format, int sampleCounts) {
return new AudioData(format, sampleCounts);
}
/**
* Creates a {@link AudioData} instance with a ring buffer whose size is {@code sampleCounts} *
* {@code format.getChannelCount()}.
*
* @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
* the number of channels and sample rate.
* @param sampleCounts the number of samples to be fed into the model
*/
public static AudioData create(AudioFormat format, int sampleCounts) {
return new AudioData(AudioDataFormat.create(format), sampleCounts);
}
/**
* Wraps a few constants describing the format of the incoming audio samples, namely number of
* channels and the sample rate. By default, num of channels is set to 1.
*/
@AutoValue
public abstract static class AudioDataFormat {
private static final int DEFAULT_NUM_OF_CHANNELS = 1;
/** Creates a {@link AudioFormat} instance from Android AudioFormat class. */
public static AudioDataFormat create(AudioFormat format) {
return AudioDataFormat.builder()
.setNumOfChannels(format.getChannelCount())
.setSampleRate(format.getSampleRate())
.build();
}
public abstract int getNumOfChannels();
public abstract float getSampleRate();
public static Builder builder() {
return new AutoValue_AudioData_AudioDataFormat.Builder()
.setNumOfChannels(DEFAULT_NUM_OF_CHANNELS);
}
/** Builder for {@link AudioDataFormat} */
@AutoValue.Builder
public abstract static class Builder {
/* By default, it's set to have 1 channel. */
public abstract Builder setNumOfChannels(int value);
public abstract Builder setSampleRate(float value);
abstract AudioDataFormat autoBuild();
public AudioDataFormat build() {
AudioDataFormat format = autoBuild();
if (format.getNumOfChannels() <= 0) {
throw new IllegalArgumentException("Number of channels should be greater than 0");
}
if (format.getSampleRate() <= 0) {
throw new IllegalArgumentException("Sample rate should be greater than 0");
}
return format;
}
}
}
/**
* Stores the input audio samples {@code src} in the ring buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
* multi-channel input, the array is interleaved.
*/
public void load(float[] src) {
load(src, 0, src.length);
}
/**
* Stores the input audio samples {@code src} in the ring buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
* multi-channel input, the array is interleaved.
* @param offsetInFloat starting position in the {@code src} array
* @param sizeInFloat the number of float values to be copied
* @throws IllegalArgumentException for incompatible audio format or incorrect input size
*/
public void load(float[] src, int offsetInFloat, int sizeInFloat) {
if (sizeInFloat % format.getNumOfChannels() != 0) {
throw new IllegalArgumentException(
String.format(
"Size (%d) needs to be a multiplier of the number of channels (%d)",
sizeInFloat, format.getNumOfChannels()));
}
buffer.load(src, offsetInFloat, sizeInFloat);
}
/**
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
* buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
* multi-channel input, the array is interleaved.
*/
public void load(short[] src) {
load(src, 0, src.length);
}
/**
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
* buffer.
*
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
* multi-channel input, the array is interleaved.
* @param offsetInShort starting position in the src array
* @param sizeInShort the number of short values to be copied
* @throws IllegalArgumentException if the source array can't be copied
*/
public void load(short[] src, int offsetInShort, int sizeInShort) {
if (offsetInShort + sizeInShort > src.length) {
throw new IllegalArgumentException(
String.format(
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
offsetInShort, sizeInShort, src.length));
}
float[] floatData = new float[sizeInShort];
for (int i = 0; i < sizeInShort; i++) {
// Convert the data to PCM Float encoding i.e. values between -1 and 1
floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
}
load(floatData);
}
/**
* Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
* supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
*
* @param record an instance of {@link android.media.AudioRecord}
* @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
* there was no new data in the AudioRecord or an error occurred, this method will return 0.
* @throws IllegalArgumentException for unsupported audio encoding format
* @throws IllegalStateException if reading from AudioRecord failed
*/
public int load(AudioRecord record) {
if (!this.format.equals(AudioDataFormat.create(record.getFormat()))) {
throw new IllegalArgumentException("Incompatible audio format.");
}
int loadedValues = 0;
if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
if (loadedValues > 0) {
load(newData, 0, loadedValues);
return loadedValues;
}
} else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
if (loadedValues > 0) {
load(newData, 0, loadedValues);
return loadedValues;
}
} else {
throw new IllegalArgumentException(
"Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
}
switch (loadedValues) {
case AudioRecord.ERROR_INVALID_OPERATION:
throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
case AudioRecord.ERROR_BAD_VALUE:
throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
case AudioRecord.ERROR_DEAD_OBJECT:
throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
case AudioRecord.ERROR:
throw new IllegalStateException("AudioRecord.ERROR");
default:
return 0;
}
}
/**
* Returns a float array holding all the available audio samples in {@link
* android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
*/
public float[] getBuffer() {
float[] bufferData = new float[buffer.getCapacity()];
ByteBuffer byteBuffer = buffer.getBuffer();
byteBuffer.asFloatBuffer().get(bufferData);
return bufferData;
}
/* Returns the {@link AudioDataFormat} associated with the tensor. */
public AudioDataFormat getFormat() {
return format;
}
/* Returns the audio buffer length. */
public int getBufferLength() {
return buffer.getCapacity() / format.getNumOfChannels();
}
private AudioData(AudioDataFormat format, int sampleCounts) {
this.format = format;
this.buffer = new FloatRingBuffer(sampleCounts * format.getNumOfChannels());
}
/** Actual implementation of the ring buffer. */
private static class FloatRingBuffer {
private final float[] buffer;
private int nextIndex = 0;
public FloatRingBuffer(int flatSize) {
buffer = new float[flatSize];
}
/**
* Loads a slice of the float array to the ring buffer. If the float array is longer than ring
* buffer's capacity, samples with lower indices in the array will be ignored.
*/
public void load(float[] newData, int offset, int size) {
if (offset + size > newData.length) {
throw new IllegalArgumentException(
String.format(
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
offset, size, newData.length));
}
// If buffer can't hold all the data, only keep the most recent data of size buffer.length
if (size > buffer.length) {
offset += (size - buffer.length);
size = buffer.length;
}
if (nextIndex + size < buffer.length) {
// No need to wrap nextIndex, just copy newData[offset:offset + size]
// to buffer[nextIndex:nextIndex+size]
arraycopy(newData, offset, buffer, nextIndex, size);
} else {
// Need to wrap nextIndex, perform copy in two chunks.
int firstChunkSize = buffer.length - nextIndex;
// First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
// Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
}
nextIndex = (nextIndex + size) % buffer.length;
}
public ByteBuffer getBuffer() {
// Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
// can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
// generally we don't create direct buffer for every invocation.
ByteBuffer byteBuffer = ByteBuffer.allocate(Float.SIZE / 8 * buffer.length);
byteBuffer.order(ByteOrder.nativeOrder());
FloatBuffer result = byteBuffer.asFloatBuffer();
result.put(buffer, nextIndex, buffer.length - nextIndex);
result.put(buffer, 0, nextIndex);
byteBuffer.rewind();
return byteBuffer;
}
public int getCapacity() {
return buffer.length;
}
}
}

View File

@ -16,6 +16,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
android_library(
name = "audio_data",
srcs = ["AudioData.java"],
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library( android_library(
name = "category", name = "category",
srcs = ["Category.java"], srcs = ["Category.java"],

View File

@ -31,11 +31,22 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
InputT convertToTaskInput(List<Packet> packets); InputT convertToTaskInput(List<Packet> packets);
} }
/** Interface for the customizable MediaPipe task result listener. */ /**
* Interface for the customizable MediaPipe task result listener that can reteive both task result
* objects and the correpsonding input data.
*/
public interface ResultListener<OutputT extends TaskResult, InputT> { public interface ResultListener<OutputT extends TaskResult, InputT> {
void run(OutputT result, InputT input); void run(OutputT result, InputT input);
} }
/**
* Interface for the customizable MediaPipe task result listener that can only reteive task result
* objects.
*/
public interface PureResultListener<OutputT extends TaskResult> {
void run(OutputT result);
}
private static final String TAG = "OutputHandler"; private static final String TAG = "OutputHandler";
// A task-specific graph output packet converter that should be implemented per task. // A task-specific graph output packet converter that should be implemented per task.
private OutputPacketConverter<OutputT, InputT> outputPacketConverter; private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
@ -45,6 +56,8 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
protected ErrorListener errorListener; protected ErrorListener errorListener;
// The cached task result for non latency sensitive use cases. // The cached task result for non latency sensitive use cases.
protected OutputT cachedTaskResult; protected OutputT cachedTaskResult;
// The latest output timestamp.
protected long latestOutputTimestamp = -1;
// Whether the output handler should react to timestamp-bound changes by outputting empty packets. // Whether the output handler should react to timestamp-bound changes by outputting empty packets.
private boolean handleTimestampBoundChanges = false; private boolean handleTimestampBoundChanges = false;
@ -98,6 +111,11 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
return taskResult; return taskResult;
} }
/* Returns the latest output timestamp. */
public long getLatestOutputTimestamp() {
return latestOutputTimestamp;
}
/** /**
* Handles a list of output {@link Packet}s. Invoked when a packet list become available. * Handles a list of output {@link Packet}s. Invoked when a packet list become available.
* *
@ -109,6 +127,7 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
taskResult = outputPacketConverter.convertToTaskResult(packets); taskResult = outputPacketConverter.convertToTaskResult(packets);
if (resultListener == null) { if (resultListener == null) {
cachedTaskResult = taskResult; cachedTaskResult = taskResult;
latestOutputTimestamp = packets.get(0).getTimestamp();
} else { } else {
InputT taskInput = outputPacketConverter.convertToTaskInput(packets); InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
resultListener.run(taskResult, taskInput); resultListener.run(taskResult, taskInput);

View File

@ -93,6 +93,7 @@ public class TaskRunner implements AutoCloseable {
public synchronized TaskResult process(Map<String, Packet> inputs) { public synchronized TaskResult process(Map<String, Packet> inputs) {
addPackets(inputs, generateSyntheticTimestamp()); addPackets(inputs, generateSyntheticTimestamp());
graph.waitUntilGraphIdle(); graph.waitUntilGraphIdle();
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
return outputHandler.retrieveCachedTaskResult(); return outputHandler.retrieveCachedTaskResult();
} }