Merge classificationResultList() and classificationResult() to be classificationResults(), and similar for embeddingResults().

PiperOrigin-RevId: 502601043
This commit is contained in:
Jiuqiang Tang 2023-01-17 09:04:54 -08:00 committed by Copybara-Service
parent c1f5920ecf
commit 7974171c3d
2 changed files with 26 additions and 32 deletions

View File

@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.ClassificationsPro
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/** Represents the classification results generated by {@link AudioClassifier}. */
@AutoValue
@ -40,8 +39,7 @@ public abstract class AudioClassifierResult implements TaskResult {
for (ClassificationsProto.ClassificationResult proto : protoList) {
classificationResultList.add(ClassificationResult.createFromProto(proto));
}
return new AutoValue_AudioClassifierResult(
Optional.of(classificationResultList), Optional.empty(), timestampMs);
return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs);
}
/**
@ -53,23 +51,22 @@ public abstract class AudioClassifierResult implements TaskResult {
*/
static AudioClassifierResult createFromProto(
ClassificationsProto.ClassificationResult proto, long timestampMs) {
return new AutoValue_AudioClassifierResult(
Optional.empty(), Optional.of(ClassificationResult.createFromProto(proto)), timestampMs);
List<ClassificationResult> classificationResultList = new ArrayList<>();
classificationResultList.add(ClassificationResult.createFromProto(proto));
return new AutoValue_AudioClassifierResult(classificationResultList, timestampMs);
}
/**
* A list of of timestamped {@link ClassificationResult} objects, each contains one set of results
* per classifier head. The list represents the audio classification result of an audio clip, and
* is only available when running with the audio clips mode.
* per classifier head.
*
* <p>In the "audio stream" mode, the list only contains one element, representing the
* classification result of the audio block that starts at {@link
* ClassificationResult.timestampMs} in the audio stream. Otherwise, in the "audio clips" mode,
* the list may include multiple {@link ClassificationResult} objects, each classifying an
* interval of the entire audio clip that starts at {@link ClassificationResult.timestampMs}.
*/
public abstract Optional<List<ClassificationResult>> classificationResultList();
/**
* Contains one set of results per classifier head. A {@link ClassificationResult} usually
* represents one audio classification result in an audio stream, and s only available when
* running with the audio stream mode.
*/
public abstract Optional<ClassificationResult> classificationResult();
public abstract List<ClassificationResult> classificationResults();
@Override
public abstract long timestampMs();

View File

@ -20,7 +20,6 @@ import com.google.mediapipe.tasks.components.containers.proto.EmbeddingsProto;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/** Represents the embedding results generated by {@link AudioEmbedder}. */
@AutoValue
@ -35,12 +34,11 @@ public abstract class AudioEmbedderResult implements TaskResult {
*/
static AudioEmbedderResult createFromProtoList(
List<EmbeddingsProto.EmbeddingResult> protoList, long timestampMs) {
List<EmbeddingResult> classificationResultList = new ArrayList<>();
List<EmbeddingResult> embeddingResultList = new ArrayList<>();
for (EmbeddingsProto.EmbeddingResult proto : protoList) {
classificationResultList.add(EmbeddingResult.createFromProto(proto));
embeddingResultList.add(EmbeddingResult.createFromProto(proto));
}
return new AutoValue_AudioEmbedderResult(
Optional.of(classificationResultList), Optional.empty(), timestampMs);
return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs);
}
/**
@ -52,23 +50,22 @@ public abstract class AudioEmbedderResult implements TaskResult {
*/
static AudioEmbedderResult createFromProto(
EmbeddingsProto.EmbeddingResult proto, long timestampMs) {
return new AutoValue_AudioEmbedderResult(
Optional.empty(), Optional.of(EmbeddingResult.createFromProto(proto)), timestampMs);
List<EmbeddingResult> embeddingResultList = new ArrayList<>();
embeddingResultList.add(EmbeddingResult.createFromProto(proto));
return new AutoValue_AudioEmbedderResult(embeddingResultList, timestampMs);
}
/**
* A list of of timpstamped {@link EmbeddingResult} objects, each contains one set of results per
* embedder head. The list represents the audio embedding result of an audio clip, and is only
* available when running with the audio clips mode.
* embedder head.
*
* <p>In the "audio stream" mode, the list only contains one element, representing the embedding
* result of the audio block that starts at {@link EmbeddingResult.timestampMs} in the audio
* stream. Otherwise, in the "audio clips" mode, the list may include multiple {@link
* EmbeddingResult} objects, each contains the embedding of an interval of the entire audio clip
* that starts at {@link EmbeddingResult.timestampMs}.
*/
public abstract Optional<List<EmbeddingResult>> embeddingResultList();
/**
* Contains one set of results per classifier head. A {@link EmbeddingResult} usually represents
* one audio embedding result in an audio stream, and is only available when running with the
* audio stream mode.
*/
public abstract Optional<EmbeddingResult> embeddingResult();
public abstract List<EmbeddingResult> embeddingResults();
@Override
public abstract long timestampMs();