Implement MediaPipe Tasks AudioClassifier Java API.
PiperOrigin-RevId: 487570148
This commit is contained in:
parent
2ce3a9719e
commit
da4d455d0c
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
|
||||||
|
option java_outer_classname = "AudioClassifierGraphOptionsProto";
|
||||||
|
|
||||||
message AudioClassifierGraphOptions {
|
message AudioClassifierGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional AudioClassifierGraphOptions ext = 451755788;
|
optional AudioClassifierGraphOptions ext = 451755788;
|
||||||
|
|
76
mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
Normal file
76
mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "core",
|
||||||
|
srcs = glob(["core/*.java"]),
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":libmediapipe_tasks_audio_jni_lib",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# The native library of all MediaPipe audio tasks.
|
||||||
|
cc_binary(
|
||||||
|
name = "libmediapipe_tasks_audio_jni.so",
|
||||||
|
linkshared = 1,
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "libmediapipe_tasks_audio_jni_lib",
|
||||||
|
srcs = [":libmediapipe_tasks_audio_jni.so"],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "audioclassifier",
|
||||||
|
srcs = [
|
||||||
|
"audioclassifier/AudioClassifier.java",
|
||||||
|
"audioclassifier/AudioClassifierResult.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = "audioclassifier/AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
":core",
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.audio.audioclassifier">
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,399 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.ProtoUtil;
|
||||||
|
import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs audio classification on audio clips or audio stream.
|
||||||
|
*
|
||||||
|
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
|
||||||
|
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
|
||||||
|
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
|
||||||
|
*
|
||||||
|
* <p>Input tensor: (kTfLiteFloat32)
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>input audio buffer of size `[batch * samples]`.
|
||||||
|
* <li>batch inference is not supported (`batch` is required to be 1).
|
||||||
|
* <li>for multi-channel models, the channels need be interleaved.
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>At least one output tensor with: (kTfLiteFloat32)
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>`[1 x N]` array with `N` represents the number of categories.
|
||||||
|
* <li>optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
|
||||||
|
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
|
||||||
|
* `category_name` field of the results. The `display_name` field is filled from the
|
||||||
|
* AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
|
||||||
|
* `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
|
||||||
|
* these are available, only the `index` field of the results will be filled.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public final class AudioClassifier extends BaseAudioTaskApi {
|
||||||
|
private static final String TAG = AudioClassifier.class.getSimpleName();
|
||||||
|
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
|
||||||
|
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
|
||||||
|
private static final List<String> INPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList(
|
||||||
|
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
|
||||||
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList(
|
||||||
|
"CLASSIFICATIONS:classifications_out",
|
||||||
|
"TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
|
||||||
|
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
|
||||||
|
private static final String TASK_GRAPH_NAME =
|
||||||
|
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||||
|
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||||
|
|
||||||
|
static {
|
||||||
|
ProtoUtil.registerTypeName(
|
||||||
|
ClassificationsProto.ClassificationResult.class,
|
||||||
|
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||||
|
* AudioClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the classification model in the assets.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||||
|
* AudioClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the classification model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from a model buffer and default {@link
|
||||||
|
* AudioClassifierOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
|
||||||
|
* classification model.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param options an {@link AudioClassifierOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||||
|
*/
|
||||||
|
public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
|
||||||
|
OutputHandler<AudioClassifierResult, Void> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<AudioClassifierResult, Void>() {
|
||||||
|
@Override
|
||||||
|
public AudioClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
try {
|
||||||
|
if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
|
// For audio stream mode.
|
||||||
|
return AudioClassifierResult.createFromProto(
|
||||||
|
PacketGetter.getProto(
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
|
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||||
|
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
|
||||||
|
/ MICROSECONDS_PER_MILLISECOND);
|
||||||
|
} else {
|
||||||
|
// For audio clips mode.
|
||||||
|
return AudioClassifierResult.createFromProtoList(
|
||||||
|
PacketGetter.getProtoVector(
|
||||||
|
packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||||
|
ClassificationsProto.ClassificationResult.parser()),
|
||||||
|
-1);
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Void convertToTaskInput(List<Packet> packets) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (options.resultListener().isPresent()) {
|
||||||
|
ResultListener<AudioClassifierResult, Void> resultListener =
|
||||||
|
new ResultListener<>() {
|
||||||
|
@Override
|
||||||
|
public void run(AudioClassifierResult audioClassifierResult, Void input) {
|
||||||
|
options.resultListener().get().run(audioClassifierResult);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
handler.setResultListener(resultListener);
|
||||||
|
}
|
||||||
|
options.errorListener().ifPresent(handler::setErrorListener);
|
||||||
|
// Audio tasks should not drop input audio due to flow limiting, which may cause data
|
||||||
|
// inconsistency.
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<AudioClassifierOptions>builder()
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(INPUT_STREAMS)
|
||||||
|
.setOutputStreams(OUTPUT_STREAMS)
|
||||||
|
.setTaskOptions(options)
|
||||||
|
.setEnableFlowLimiting(false)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
return new AudioClassifier(runner, options.runningMode());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
|
||||||
|
* RunningMode}.
|
||||||
|
*
|
||||||
|
* @param taskRunner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||||
|
*/
|
||||||
|
private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
|
||||||
|
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Performs audio classification on the provided audio clip. Only use this method when the
|
||||||
|
* AudioClassifier is created with the audio clips mode.
|
||||||
|
*
|
||||||
|
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
|
||||||
|
* audio clips with various length and audio sample rate. It's required to provide the
|
||||||
|
* corresponding audio sample rate within the {@link AudioData} object.
|
||||||
|
*
|
||||||
|
* <p>The input audio clip may be longer than what the model is able to process in a single
|
||||||
|
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
|
||||||
|
* different timestamps. For this reason, this function returns a vector of ClassificationResult
|
||||||
|
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
|
||||||
|
* chunk data that was classified, e.g:
|
||||||
|
*
|
||||||
|
* ClassificationResult #0 (first chunk of data):
|
||||||
|
* timestamp_ms: 0 (starts at 0ms)
|
||||||
|
* classifications #0 (single head model):
|
||||||
|
* category #0:
|
||||||
|
* category_name: "Speech"
|
||||||
|
* score: 0.6
|
||||||
|
* category #1:
|
||||||
|
* category_name: "Music"
|
||||||
|
* score: 0.2
|
||||||
|
* ClassificationResult #1 (second chunk of data):
|
||||||
|
* timestamp_ms: 800 (starts at 800ms)
|
||||||
|
* classifications #0 (single head model):
|
||||||
|
* category #0:
|
||||||
|
* category_name: "Speech"
|
||||||
|
* score: 0.5
|
||||||
|
* category #1:
|
||||||
|
* category_name: "Silence"
|
||||||
|
* score: 0.1
|
||||||
|
*
|
||||||
|
* @param audioClip a MediaPipe {@link AudioData} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public AudioClassifierResult classify(AudioData audioClip) {
|
||||||
|
return (AudioClassifierResult) processAudioClip(audioClip);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
|
||||||
|
* use this method when the AudioClassifier is created with the audio stream mode.
|
||||||
|
*
|
||||||
|
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
|
||||||
|
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
|
||||||
|
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
|
||||||
|
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
|
||||||
|
* timestamps must be monotonically increasing. This method will return immediately after
|
||||||
|
* the input audio data is accepted. The results will be available in the `resultListener`
|
||||||
|
* provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
|
||||||
|
* auido stream data such as microphone input.
|
||||||
|
*
|
||||||
|
* <p>The input audio block may be longer than what the model is able to process in a single
|
||||||
|
* inference. When this occurs, the input audio block is split into multiple chunks. For this
|
||||||
|
* reason, the callback may be called multiple times (once per chunk) for each call to this
|
||||||
|
* function.
|
||||||
|
*
|
||||||
|
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void classifyAsync(AudioData audioBlock, long timestampMs) {
|
||||||
|
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
|
||||||
|
sendAudioStreamData(audioBlock, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up and {@link AudioClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class AudioClassifierOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link AudioClassifierOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the {@link BaseOptions} for the audio classifier task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
|
||||||
|
* mode. Image classifier has two modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
|
||||||
|
* audio clips to the `classify` method, and will receive the classification results as
|
||||||
|
* the return value.
|
||||||
|
* <li>AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
|
||||||
|
* from microphone. Users call `classifyAsync` to push the audio data into the
|
||||||
|
* AudioClassifier, the classification results will be available in the result callback
|
||||||
|
* when the audio classifier finishes the work.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||||
|
* score threshold, number of results, etc.
|
||||||
|
*/
|
||||||
|
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||||
|
* the audio classifier is in the audio stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Builder setResultListener(
|
||||||
|
PureResultListener<AudioClassifierResult> resultListener);
|
||||||
|
|
||||||
|
/** Sets an optional {@link ErrorListener}. */
|
||||||
|
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||||
|
|
||||||
|
abstract AudioClassifierOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link AudioClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||||
|
* properly configured. The result listener should only be set when the audio classifier
|
||||||
|
* is in the audio stream mode.
|
||||||
|
*/
|
||||||
|
public final AudioClassifierOptions build() {
|
||||||
|
AudioClassifierOptions options = autoBuild();
|
||||||
|
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
|
||||||
|
if (!options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The audio classifier is in the audio stream mode, a user-defined result listener"
|
||||||
|
+ " must be provided in the AudioClassifierOptions.");
|
||||||
|
}
|
||||||
|
} else if (options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
||||||
|
+ " shouldn't be provided in AudioClassifierOptions.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
|
abstract Optional<ClassifierOptions> classifierOptions();
|
||||||
|
|
||||||
|
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
||||||
|
|
||||||
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
||||||
|
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder();
|
||||||
|
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||||
|
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||||
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||||
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
||||||
|
.setBaseOptions(baseOptionsBuilder);
|
||||||
|
if (classifierOptions().isPresent()) {
|
||||||
|
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||||
|
}
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
||||||
|
taskOptionsBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/** Represents the classification results generated by {@link AudioClassifier}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class AudioClassifierResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifierResult} instance from a list of {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf messages.
|
||||||
|
*
|
||||||
|
* @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
|
||||||
|
* to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static AudioClassifierResult createFromProtoList(
|
||||||
|
List<ClassificationsProto.ClassificationResult> protoList, long timestampMs) {
|
||||||
|
List<ClassificationResult> classificationResultList = new ArrayList<>();
|
||||||
|
for (ClassificationsProto.ClassificationResult proto : protoList) {
|
||||||
|
classificationResultList.add(ClassificationResult.createFromProto(proto));
|
||||||
|
}
|
||||||
|
return new AutoValue_AudioClassifierResult(
|
||||||
|
Optional.of(classificationResultList), Optional.empty(), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an {@link AudioClassifierResult} instance from a {@link
|
||||||
|
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||||
|
*
|
||||||
|
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||||
|
* @param timestampMs a timestamp for this result.
|
||||||
|
*/
|
||||||
|
static AudioClassifierResult createFromProto(
|
||||||
|
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||||
|
return new AutoValue_AudioClassifierResult(
|
||||||
|
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
|
||||||
|
* per classifier head. The list represents the audio classification result of an audio clip, and
|
||||||
|
* is only available when running with the audio clips mode.
|
||||||
|
*/
|
||||||
|
public abstract Optional<List<ClassificationResult>> classificationResultList();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
|
||||||
|
* represents one audio classification result in an audio stream, and s only available when
|
||||||
|
* running with the audio stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Optional<ClassificationResult> classificationResult();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -0,0 +1,151 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.core;
|
||||||
|
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/** The base class of MediaPipe audio tasks. */
|
||||||
|
public class BaseAudioTaskApi implements AutoCloseable {
|
||||||
|
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||||
|
private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
|
||||||
|
|
||||||
|
private final TaskRunner runner;
|
||||||
|
private final RunningMode runningMode;
|
||||||
|
private final String audioStreamName;
|
||||||
|
private final String sampleRateStreamName;
|
||||||
|
private double defaultSampleRate;
|
||||||
|
|
||||||
|
static {
|
||||||
|
System.loadLibrary("mediapipe_tasks_audio_jni");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize a {@link BaseAudioTaskApi}.
|
||||||
|
*
|
||||||
|
* @param runner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||||
|
* @param audioStreamName the name of the input audio stream.
|
||||||
|
* @param sampleRateStreamName the name of the audio sample rate stream.
|
||||||
|
*/
|
||||||
|
public BaseAudioTaskApi(
|
||||||
|
TaskRunner runner,
|
||||||
|
RunningMode runningMode,
|
||||||
|
String audioStreamName,
|
||||||
|
String sampleRateStreamName) {
|
||||||
|
this.runner = runner;
|
||||||
|
this.runningMode = runningMode;
|
||||||
|
this.audioStreamName = audioStreamName;
|
||||||
|
this.sampleRateStreamName = sampleRateStreamName;
|
||||||
|
this.defaultSampleRate = -1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A synchronous method to process audio clips. The call blocks the current thread until a failure
|
||||||
|
* status or a successful result is returned.
|
||||||
|
*
|
||||||
|
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||||
|
* @throws MediaPipeException if the task is not in the audio clips mode.
|
||||||
|
*/
|
||||||
|
protected TaskResult processAudioClip(AudioData audioClip) {
|
||||||
|
if (runningMode != RunningMode.AUDIO_CLIPS) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the audio clips mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(
|
||||||
|
audioStreamName,
|
||||||
|
runner
|
||||||
|
.getPacketCreator()
|
||||||
|
.createMatrix(
|
||||||
|
audioClip.getFormat().getNumOfChannels(),
|
||||||
|
audioClip.getBufferLength(),
|
||||||
|
audioClip.getBuffer()));
|
||||||
|
inputPackets.put(
|
||||||
|
sampleRateStreamName,
|
||||||
|
runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
|
||||||
|
return runner.process(inputPackets);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks or sets the audio sample rate in the audio stream mode.
|
||||||
|
*
|
||||||
|
* @param sampleRate the audio sample rate.
|
||||||
|
* @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
|
||||||
|
* rate is inconsisent with the previously recevied.
|
||||||
|
*/
|
||||||
|
protected void checkOrSetSampleRate(double sampleRate) {
|
||||||
|
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
if (defaultSampleRate > 0) {
|
||||||
|
if (Double.compare(sampleRate, defaultSampleRate) != 0) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
|
||||||
|
"The input audio sample rate: "
|
||||||
|
+ sampleRate
|
||||||
|
+ " is inconsistent with the previously provided: "
|
||||||
|
+ defaultSampleRate);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
|
||||||
|
runner.send(inputPackets, PRESTREAM_TIMESTAMP);
|
||||||
|
defaultSampleRate = sampleRate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
|
||||||
|
* available in the user-defined result listener.
|
||||||
|
*
|
||||||
|
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||||
|
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||||
|
* @throws MediaPipeException if the task is not in the stream mode.
|
||||||
|
*/
|
||||||
|
protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
|
||||||
|
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||||
|
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||||
|
+ runningMode.name());
|
||||||
|
}
|
||||||
|
Map<String, Packet> inputPackets = new HashMap<>();
|
||||||
|
inputPackets.put(
|
||||||
|
audioStreamName,
|
||||||
|
runner
|
||||||
|
.getPacketCreator()
|
||||||
|
.createMatrix(
|
||||||
|
audioClip.getFormat().getNumOfChannels(),
|
||||||
|
audioClip.getBufferLength(),
|
||||||
|
audioClip.getBuffer()));
|
||||||
|
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Closes and cleans up the MediaPipe audio task. */
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
runner.close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.audio.core;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MediaPipe audio task running mode. A MediaPipe audio task can be run with two different modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips.
|
||||||
|
* <li>AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from
|
||||||
|
* microphone.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public enum RunningMode {
|
||||||
|
AUDIO_CLIPS,
|
||||||
|
AUDIO_STREAM,
|
||||||
|
}
|
|
@ -0,0 +1,334 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import static java.lang.System.arraycopy;
|
||||||
|
|
||||||
|
import android.media.AudioFormat;
|
||||||
|
import android.media.AudioRecord;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines a ring buffer and some utility functions to prepare the input audio samples.
|
||||||
|
*
|
||||||
|
* <p>It maintains a <a href="https://en.wikipedia.org/wiki/Circular_buffer">Ring Buffer</a> to hold
|
||||||
|
* input audio data. Clients could feed input audio data via `load` methods and access the
|
||||||
|
* aggregated audio samples via `getTensorBuffer` method.
|
||||||
|
*
|
||||||
|
* <p>Note that this class can only handle input audio in Float (in {@link
|
||||||
|
* android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link
|
||||||
|
* android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio
|
||||||
|
* samples in PCM Float encoding.
|
||||||
|
*
|
||||||
|
* <p>Typical usage in Kotlin
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* val audioData = AudioData.create(format, modelInputLength)
|
||||||
|
* audioData.load(newData)
|
||||||
|
* </pre>
|
||||||
|
*
|
||||||
|
* <p>Another sample usage with {@link android.media.AudioRecord}
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* val audioData = AudioData.create(format, modelInputLength)
|
||||||
|
* Timer().scheduleAtFixedRate(delay, period) {
|
||||||
|
* audioData.load(audioRecord)
|
||||||
|
* }
|
||||||
|
* </pre>
|
||||||
|
*/
|
||||||
|
public class AudioData {
|
||||||
|
|
||||||
|
private static final String TAG = AudioData.class.getSimpleName();
|
||||||
|
private final FloatRingBuffer buffer;
|
||||||
|
private final AudioDataFormat format;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
|
||||||
|
* sampleCounts} * {@code format.getNumOfChannels()}.
|
||||||
|
*
|
||||||
|
* @param format the expected {@link AudioDataFormat} of audio data loaded into this class.
|
||||||
|
* @param sampleCounts the number of samples.
|
||||||
|
*/
|
||||||
|
public static AudioData create(AudioDataFormat format, int sampleCounts) {
|
||||||
|
return new AudioData(format, sampleCounts);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link AudioData} instance with a ring buffer whose size is {@code sampleCounts} *
|
||||||
|
* {@code format.getChannelCount()}.
|
||||||
|
*
|
||||||
|
* @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
|
||||||
|
* the number of channels and sample rate.
|
||||||
|
* @param sampleCounts the number of samples to be fed into the model
|
||||||
|
*/
|
||||||
|
public static AudioData create(AudioFormat format, int sampleCounts) {
|
||||||
|
return new AudioData(AudioDataFormat.create(format), sampleCounts);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps a few constants describing the format of the incoming audio samples, namely number of
|
||||||
|
* channels and the sample rate. By default, num of channels is set to 1.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class AudioDataFormat {
|
||||||
|
private static final int DEFAULT_NUM_OF_CHANNELS = 1;
|
||||||
|
|
||||||
|
/** Creates a {@link AudioFormat} instance from Android AudioFormat class. */
|
||||||
|
public static AudioDataFormat create(AudioFormat format) {
|
||||||
|
return AudioDataFormat.builder()
|
||||||
|
.setNumOfChannels(format.getChannelCount())
|
||||||
|
.setSampleRate(format.getSampleRate())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract int getNumOfChannels();
|
||||||
|
|
||||||
|
public abstract float getSampleRate();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_AudioData_AudioDataFormat.Builder()
|
||||||
|
.setNumOfChannels(DEFAULT_NUM_OF_CHANNELS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Builder for {@link AudioDataFormat} */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
|
||||||
|
/* By default, it's set to have 1 channel. */
|
||||||
|
public abstract Builder setNumOfChannels(int value);
|
||||||
|
|
||||||
|
public abstract Builder setSampleRate(float value);
|
||||||
|
|
||||||
|
abstract AudioDataFormat autoBuild();
|
||||||
|
|
||||||
|
public AudioDataFormat build() {
|
||||||
|
AudioDataFormat format = autoBuild();
|
||||||
|
if (format.getNumOfChannels() <= 0) {
|
||||||
|
throw new IllegalArgumentException("Number of channels should be greater than 0");
|
||||||
|
}
|
||||||
|
if (format.getSampleRate() <= 0) {
|
||||||
|
throw new IllegalArgumentException("Sample rate should be greater than 0");
|
||||||
|
}
|
||||||
|
return format;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the input audio samples {@code src} in the ring buffer.
|
||||||
|
*
|
||||||
|
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
|
||||||
|
* multi-channel input, the array is interleaved.
|
||||||
|
*/
|
||||||
|
public void load(float[] src) {
|
||||||
|
load(src, 0, src.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the input audio samples {@code src} in the ring buffer.
|
||||||
|
*
|
||||||
|
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
|
||||||
|
* multi-channel input, the array is interleaved.
|
||||||
|
* @param offsetInFloat starting position in the {@code src} array
|
||||||
|
* @param sizeInFloat the number of float values to be copied
|
||||||
|
* @throws IllegalArgumentException for incompatible audio format or incorrect input size
|
||||||
|
*/
|
||||||
|
public void load(float[] src, int offsetInFloat, int sizeInFloat) {
|
||||||
|
if (sizeInFloat % format.getNumOfChannels() != 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Size (%d) needs to be a multiplier of the number of channels (%d)",
|
||||||
|
sizeInFloat, format.getNumOfChannels()));
|
||||||
|
}
|
||||||
|
buffer.load(src, offsetInFloat, sizeInFloat);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
|
||||||
|
* buffer.
|
||||||
|
*
|
||||||
|
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
|
||||||
|
* multi-channel input, the array is interleaved.
|
||||||
|
*/
|
||||||
|
public void load(short[] src) {
|
||||||
|
load(src, 0, src.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
|
||||||
|
* buffer.
|
||||||
|
*
|
||||||
|
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
|
||||||
|
* multi-channel input, the array is interleaved.
|
||||||
|
* @param offsetInShort starting position in the src array
|
||||||
|
* @param sizeInShort the number of short values to be copied
|
||||||
|
* @throws IllegalArgumentException if the source array can't be copied
|
||||||
|
*/
|
||||||
|
public void load(short[] src, int offsetInShort, int sizeInShort) {
|
||||||
|
if (offsetInShort + sizeInShort > src.length) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
|
||||||
|
offsetInShort, sizeInShort, src.length));
|
||||||
|
}
|
||||||
|
float[] floatData = new float[sizeInShort];
|
||||||
|
for (int i = 0; i < sizeInShort; i++) {
|
||||||
|
// Convert the data to PCM Float encoding i.e. values between -1 and 1
|
||||||
|
floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
|
||||||
|
}
|
||||||
|
load(floatData);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
|
||||||
|
* supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
|
||||||
|
*
|
||||||
|
* @param record an instance of {@link android.media.AudioRecord}
|
||||||
|
* @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
|
||||||
|
* there was no new data in the AudioRecord or an error occurred, this method will return 0.
|
||||||
|
* @throws IllegalArgumentException for unsupported audio encoding format
|
||||||
|
* @throws IllegalStateException if reading from AudioRecord failed
|
||||||
|
*/
|
||||||
|
public int load(AudioRecord record) {
|
||||||
|
if (!this.format.equals(AudioDataFormat.create(record.getFormat()))) {
|
||||||
|
throw new IllegalArgumentException("Incompatible audio format.");
|
||||||
|
}
|
||||||
|
int loadedValues = 0;
|
||||||
|
if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
|
||||||
|
float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
|
||||||
|
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
|
||||||
|
if (loadedValues > 0) {
|
||||||
|
load(newData, 0, loadedValues);
|
||||||
|
return loadedValues;
|
||||||
|
}
|
||||||
|
} else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
|
||||||
|
short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
|
||||||
|
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
|
||||||
|
if (loadedValues > 0) {
|
||||||
|
load(newData, 0, loadedValues);
|
||||||
|
return loadedValues;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (loadedValues) {
|
||||||
|
case AudioRecord.ERROR_INVALID_OPERATION:
|
||||||
|
throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
|
||||||
|
|
||||||
|
case AudioRecord.ERROR_BAD_VALUE:
|
||||||
|
throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
|
||||||
|
|
||||||
|
case AudioRecord.ERROR_DEAD_OBJECT:
|
||||||
|
throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
|
||||||
|
|
||||||
|
case AudioRecord.ERROR:
|
||||||
|
throw new IllegalStateException("AudioRecord.ERROR");
|
||||||
|
|
||||||
|
default:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a float array holding all the available audio samples in {@link
|
||||||
|
* android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
|
||||||
|
*/
|
||||||
|
public float[] getBuffer() {
|
||||||
|
float[] bufferData = new float[buffer.getCapacity()];
|
||||||
|
ByteBuffer byteBuffer = buffer.getBuffer();
|
||||||
|
byteBuffer.asFloatBuffer().get(bufferData);
|
||||||
|
return bufferData;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Returns the {@link AudioDataFormat} associated with the tensor. */
|
||||||
|
public AudioDataFormat getFormat() {
|
||||||
|
return format;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Returns the audio buffer length. */
|
||||||
|
public int getBufferLength() {
|
||||||
|
return buffer.getCapacity() / format.getNumOfChannels();
|
||||||
|
}
|
||||||
|
|
||||||
|
private AudioData(AudioDataFormat format, int sampleCounts) {
|
||||||
|
this.format = format;
|
||||||
|
this.buffer = new FloatRingBuffer(sampleCounts * format.getNumOfChannels());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Actual implementation of the ring buffer. */
|
||||||
|
private static class FloatRingBuffer {
|
||||||
|
|
||||||
|
private final float[] buffer;
|
||||||
|
private int nextIndex = 0;
|
||||||
|
|
||||||
|
public FloatRingBuffer(int flatSize) {
|
||||||
|
buffer = new float[flatSize];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a slice of the float array to the ring buffer. If the float array is longer than ring
|
||||||
|
* buffer's capacity, samples with lower indices in the array will be ignored.
|
||||||
|
*/
|
||||||
|
public void load(float[] newData, int offset, int size) {
|
||||||
|
if (offset + size > newData.length) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format(
|
||||||
|
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
|
||||||
|
offset, size, newData.length));
|
||||||
|
}
|
||||||
|
// If buffer can't hold all the data, only keep the most recent data of size buffer.length
|
||||||
|
if (size > buffer.length) {
|
||||||
|
offset += (size - buffer.length);
|
||||||
|
size = buffer.length;
|
||||||
|
}
|
||||||
|
if (nextIndex + size < buffer.length) {
|
||||||
|
// No need to wrap nextIndex, just copy newData[offset:offset + size]
|
||||||
|
// to buffer[nextIndex:nextIndex+size]
|
||||||
|
arraycopy(newData, offset, buffer, nextIndex, size);
|
||||||
|
} else {
|
||||||
|
// Need to wrap nextIndex, perform copy in two chunks.
|
||||||
|
int firstChunkSize = buffer.length - nextIndex;
|
||||||
|
// First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
|
||||||
|
arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
|
||||||
|
// Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
|
||||||
|
arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
nextIndex = (nextIndex + size) % buffer.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ByteBuffer getBuffer() {
|
||||||
|
// Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
|
||||||
|
// can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
|
||||||
|
// generally we don't create direct buffer for every invocation.
|
||||||
|
ByteBuffer byteBuffer = ByteBuffer.allocate(Float.SIZE / 8 * buffer.length);
|
||||||
|
byteBuffer.order(ByteOrder.nativeOrder());
|
||||||
|
FloatBuffer result = byteBuffer.asFloatBuffer();
|
||||||
|
result.put(buffer, nextIndex, buffer.length - nextIndex);
|
||||||
|
result.put(buffer, 0, nextIndex);
|
||||||
|
byteBuffer.rewind();
|
||||||
|
return byteBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getCapacity() {
|
||||||
|
return buffer.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "audio_data",
|
||||||
|
srcs = ["AudioData.java"],
|
||||||
|
deps = [
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
android_library(
|
android_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["Category.java"],
|
srcs = ["Category.java"],
|
||||||
|
|
|
@ -31,11 +31,22 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
||||||
InputT convertToTaskInput(List<Packet> packets);
|
InputT convertToTaskInput(List<Packet> packets);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Interface for the customizable MediaPipe task result listener. */
|
/**
|
||||||
|
* Interface for the customizable MediaPipe task result listener that can reteive both task result
|
||||||
|
* objects and the correpsonding input data.
|
||||||
|
*/
|
||||||
public interface ResultListener<OutputT extends TaskResult, InputT> {
|
public interface ResultListener<OutputT extends TaskResult, InputT> {
|
||||||
void run(OutputT result, InputT input);
|
void run(OutputT result, InputT input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for the customizable MediaPipe task result listener that can only reteive task result
|
||||||
|
* objects.
|
||||||
|
*/
|
||||||
|
public interface PureResultListener<OutputT extends TaskResult> {
|
||||||
|
void run(OutputT result);
|
||||||
|
}
|
||||||
|
|
||||||
private static final String TAG = "OutputHandler";
|
private static final String TAG = "OutputHandler";
|
||||||
// A task-specific graph output packet converter that should be implemented per task.
|
// A task-specific graph output packet converter that should be implemented per task.
|
||||||
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
|
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
|
||||||
|
@ -45,6 +56,8 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
||||||
protected ErrorListener errorListener;
|
protected ErrorListener errorListener;
|
||||||
// The cached task result for non latency sensitive use cases.
|
// The cached task result for non latency sensitive use cases.
|
||||||
protected OutputT cachedTaskResult;
|
protected OutputT cachedTaskResult;
|
||||||
|
// The latest output timestamp.
|
||||||
|
protected long latestOutputTimestamp = -1;
|
||||||
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
|
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
|
||||||
private boolean handleTimestampBoundChanges = false;
|
private boolean handleTimestampBoundChanges = false;
|
||||||
|
|
||||||
|
@ -98,6 +111,11 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
||||||
return taskResult;
|
return taskResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Returns the latest output timestamp. */
|
||||||
|
public long getLatestOutputTimestamp() {
|
||||||
|
return latestOutputTimestamp;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
|
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
|
||||||
*
|
*
|
||||||
|
@ -109,6 +127,7 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
||||||
taskResult = outputPacketConverter.convertToTaskResult(packets);
|
taskResult = outputPacketConverter.convertToTaskResult(packets);
|
||||||
if (resultListener == null) {
|
if (resultListener == null) {
|
||||||
cachedTaskResult = taskResult;
|
cachedTaskResult = taskResult;
|
||||||
|
latestOutputTimestamp = packets.get(0).getTimestamp();
|
||||||
} else {
|
} else {
|
||||||
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
|
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
|
||||||
resultListener.run(taskResult, taskInput);
|
resultListener.run(taskResult, taskInput);
|
||||||
|
|
|
@ -93,6 +93,7 @@ public class TaskRunner implements AutoCloseable {
|
||||||
public synchronized TaskResult process(Map<String, Packet> inputs) {
|
public synchronized TaskResult process(Map<String, Packet> inputs) {
|
||||||
addPackets(inputs, generateSyntheticTimestamp());
|
addPackets(inputs, generateSyntheticTimestamp());
|
||||||
graph.waitUntilGraphIdle();
|
graph.waitUntilGraphIdle();
|
||||||
|
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
|
||||||
return outputHandler.retrieveCachedTaskResult();
|
return outputHandler.retrieveCachedTaskResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user