Implement MediaPipe Tasks AudioClassifier Java API.
PiperOrigin-RevId: 487570148
This commit is contained in:
parent
2ce3a9719e
commit
da4d455d0c
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.audio.audioclassifier.proto";
|
||||
option java_outer_classname = "AudioClassifierGraphOptionsProto";
|
||||
|
||||
message AudioClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional AudioClassifierGraphOptions ext = 451755788;
|
||||
|
|
76
mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
Normal file
76
mediapipe/tasks/java/com/google/mediapipe/tasks/audio/BUILD
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
android_library(
|
||||
name = "core",
|
||||
srcs = glob(["core/*.java"]),
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
":libmediapipe_tasks_audio_jni_lib",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# The native library of all MediaPipe audio tasks.
|
||||
cc_binary(
|
||||
name = "libmediapipe_tasks_audio_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libmediapipe_tasks_audio_jni_lib",
|
||||
srcs = [":libmediapipe_tasks_audio_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "audioclassifier",
|
||||
srcs = [
|
||||
"audioclassifier/AudioClassifier.java",
|
||||
"audioclassifier/AudioClassifierResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "audioclassifier/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audio_data",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.audio.audioclassifier">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,399 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.audio.audioclassifier.proto.AudioClassifierGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
|
||||
import com.google.mediapipe.tasks.audio.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs audio classification on audio clips or audio stream.
|
||||
*
|
||||
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
|
||||
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
|
||||
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
|
||||
*
|
||||
* <p>Input tensor: (kTfLiteFloat32)
|
||||
*
|
||||
* <ul>
|
||||
* <li>input audio buffer of size `[batch * samples]`.
|
||||
* <li>batch inference is not supported (`batch` is required to be 1).
|
||||
* <li>for multi-channel models, the channels need be interleaved.
|
||||
* </ul>
|
||||
*
|
||||
* <p>At least one output tensor with: (kTfLiteFloat32)
|
||||
*
|
||||
* <ul>
|
||||
* <li>`[1 x N]` array with `N` represents the number of categories.
|
||||
* <li>optional (but recommended) label items as AssociatedFiles with type TENSOR_AXIS_LABELS,
|
||||
* containing one label per line. The first such AssociatedFile (if any) is used to fill the
|
||||
* `category_name` field of the results. The `display_name` field is filled from the
|
||||
* AssociatedFile (if any) whose locale matches the `display_names_locale` field of the
|
||||
* `AudioClassifierOptions` used at creation time ("en" by default, i.e. English). If none of
|
||||
* these are available, only the `index` field of the results will be filled.
|
||||
* </ul>
|
||||
*/
|
||||
public final class AudioClassifier extends BaseAudioTaskApi {
|
||||
private static final String TAG = AudioClassifier.class.getSimpleName();
|
||||
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
|
||||
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"CLASSIFICATIONS:classifications_out",
|
||||
"TIMESTAMPED_CLASSIFICATIONS:timestamped_classifications_out"));
|
||||
private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
|
||||
private static final int TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX = 1;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||
|
||||
static {
|
||||
ProtoUtil.registerTypeName(
|
||||
ClassificationsProto.ClassificationResult.class,
|
||||
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||
* AudioClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the classification model in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from a model file and default {@link
|
||||
* AudioClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the classification model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from a model buffer and default {@link
|
||||
* AudioClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the
|
||||
* classification model.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||
return createFromOptions(
|
||||
context, AudioClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifier} instance from an {@link AudioClassifierOptions} instance.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param options an {@link AudioClassifierOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link AudioClassifier} creation.
|
||||
*/
|
||||
public static AudioClassifier createFromOptions(Context context, AudioClassifierOptions options) {
|
||||
OutputHandler<AudioClassifierResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<AudioClassifierResult, Void>() {
|
||||
@Override
|
||||
public AudioClassifierResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
if (!packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
// For audio stream mode.
|
||||
return AudioClassifierResult.createFromProto(
|
||||
PacketGetter.getProto(
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp()
|
||||
/ MICROSECONDS_PER_MILLISECOND);
|
||||
} else {
|
||||
// For audio clips mode.
|
||||
return AudioClassifierResult.createFromProtoList(
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(TIMESTAMPED_CLASSIFICATIONS_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.parser()),
|
||||
-1);
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
if (options.resultListener().isPresent()) {
|
||||
ResultListener<AudioClassifierResult, Void> resultListener =
|
||||
new ResultListener<>() {
|
||||
@Override
|
||||
public void run(AudioClassifierResult audioClassifierResult, Void input) {
|
||||
options.resultListener().get().run(audioClassifierResult);
|
||||
}
|
||||
};
|
||||
handler.setResultListener(resultListener);
|
||||
}
|
||||
options.errorListener().ifPresent(handler::setErrorListener);
|
||||
// Audio tasks should not drop input audio due to flow limiting, which may cause data
|
||||
// inconsistency.
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<AudioClassifierOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(options)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
handler);
|
||||
return new AudioClassifier(runner, options.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link AudioClassifier} from a {@link TaskRunner} and {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||
*/
|
||||
private AudioClassifier(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/*
|
||||
* Performs audio classification on the provided audio clip. Only use this method when the
|
||||
* AudioClassifier is created with the audio clips mode.
|
||||
*
|
||||
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
|
||||
* audio clips with various length and audio sample rate. It's required to provide the
|
||||
* corresponding audio sample rate within the {@link AudioData} object.
|
||||
*
|
||||
* <p>The input audio clip may be longer than what the model is able to process in a single
|
||||
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
|
||||
* different timestamps. For this reason, this function returns a vector of ClassificationResult
|
||||
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
|
||||
* chunk data that was classified, e.g:
|
||||
*
|
||||
* ClassificationResult #0 (first chunk of data):
|
||||
* timestamp_ms: 0 (starts at 0ms)
|
||||
* classifications #0 (single head model):
|
||||
* category #0:
|
||||
* category_name: "Speech"
|
||||
* score: 0.6
|
||||
* category #1:
|
||||
* category_name: "Music"
|
||||
* score: 0.2
|
||||
* ClassificationResult #1 (second chunk of data):
|
||||
* timestamp_ms: 800 (starts at 800ms)
|
||||
* classifications #0 (single head model):
|
||||
* category #0:
|
||||
* category_name: "Speech"
|
||||
* score: 0.5
|
||||
* category #1:
|
||||
* category_name: "Silence"
|
||||
* score: 0.1
|
||||
*
|
||||
* @param audioClip a MediaPipe {@link AudioData} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public AudioClassifierResult classify(AudioData audioClip) {
|
||||
return (AudioClassifierResult) processAudioClip(audioClip);
|
||||
}
|
||||
|
||||
/*
|
||||
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
|
||||
* use this method when the AudioClassifier is created with the audio stream mode.
|
||||
*
|
||||
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
|
||||
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
|
||||
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
|
||||
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
|
||||
* timestamps must be monotonically increasing. This method will return immediately after
|
||||
* the input audio data is accepted. The results will be available in the `resultListener`
|
||||
* provided in the `AudioClassifierOptions`. The `classifyAsync` method is designed to process
|
||||
* auido stream data such as microphone input.
|
||||
*
|
||||
* <p>The input audio block may be longer than what the model is able to process in a single
|
||||
* inference. When this occurs, the input audio block is split into multiple chunks. For this
|
||||
* reason, the callback may be called multiple times (once per chunk) for each call to this
|
||||
* function.
|
||||
*
|
||||
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void classifyAsync(AudioData audioBlock, long timestampMs) {
|
||||
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
|
||||
sendAudioStreamData(audioBlock, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up and {@link AudioClassifier}. */
|
||||
@AutoValue
|
||||
public abstract static class AudioClassifierOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link AudioClassifierOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the {@link BaseOptions} for the audio classifier task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions baseOptions);
|
||||
|
||||
/**
|
||||
* Sets the {@link RunningMode} for the audio classifier task. Default to the audio clips
|
||||
* mode. Image classifier has two modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>AUDIO_CLIPS: The mode for running audio classification on audio clips. Users feed
|
||||
* audio clips to the `classify` method, and will receive the classification results as
|
||||
* the return value.
|
||||
* <li>AUDIO_STREAM: The mode for running audio classification on the audio stream, such as
|
||||
* from microphone. Users call `classifyAsync` to push the audio data into the
|
||||
* AudioClassifier, the classification results will be available in the result callback
|
||||
* when the audio classifier finishes the work.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode runningMode);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the classification results asynchronously when
|
||||
* the audio classifier is in the audio stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
PureResultListener<AudioClassifierResult> resultListener);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}. */
|
||||
public abstract Builder setErrorListener(ErrorListener errorListener);
|
||||
|
||||
abstract AudioClassifierOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link AudioClassifierOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the audio classifier
|
||||
* is in the audio stream mode.
|
||||
*/
|
||||
public final AudioClassifierOptions build() {
|
||||
AudioClassifierOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The audio classifier is in the audio stream mode, a user-defined result listener"
|
||||
+ " must be provided in the AudioClassifierOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The audio classifier is in the audio clips mode, a user-defined result listener"
|
||||
+ " shouldn't be provided in AudioClassifierOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
|
||||
abstract Optional<PureResultListener<AudioClassifierResult>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_AudioClassifier_AudioClassifierOptions.Builder()
|
||||
.setRunningMode(RunningMode.AUDIO_CLIPS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link AudioClassifierOptions} to a {@link CalculatorOptions} protobuf message.
|
||||
*/
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
AudioClassifierGraphOptionsProto.AudioClassifierGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.audioclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationResult;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Represents the classification results generated by {@link AudioClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class AudioClassifierResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifierResult} instance from a list of {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf messages.
|
||||
*
|
||||
* @param protoList a list of {@link ClassificationsProto.ClassificationResult} protobuf message
|
||||
* to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static AudioClassifierResult createFromProtoList(
|
||||
List<ClassificationsProto.ClassificationResult> protoList, long timestampMs) {
|
||||
List<ClassificationResult> classificationResultList = new ArrayList<>();
|
||||
for (ClassificationsProto.ClassificationResult proto : protoList) {
|
||||
classificationResultList.add(ClassificationResult.createFromProto(proto));
|
||||
}
|
||||
return new AutoValue_AudioClassifierResult(
|
||||
Optional.of(classificationResultList), Optional.empty(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an {@link AudioClassifierResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static AudioClassifierResult createFromProto(
|
||||
ClassificationsProto.ClassificationResult proto, long timestampMs) {
|
||||
return new AutoValue_AudioClassifierResult(
|
||||
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* A list of of timpstamed {@link ClassificationResult} objects, each contains one set of results
|
||||
* per classifier head. The list represents the audio classification result of an audio clip, and
|
||||
* is only available when running with the audio clips mode.
|
||||
*/
|
||||
public abstract Optional<List<ClassificationResult>> classificationResultList();
|
||||
|
||||
/**
|
||||
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
|
||||
* represents one audio classification result in an audio stream, and s only available when
|
||||
* running with the audio stream mode.
|
||||
*/
|
||||
public abstract Optional<ClassificationResult> classificationResult();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,151 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.core;
|
||||
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.tasks.components.containers.AudioData;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/** The base class of MediaPipe audio tasks. */
|
||||
public class BaseAudioTaskApi implements AutoCloseable {
|
||||
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||
private static final long PRESTREAM_TIMESTAMP = Long.MIN_VALUE + 2;
|
||||
|
||||
private final TaskRunner runner;
|
||||
private final RunningMode runningMode;
|
||||
private final String audioStreamName;
|
||||
private final String sampleRateStreamName;
|
||||
private double defaultSampleRate;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_audio_jni");
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseAudioTaskApi}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe audio task {@link RunningMode}.
|
||||
* @param audioStreamName the name of the input audio stream.
|
||||
* @param sampleRateStreamName the name of the audio sample rate stream.
|
||||
*/
|
||||
public BaseAudioTaskApi(
|
||||
TaskRunner runner,
|
||||
RunningMode runningMode,
|
||||
String audioStreamName,
|
||||
String sampleRateStreamName) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.audioStreamName = audioStreamName;
|
||||
this.sampleRateStreamName = sampleRateStreamName;
|
||||
this.defaultSampleRate = -1.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process audio clips. The call blocks the current thread until a failure
|
||||
* status or a successful result is returned.
|
||||
*
|
||||
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||
* @throws MediaPipeException if the task is not in the audio clips mode.
|
||||
*/
|
||||
protected TaskResult processAudioClip(AudioData audioClip) {
|
||||
if (runningMode != RunningMode.AUDIO_CLIPS) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the audio clips mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(
|
||||
audioStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createMatrix(
|
||||
audioClip.getFormat().getNumOfChannels(),
|
||||
audioClip.getBufferLength(),
|
||||
audioClip.getBuffer()));
|
||||
inputPackets.put(
|
||||
sampleRateStreamName,
|
||||
runner.getPacketCreator().createFloat64(audioClip.getFormat().getSampleRate()));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks or sets the audio sample rate in the audio stream mode.
|
||||
*
|
||||
* @param sampleRate the audio sample rate.
|
||||
* @throws MediaPipeException if the task is not in the audio stream mode or the provided sample
|
||||
* rate is inconsisent with the previously recevied.
|
||||
*/
|
||||
protected void checkOrSetSampleRate(double sampleRate) {
|
||||
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (defaultSampleRate > 0) {
|
||||
if (Double.compare(sampleRate, defaultSampleRate) != 0) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INVALID_ARGUMENT.ordinal(),
|
||||
"The input audio sample rate: "
|
||||
+ sampleRate
|
||||
+ " is inconsistent with the previously provided: "
|
||||
+ defaultSampleRate);
|
||||
}
|
||||
} else {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(sampleRateStreamName, runner.getPacketCreator().createFloat64(sampleRate));
|
||||
runner.send(inputPackets, PRESTREAM_TIMESTAMP);
|
||||
defaultSampleRate = sampleRate;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* An asynchronous method to send audio stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param audioClip a MediaPipe {@link AudioDatra} object for processing.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the stream mode.
|
||||
*/
|
||||
protected void sendAudioStreamData(AudioData audioClip, long timestampMs) {
|
||||
if (runningMode != RunningMode.AUDIO_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the audio stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(
|
||||
audioStreamName,
|
||||
runner
|
||||
.getPacketCreator()
|
||||
.createMatrix(
|
||||
audioClip.getFormat().getNumOfChannels(),
|
||||
audioClip.getBufferLength(),
|
||||
audioClip.getBuffer()));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the MediaPipe audio task. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.audio.core;
|
||||
|
||||
/**
|
||||
* MediaPipe audio task running mode. A MediaPipe audio task can be run with two different modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>AUDIO_CLIPS: The mode for running a mediapipe audio task on independent audio clips.
|
||||
* <li>AUDIO_STREAM: The mode for running a mediapipe audio task on an audio stream, such as from
|
||||
* microphone.
|
||||
* </ul>
|
||||
*/
|
||||
public enum RunningMode {
|
||||
AUDIO_CLIPS,
|
||||
AUDIO_STREAM,
|
||||
}
|
|
@ -0,0 +1,334 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import static java.lang.System.arraycopy;
|
||||
|
||||
import android.media.AudioFormat;
|
||||
import android.media.AudioRecord;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
|
||||
/**
|
||||
* Defines a ring buffer and some utility functions to prepare the input audio samples.
|
||||
*
|
||||
* <p>It maintains a <a href="https://en.wikipedia.org/wiki/Circular_buffer">Ring Buffer</a> to hold
|
||||
* input audio data. Clients could feed input audio data via `load` methods and access the
|
||||
* aggregated audio samples via `getTensorBuffer` method.
|
||||
*
|
||||
* <p>Note that this class can only handle input audio in Float (in {@link
|
||||
* android.media.AudioFormat#ENCODING_PCM_16BIT}) or Short (in {@link
|
||||
* android.media.AudioFormat#ENCODING_PCM_FLOAT}). Internally it converts and stores all the audio
|
||||
* samples in PCM Float encoding.
|
||||
*
|
||||
* <p>Typical usage in Kotlin
|
||||
*
|
||||
* <pre>
|
||||
* val audioData = AudioData.create(format, modelInputLength)
|
||||
* audioData.load(newData)
|
||||
* </pre>
|
||||
*
|
||||
* <p>Another sample usage with {@link android.media.AudioRecord}
|
||||
*
|
||||
* <pre>
|
||||
* val audioData = AudioData.create(format, modelInputLength)
|
||||
* Timer().scheduleAtFixedRate(delay, period) {
|
||||
* audioData.load(audioRecord)
|
||||
* }
|
||||
* </pre>
|
||||
*/
|
||||
public class AudioData {
|
||||
|
||||
private static final String TAG = AudioData.class.getSimpleName();
|
||||
private final FloatRingBuffer buffer;
|
||||
private final AudioDataFormat format;
|
||||
|
||||
/**
|
||||
* Creates a {@link android.media.AudioRecord} instance with a ring buffer whose size is {@code
|
||||
* sampleCounts} * {@code format.getNumOfChannels()}.
|
||||
*
|
||||
* @param format the expected {@link AudioDataFormat} of audio data loaded into this class.
|
||||
* @param sampleCounts the number of samples.
|
||||
*/
|
||||
public static AudioData create(AudioDataFormat format, int sampleCounts) {
|
||||
return new AudioData(format, sampleCounts);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link AudioData} instance with a ring buffer whose size is {@code sampleCounts} *
|
||||
* {@code format.getChannelCount()}.
|
||||
*
|
||||
* @param format the {@link android.media.AudioFormat} required by the TFLite model. It defines
|
||||
* the number of channels and sample rate.
|
||||
* @param sampleCounts the number of samples to be fed into the model
|
||||
*/
|
||||
public static AudioData create(AudioFormat format, int sampleCounts) {
|
||||
return new AudioData(AudioDataFormat.create(format), sampleCounts);
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps a few constants describing the format of the incoming audio samples, namely number of
|
||||
* channels and the sample rate. By default, num of channels is set to 1.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract static class AudioDataFormat {
|
||||
private static final int DEFAULT_NUM_OF_CHANNELS = 1;
|
||||
|
||||
/** Creates a {@link AudioFormat} instance from Android AudioFormat class. */
|
||||
public static AudioDataFormat create(AudioFormat format) {
|
||||
return AudioDataFormat.builder()
|
||||
.setNumOfChannels(format.getChannelCount())
|
||||
.setSampleRate(format.getSampleRate())
|
||||
.build();
|
||||
}
|
||||
|
||||
public abstract int getNumOfChannels();
|
||||
|
||||
public abstract float getSampleRate();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_AudioData_AudioDataFormat.Builder()
|
||||
.setNumOfChannels(DEFAULT_NUM_OF_CHANNELS);
|
||||
}
|
||||
|
||||
/** Builder for {@link AudioDataFormat} */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
|
||||
/* By default, it's set to have 1 channel. */
|
||||
public abstract Builder setNumOfChannels(int value);
|
||||
|
||||
public abstract Builder setSampleRate(float value);
|
||||
|
||||
abstract AudioDataFormat autoBuild();
|
||||
|
||||
public AudioDataFormat build() {
|
||||
AudioDataFormat format = autoBuild();
|
||||
if (format.getNumOfChannels() <= 0) {
|
||||
throw new IllegalArgumentException("Number of channels should be greater than 0");
|
||||
}
|
||||
if (format.getSampleRate() <= 0) {
|
||||
throw new IllegalArgumentException("Sample rate should be greater than 0");
|
||||
}
|
||||
return format;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores the input audio samples {@code src} in the ring buffer.
|
||||
*
|
||||
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
|
||||
* multi-channel input, the array is interleaved.
|
||||
*/
|
||||
public void load(float[] src) {
|
||||
load(src, 0, src.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores the input audio samples {@code src} in the ring buffer.
|
||||
*
|
||||
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_FLOAT}. For
|
||||
* multi-channel input, the array is interleaved.
|
||||
* @param offsetInFloat starting position in the {@code src} array
|
||||
* @param sizeInFloat the number of float values to be copied
|
||||
* @throws IllegalArgumentException for incompatible audio format or incorrect input size
|
||||
*/
|
||||
public void load(float[] src, int offsetInFloat, int sizeInFloat) {
|
||||
if (sizeInFloat % format.getNumOfChannels() != 0) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Size (%d) needs to be a multiplier of the number of channels (%d)",
|
||||
sizeInFloat, format.getNumOfChannels()));
|
||||
}
|
||||
buffer.load(src, offsetInFloat, sizeInFloat);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
|
||||
* buffer.
|
||||
*
|
||||
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
|
||||
* multi-channel input, the array is interleaved.
|
||||
*/
|
||||
public void load(short[] src) {
|
||||
load(src, 0, src.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts the input audio samples {@code src} to ENCODING_PCM_FLOAT, then stores it in the ring
|
||||
* buffer.
|
||||
*
|
||||
* @param src input audio samples in {@link android.media.AudioFormat#ENCODING_PCM_16BIT}. For
|
||||
* multi-channel input, the array is interleaved.
|
||||
* @param offsetInShort starting position in the src array
|
||||
* @param sizeInShort the number of short values to be copied
|
||||
* @throws IllegalArgumentException if the source array can't be copied
|
||||
*/
|
||||
public void load(short[] src, int offsetInShort, int sizeInShort) {
|
||||
if (offsetInShort + sizeInShort > src.length) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
|
||||
offsetInShort, sizeInShort, src.length));
|
||||
}
|
||||
float[] floatData = new float[sizeInShort];
|
||||
for (int i = 0; i < sizeInShort; i++) {
|
||||
// Convert the data to PCM Float encoding i.e. values between -1 and 1
|
||||
floatData[i] = src[i + offsetInShort] * 1.f / Short.MAX_VALUE;
|
||||
}
|
||||
load(floatData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads latest data from the {@link android.media.AudioRecord} in a non-blocking way. Only
|
||||
* supporting ENCODING_PCM_16BIT and ENCODING_PCM_FLOAT.
|
||||
*
|
||||
* @param record an instance of {@link android.media.AudioRecord}
|
||||
* @return number of captured audio values whose size is {@code channelCount * sampleCount}. If
|
||||
* there was no new data in the AudioRecord or an error occurred, this method will return 0.
|
||||
* @throws IllegalArgumentException for unsupported audio encoding format
|
||||
* @throws IllegalStateException if reading from AudioRecord failed
|
||||
*/
|
||||
public int load(AudioRecord record) {
|
||||
if (!this.format.equals(AudioDataFormat.create(record.getFormat()))) {
|
||||
throw new IllegalArgumentException("Incompatible audio format.");
|
||||
}
|
||||
int loadedValues = 0;
|
||||
if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_FLOAT) {
|
||||
float[] newData = new float[record.getChannelCount() * record.getBufferSizeInFrames()];
|
||||
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
|
||||
if (loadedValues > 0) {
|
||||
load(newData, 0, loadedValues);
|
||||
return loadedValues;
|
||||
}
|
||||
} else if (record.getAudioFormat() == AudioFormat.ENCODING_PCM_16BIT) {
|
||||
short[] newData = new short[record.getChannelCount() * record.getBufferSizeInFrames()];
|
||||
loadedValues = record.read(newData, 0, newData.length, AudioRecord.READ_NON_BLOCKING);
|
||||
if (loadedValues > 0) {
|
||||
load(newData, 0, loadedValues);
|
||||
return loadedValues;
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Unsupported encoding. Requires ENCODING_PCM_16BIT or ENCODING_PCM_FLOAT.");
|
||||
}
|
||||
|
||||
switch (loadedValues) {
|
||||
case AudioRecord.ERROR_INVALID_OPERATION:
|
||||
throw new IllegalStateException("AudioRecord.ERROR_INVALID_OPERATION");
|
||||
|
||||
case AudioRecord.ERROR_BAD_VALUE:
|
||||
throw new IllegalStateException("AudioRecord.ERROR_BAD_VALUE");
|
||||
|
||||
case AudioRecord.ERROR_DEAD_OBJECT:
|
||||
throw new IllegalStateException("AudioRecord.ERROR_DEAD_OBJECT");
|
||||
|
||||
case AudioRecord.ERROR:
|
||||
throw new IllegalStateException("AudioRecord.ERROR");
|
||||
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a float array holding all the available audio samples in {@link
|
||||
* android.media.AudioFormat#ENCODING_PCM_FLOAT} i.e. values are in the range of [-1, 1].
|
||||
*/
|
||||
public float[] getBuffer() {
|
||||
float[] bufferData = new float[buffer.getCapacity()];
|
||||
ByteBuffer byteBuffer = buffer.getBuffer();
|
||||
byteBuffer.asFloatBuffer().get(bufferData);
|
||||
return bufferData;
|
||||
}
|
||||
|
||||
/* Returns the {@link AudioDataFormat} associated with the tensor. */
|
||||
public AudioDataFormat getFormat() {
|
||||
return format;
|
||||
}
|
||||
|
||||
/* Returns the audio buffer length. */
|
||||
public int getBufferLength() {
|
||||
return buffer.getCapacity() / format.getNumOfChannels();
|
||||
}
|
||||
|
||||
private AudioData(AudioDataFormat format, int sampleCounts) {
|
||||
this.format = format;
|
||||
this.buffer = new FloatRingBuffer(sampleCounts * format.getNumOfChannels());
|
||||
}
|
||||
|
||||
/** Actual implementation of the ring buffer. */
|
||||
private static class FloatRingBuffer {
|
||||
|
||||
private final float[] buffer;
|
||||
private int nextIndex = 0;
|
||||
|
||||
public FloatRingBuffer(int flatSize) {
|
||||
buffer = new float[flatSize];
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a slice of the float array to the ring buffer. If the float array is longer than ring
|
||||
* buffer's capacity, samples with lower indices in the array will be ignored.
|
||||
*/
|
||||
public void load(float[] newData, int offset, int size) {
|
||||
if (offset + size > newData.length) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Index out of range. offset (%d) + size (%d) should <= newData.length (%d)",
|
||||
offset, size, newData.length));
|
||||
}
|
||||
// If buffer can't hold all the data, only keep the most recent data of size buffer.length
|
||||
if (size > buffer.length) {
|
||||
offset += (size - buffer.length);
|
||||
size = buffer.length;
|
||||
}
|
||||
if (nextIndex + size < buffer.length) {
|
||||
// No need to wrap nextIndex, just copy newData[offset:offset + size]
|
||||
// to buffer[nextIndex:nextIndex+size]
|
||||
arraycopy(newData, offset, buffer, nextIndex, size);
|
||||
} else {
|
||||
// Need to wrap nextIndex, perform copy in two chunks.
|
||||
int firstChunkSize = buffer.length - nextIndex;
|
||||
// First copy newData[offset:offset+firstChunkSize] to buffer[nextIndex:buffer.length]
|
||||
arraycopy(newData, offset, buffer, nextIndex, firstChunkSize);
|
||||
// Then copy newData[offset+firstChunkSize:offset+size] to buffer[0:size-firstChunkSize]
|
||||
arraycopy(newData, offset + firstChunkSize, buffer, 0, size - firstChunkSize);
|
||||
}
|
||||
|
||||
nextIndex = (nextIndex + size) % buffer.length;
|
||||
}
|
||||
|
||||
public ByteBuffer getBuffer() {
|
||||
// Create non-direct buffers. On Pixel 4, creating direct buffer costs around 0.1 ms, which
|
||||
// can be 5x ~ 10x longer compared to non-direct buffer backed by arrays (around 0.01ms), so
|
||||
// generally we don't create direct buffer for every invocation.
|
||||
ByteBuffer byteBuffer = ByteBuffer.allocate(Float.SIZE / 8 * buffer.length);
|
||||
byteBuffer.order(ByteOrder.nativeOrder());
|
||||
FloatBuffer result = byteBuffer.asFloatBuffer();
|
||||
result.put(buffer, nextIndex, buffer.length - nextIndex);
|
||||
result.put(buffer, 0, nextIndex);
|
||||
byteBuffer.rewind();
|
||||
return byteBuffer;
|
||||
}
|
||||
|
||||
public int getCapacity() {
|
||||
return buffer.length;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,6 +16,15 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "audio_data",
|
||||
srcs = ["AudioData.java"],
|
||||
deps = [
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "category",
|
||||
srcs = ["Category.java"],
|
||||
|
|
|
@ -31,11 +31,22 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
|||
InputT convertToTaskInput(List<Packet> packets);
|
||||
}
|
||||
|
||||
/** Interface for the customizable MediaPipe task result listener. */
|
||||
/**
|
||||
* Interface for the customizable MediaPipe task result listener that can reteive both task result
|
||||
* objects and the correpsonding input data.
|
||||
*/
|
||||
public interface ResultListener<OutputT extends TaskResult, InputT> {
|
||||
void run(OutputT result, InputT input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for the customizable MediaPipe task result listener that can only reteive task result
|
||||
* objects.
|
||||
*/
|
||||
public interface PureResultListener<OutputT extends TaskResult> {
|
||||
void run(OutputT result);
|
||||
}
|
||||
|
||||
private static final String TAG = "OutputHandler";
|
||||
// A task-specific graph output packet converter that should be implemented per task.
|
||||
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
|
||||
|
@ -45,6 +56,8 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
|||
protected ErrorListener errorListener;
|
||||
// The cached task result for non latency sensitive use cases.
|
||||
protected OutputT cachedTaskResult;
|
||||
// The latest output timestamp.
|
||||
protected long latestOutputTimestamp = -1;
|
||||
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
|
||||
private boolean handleTimestampBoundChanges = false;
|
||||
|
||||
|
@ -98,6 +111,11 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
|||
return taskResult;
|
||||
}
|
||||
|
||||
/* Returns the latest output timestamp. */
|
||||
public long getLatestOutputTimestamp() {
|
||||
return latestOutputTimestamp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
|
||||
*
|
||||
|
@ -109,6 +127,7 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
|||
taskResult = outputPacketConverter.convertToTaskResult(packets);
|
||||
if (resultListener == null) {
|
||||
cachedTaskResult = taskResult;
|
||||
latestOutputTimestamp = packets.get(0).getTimestamp();
|
||||
} else {
|
||||
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
|
||||
resultListener.run(taskResult, taskInput);
|
||||
|
|
|
@ -93,6 +93,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
public synchronized TaskResult process(Map<String, Packet> inputs) {
|
||||
addPackets(inputs, generateSyntheticTimestamp());
|
||||
graph.waitUntilGraphIdle();
|
||||
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
|
||||
return outputHandler.retrieveCachedTaskResult();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user