MediaPipe Tasks AudioEmbedder Java API

PiperOrigin-RevId: 488456442
This commit is contained in:
Jiuqiang Tang 2022-11-14 13:40:49 -08:00 committed by Copybara-Service
parent ca7b5e9d8b
commit b4fba6fe61
7 changed files with 511 additions and 4 deletions

View File

@ -58,9 +58,12 @@ struct AudioEmbedderOptions {
nullptr;
};
// Performs embedding extraction on audio clips or audio stream.
// Performs audio embedding extraction on audio clips or audio stream.
//
// The API expects a TFLite model with TFLite Model Metadata.
// This API expects a TFLite model with mandatory TFLite Model Metadata that
// contains the mandatory AudioProperties of the solo input audio tensor and the
// optional (but recommended) label items as AssociatedFiles with type
// TENSOR_AXIS_LABELS per output embedding tensor.
//
// Input tensor:
// (kTfLiteFloat32)

View File

@ -39,6 +39,7 @@ cc_binary(
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/audio/audio_classifier:audio_classifier_graph",
"//mediapipe/tasks/cc/audio/audio_embedder:audio_embedder_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
@ -75,6 +76,35 @@ android_library(
],
)
android_library(
name = "audioembedder",
srcs = [
"audioembedder/AudioEmbedder.java",
"audioembedder/AudioEmbedderResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "audioembedder/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/audio:libmediapipe_tasks_audio_jni_lib",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:audiodata",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:embedderoptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_audio_aar")
mediapipe_tasks_audio_aar(

View File

@ -265,8 +265,10 @@ public final class AudioClassifier extends BaseAudioTaskApi {
}
/*
* Sends audio data (a block in a continuous audio stream) to perform audio classification. Only
* use this method when the AudioClassifier is created with the audio stream mode.
* Sends audio data (a block in a continuous audio stream) to perform audio classification, and
* the results will be available via the {@link ResultListener} provided in the
* {@link AudioClassifierOptions}. Only use this method when the AudioClassifier is created with
* the audio stream mode.
*
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.audio.audioembedder">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,388 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.audioembedder;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.audio.audioembedder.proto.AudioEmbedderGraphOptionsProto;
import com.google.mediapipe.tasks.audio.core.BaseAudioTaskApi;
import com.google.mediapipe.tasks.audio.core.RunningMode;
import com.google.mediapipe.tasks.components.containers.AudioData;
import com.google.mediapipe.tasks.components.containers.Embedding;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.components.processors.EmbedderOptions;
import com.google.mediapipe.tasks.components.utils.CosineSimilarity;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.PureResultListener;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs audio embedding extraction on audio clips or audio stream.
*
* <p>This API expects a TFLite model with mandatory TFLite Model Metadata that contains the
* mandatory AudioProperties of the solo input audio tensor and the optional (but recommended) label
* items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
*
* <p>Input tensor: (kTfLiteFloat32)
*
* <ul>
* <li>input audio buffer of size `[batch * samples]`.
* <li>batch inference is not supported (`batch` is required to be 1).
* <li>for multi-channel models, the channels need be interleaved.
* </ul>
*
* <p>At least one output tensor with: (kTfLiteFloat32)
*
* <ul>
* <li>`N` components corresponding to the `N` dimensions of the returned feature vector for this
* output layer.
* <li>Either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]`.
* </ul>
*/
public final class AudioEmbedder extends BaseAudioTaskApi {
private static final String TAG = AudioEmbedder.class.getSimpleName();
private static final String AUDIO_IN_STREAM_NAME = "audio_in";
private static final String SAMPLE_RATE_IN_STREAM_NAME = "sample_rate_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"AUDIO:" + AUDIO_IN_STREAM_NAME, "SAMPLE_RATE:" + SAMPLE_RATE_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"EMBEDDINGS:embeddings_out", "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"));
private static final int EMBEDDINGS_OUT_STREAM_INDEX = 0;
private static final int TIMESTAMPED_EMBEDDINGS_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.audio.audio_embedder.AudioEmbedderGraph";
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
static {
ProtoUtil.registerTypeName(
EmbeddingsProto.EmbeddingResult.class,
"mediapipe.tasks.components.containers.proto.EmbeddingResult");
}
/**
* Creates an {@link AudioEmbedder} instance from a model file and default {@link
* AudioEmbedderOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the embedding model in the assets.
* @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation.
*/
public static AudioEmbedder createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link AudioEmbedder} instance from a model file and default {@link
* AudioEmbedderOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the embedding model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation.
*/
public static AudioEmbedder createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates an {@link AudioEmbedder} instance from a model buffer and default {@link
* AudioEmbedderOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the embedding
* model.
* @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation.
*/
public static AudioEmbedder createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, AudioEmbedderOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link AudioEmbedder} instance from an {@link AudioEmbedderOptions} instance.
*
* @param context an Android {@link Context}.
* @param options an {@link AudioEmbedderOptions} instance.
* @throws MediaPipeException if there is an error during {@link AudioEmbedder} creation.
*/
public static AudioEmbedder createFromOptions(Context context, AudioEmbedderOptions options) {
OutputHandler<AudioEmbedderResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<AudioEmbedderResult, Void>() {
@Override
public AudioEmbedderResult convertToTaskResult(List<Packet> packets) {
try {
if (!packets.get(EMBEDDINGS_OUT_STREAM_INDEX).isEmpty()) {
// For audio stream mode.
return AudioEmbedderResult.createFromProto(
PacketGetter.getProto(
packets.get(EMBEDDINGS_OUT_STREAM_INDEX),
EmbeddingsProto.EmbeddingResult.getDefaultInstance()),
packets.get(EMBEDDINGS_OUT_STREAM_INDEX).getTimestamp()
/ MICROSECONDS_PER_MILLISECOND);
} else {
// For audio clips mode.
return AudioEmbedderResult.createFromProtoList(
PacketGetter.getProtoVector(
packets.get(TIMESTAMPED_EMBEDDINGS_OUT_STREAM_INDEX),
EmbeddingsProto.EmbeddingResult.parser()),
-1);
}
} catch (IOException e) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
}
}
@Override
public Void convertToTaskInput(List<Packet> packets) {
return null;
}
});
if (options.resultListener().isPresent()) {
ResultListener<AudioEmbedderResult, Void> resultListener =
new ResultListener<AudioEmbedderResult, Void>() {
@Override
public void run(AudioEmbedderResult audioEmbedderResult, Void input) {
options.resultListener().get().run(audioEmbedderResult);
}
};
handler.setResultListener(resultListener);
}
options.errorListener().ifPresent(handler::setErrorListener);
// Audio tasks should not drop input audio due to flow limiting, which may cause data
// inconsistency.
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<AudioEmbedderOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(options)
.setEnableFlowLimiting(false)
.build(),
handler);
return new AudioEmbedder(runner, options.runningMode());
}
/**
* Constructor to initialize an {@link AudioEmbedder} from a {@link TaskRunner} and {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe audio task {@link RunningMode}.
*/
private AudioEmbedder(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, AUDIO_IN_STREAM_NAME, SAMPLE_RATE_IN_STREAM_NAME);
}
/*
* Performs embedding extraction on the provided audio clips. Only use this method when the
* AudioEmbedder is created with the audio clips mode.
*
* <p>The audio clip is represented as a MediaPipe {@link AudioData} object The method accepts
* audio clips with various length and audio sample rate. It's required to provide the
* corresponding audio sample rate within the {@link AudioData} object.
*
* <p>The input audio clip may be longer than what the model is able to process in a single
* inference. When this occurs, the input audio clip is split into multiple chunks starting at
* different timestamps. For this reason, this function returns a vector of EmbeddingResult
* objects, each associated with a timestamp corresponding to the start (in milliseconds) of the
* chunk data that was extracted.
*
* @param audioClip a MediaPipe {@link AudioData} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public AudioEmbedderResult embed(AudioData audioClip) {
return (AudioEmbedderResult) processAudioClip(audioClip);
}
/*
* Sends audio data (a block in a continuous audio stream) to perform audio embedding, and
* the results will be available via the {@link ResultListener} provided in the
* {@link AudioClassifierOptions}. Only use this method when the AudioEmbedder is created with
* the audio stream mode.
*
* <p>The audio block is represented as a MediaPipe {@link AudioData} object. The audio data will
* be resampled, accumulated, and framed to the proper size for the underlying model to consume.
* It's required to provide the corresponding audio sample rate within {@link AudioData} object as
* well as a timestamp (in milliseconds) to indicate the start time of the input audio block. The
* timestamps must be monotonically increasing. This method will return immediately after
* the input audio data is accepted. The results will be available in the `resultListener`
* provided in the `AudioEmbedderOptions`. The `embedAsync` method is designed to process
* auido stream data such as microphone input.
*
* <p>The input audio block may be longer than what the model is able to process in a single
* inference. When this occurs, the input audio block is split into multiple chunks. For this
* reason, the callback may be called multiple times (once per chunk) for each call to this
* function.
*
* @param audioBlock a MediaPipe {@link AudioData} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void embedAsync(AudioData audioBlock, long timestampMs) {
checkOrSetSampleRate(audioBlock.getFormat().getSampleRate());
sendAudioStreamData(audioBlock, timestampMs);
}
/**
* Utility function to compute <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine
* similarity</a> between two {@link Embedding} objects.
*
* @throws IllegalArgumentException if the embeddings are of different types (float vs.
* quantized), have different sizes, or have an L2-norm of 0.
*/
public static double cosineSimilarity(Embedding u, Embedding v) {
return CosineSimilarity.compute(u, v);
}
/** Options for setting up and {@link AudioEmbedder}. */
@AutoValue
public abstract static class AudioEmbedderOptions extends TaskOptions {
/** Builder for {@link AudioEmbedderOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the {@link BaseOptions} for the audio embedder task. */
public abstract Builder setBaseOptions(BaseOptions baseOptions);
/**
* Sets the {@link RunningMode} for the audio embedder task. Default to the audio clips mode.
* Image embedder has two modes:
*
* <ul>
* <li>AUDIO_CLIPS: The mode for running audio embedding on audio clips. Users feed audio
* clips to the `embed` method, and will receive the embedding results as the return
* value.
* <li>AUDIO_STREAM: The mode for running audio embedding on the audio stream, such as from
* microphone. Users call `embedAsync` to push the audio data into the AudioEmbedder,
* the embedding results will be available in the result callback when the audio
* embedder finishes the work.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode runningMode);
/**
* Sets the optional {@link EmbedderOptions} controling embedding behavior, such as score
* threshold, number of results, etc.
*/
public abstract Builder setEmbedderOptions(EmbedderOptions embedderOptions);
/**
* Sets the {@link ResultListener} to receive the embedding results asynchronously when the
* audio embedder is in the audio stream mode.
*/
public abstract Builder setResultListener(
PureResultListener<AudioEmbedderResult> resultListener);
/** Sets an optional {@link ErrorListener}. */
public abstract Builder setErrorListener(ErrorListener errorListener);
abstract AudioEmbedderOptions autoBuild();
/**
* Validates and builds the {@link AudioEmbedderOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the audio embedder is
* in the audio stream mode.
*/
public final AudioEmbedderOptions build() {
AudioEmbedderOptions options = autoBuild();
if (options.runningMode() == RunningMode.AUDIO_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The audio embedder is in the audio stream mode, a user-defined result listener"
+ " must be provided in the AudioEmbedderOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The audio embedder is in the audio clips mode, a user-defined result listener"
+ " shouldn't be provided in AudioEmbedderOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<EmbedderOptions> embedderOptions();
abstract Optional<PureResultListener<AudioEmbedderResult>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_AudioEmbedder_AudioEmbedderOptions.Builder()
.setRunningMode(RunningMode.AUDIO_CLIPS);
}
/** Converts a {@link AudioEmbedderOptions} to a {@link CalculatorOptions} protobuf message. */
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() == RunningMode.AUDIO_STREAM);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.Builder taskOptionsBuilder =
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (embedderOptions().isPresent()) {
taskOptionsBuilder.setEmbedderOptions(embedderOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder()
.setExtension(
AudioEmbedderGraphOptionsProto.AudioEmbedderGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -0,0 +1,75 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.audio.audioembedder;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.EmbeddingResult;
import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/** Represents the embedding results generated by {@link AudioEmbedder}. */
@AutoValue
public abstract class AudioEmbedderResult implements TaskResult {
/**
* Creates an {@link AudioEmbedderResult} instance from a list of {@link
* EmbeddingsProto.EmbeddingResult} protobuf messages.
*
* @param protoList a list of {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
static AudioEmbedderResult createFromProtoList(
List<EmbeddingsProto.EmbeddingResult> protoList, long timestampMs) {
List<EmbeddingResult> classificationResultList = new ArrayList<>();
for (EmbeddingsProto.EmbeddingResult proto : protoList) {
classificationResultList.add(EmbeddingResult.createFromProto(proto));
}
return new AutoValue_AudioEmbedderResult(
Optional.of(classificationResultList), Optional.empty(), timestampMs);
}
/**
* Creates an {@link AudioEmbedderResult} instance from a {@link EmbeddingsProto.EmbeddingResult}
* protobuf message.
*
* @param proto the {@link EmbeddingsProto.EmbeddingResult} protobuf message to convert.
* @param timestampMs a timestamp for this result.
*/
static AudioEmbedderResult createFromProto(
EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
return new AutoValue_AudioEmbedderResult(
Optional.empty(), Optional.of(EmbeddingResult.createFromProto(proto)), timestampMs);
}
/**
* A list of of timpstamped {@link EmbeddingResult} objects, each contains one set of results per
* embedder head. The list represents the audio embedding result of an audio clip, and is only
* available when running with the audio clips mode.
*/
public abstract Optional<List<EmbeddingResult>> embeddingResultList();
/**
* Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents
* one audio embedding result in an audio stream, and s only available when running with the audio
* stream mode.
*/
public abstract Optional<EmbeddingResult> embeddingResult();
@Override
public abstract long timestampMs();
}

View File

@ -32,6 +32,7 @@ _CORE_TASKS_JAVA_PROTO_LITE_TARGETS = [
_AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/audio/audio_embedder/proto:audio_embedder_graph_options_java_proto_lite",
]
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [