Internal change

PiperOrigin-RevId: 487466061
This commit is contained in:
MediaPipe Team 2022-11-10 01:30:55 -08:00 committed by Copybara-Service
parent 0ac604d507
commit 0b12aa9435
13 changed files with 103 additions and 116 deletions

View File

@ -21,7 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
// Placeholder for internal dependency on trusted resource url
import {AudioClassifierOptions} from './audio_classifier_options';
import {Classifications} from './audio_classifier_result';
import {AudioClassifierResult} from './audio_classifier_result';
const MEDIAPIPE_GRAPH =
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
// implementation
const AUDIO_STREAM = 'input_audio';
const SAMPLE_RATE_STREAM = 'sample_rate';
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Performs audio classification. */
export class AudioClassifier extends TaskRunner {
private classifications: Classifications[] = [];
private classificationResults: AudioClassifierResult[] = [];
private defaultSampleRate = 48000;
private readonly options = new AudioClassifierGraphOptions();
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
* `48000` if no custom default was set.
* @return The classification result of the audio datas
*/
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
classify(audioData: Float32Array, sampleRate?: number):
AudioClassifierResult[] {
sampleRate = sampleRate ?? this.defaultSampleRate;
// Configures the number of samples in the WASM layer. We re-configure the
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
this.addAudioToStream(audioData, timestamp);
this.classifications = [];
this.classificationResults = [];
this.finishProcessing();
return [...this.classifications];
return [...this.classificationResults];
}
/**
* Internal function for converting raw data into a classification, and
* adding it to our classfications list.
* Internal function for converting raw data into classification results, and
* adding them to our classfication results list.
**/
private addJsAudioClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
binaryProtos.forEach(binaryProto => {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classificationResults.push(
convertFromClassificationResultProto(classificationResult));
});
}
/** Updates the MediaPipe graph configuration. */
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
@ -198,14 +201,15 @@ export class AudioClassifier extends TaskRunner {
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
classifierNode.addOutputStream(
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
this.addJsAudioClassification(binaryProto);
});
this.attachProtoVectorListener(
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
this.addJsAudioClassificationResults(binaryProtos);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -15,4 +15,4 @@
*/
export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -10,8 +10,8 @@ mediapipe_ts_library(
)
mediapipe_ts_library(
name = "classifications",
srcs = ["classifications.d.ts"],
name = "classification_result",
srcs = ["classification_result.d.ts"],
deps = [":category"],
)

View File

@ -16,27 +16,14 @@
import {Category} from '../../../../tasks/web/components/containers/category';
/** List of predicted categories with an optional timestamp. */
export interface ClassificationEntry {
/** Classification results for a given classifier head. */
export interface Classifications {
/**
* The array of predicted categories, usually sorted by descending scores,
* e.g., from high to low probability.
*/
categories: Category[];
/**
* The optional timestamp (in milliseconds) associated to the classification
* entry. This is useful for time series use cases, e.g., audio
* classification.
*/
timestampMs?: number;
}
/** Classifications for a given classifier head. */
export interface Classifications {
/** A list of classification entries. */
entries: ClassificationEntry[];
/**
* The index of the classifier head these categories refer to. This is
* useful for multi-head models.
@ -45,7 +32,24 @@ export interface Classifications {
/**
* The name of the classifier head, which is the corresponding tensor
* metadata name.
* metadata name. Defaults to an empty string if there is no such metadata.
*/
headName: string;
}
/** Classification results of a model. */
export interface ClassificationResult {
/** The classification results for each head of the model. */
classifications: Classifications[];
/**
* The optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*
* This is only used for classification on time series (e.g. audio
* classification). In these use cases, the amount of data to process might
* exceed the maximum size that the model can process: to solve this, the
* input data is split into multiple chunks starting at different timestamps.
*/
timestampMs?: number;
}

View File

@ -17,8 +17,9 @@ mediapipe_ts_library(
name = "classifier_result",
srcs = ["classifier_result.ts"],
deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/containers:classification_result",
],
)

View File

@ -14,48 +14,46 @@
* limitations under the License.
*/
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0;
/**
* Converts a ClassificationEntry proto to the ClassificationEntry result
* type.
* Converts a Classifications proto to a Classifications object.
*/
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
ClassificationEntry {
const categories = source.getCategoriesList().map(category => {
return {
index: category.getIndex() ?? DEFAULT_INDEX,
score: category.getScore() ?? DEFAULT_SCORE,
displayName: category.getDisplayName() ?? '',
categoryName: category.getCategoryName() ?? '',
};
});
function convertFromClassificationsProto(source: ClassificationsProto):
Classifications {
const categories =
source.getClassificationList()?.getClassificationList().map(
classification => {
return {
index: classification.getIndex() ?? DEFAULT_INDEX,
score: classification.getScore() ?? DEFAULT_SCORE,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
};
}) ??
[];
return {
categories,
timestampMs: source.getTimestampMs(),
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
headName: source.getHeadName() ?? '',
};
}
/**
* Converts a ClassificationResult proto to a list of classifications.
* Converts a ClassificationResult proto to a ClassificationResult object.
*/
export function convertFromClassificationResultProto(
classificationResult: ClassificationResult) : Classifications[] {
const result: Classifications[] = [];
for (const classificationsProto of
classificationResult.getClassificationsList()) {
const classifications: Classifications = {
entries: classificationsProto.getEntriesList().map(
entry => convertFromClassificationEntryProto(entry)),
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
headName: classificationsProto.getHeadName() ?? '',
};
result.push(classifications);
source: ClassificationResultProto): ClassificationResult {
const result: ClassificationResult = {
classifications: source.getClassificationsList().map(
classififications => convertFromClassificationsProto(classififications))
};
if (source.hasTimestampMs()) {
result.timestampMs = source.getTimestampMs();
}
return result;
}

View File

@ -22,7 +22,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
// Placeholder for internal dependency on trusted resource url
import {TextClassifierOptions} from './text_classifier_options';
import {Classifications} from './text_classifier_result';
import {TextClassifierResult} from './text_classifier_result';
const INPUT_STREAM = 'text_in';
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
const CLASSIFICATIONS_STREAM = 'classifications_out';
const TEXT_CLASSIFIER_GRAPH =
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
/** Performs Natural Language classification. */
export class TextClassifier extends TaskRunner {
private classifications: Classifications[] = [];
private classificationResult: TextClassifierResult = {classifications: []};
private readonly options = new TextClassifierGraphOptions();
/**
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
* @param text The text to process.
* @return The classification result of the text
*/
classify(text: string): Classifications[] {
// Get classification classes by running our MediaPipe graph.
this.classifications = [];
classify(text: string): TextClassifierResult {
// Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []};
this.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing();
return [...this.classifications];
}
// Internal function for converting raw data into a classification, and
// adding it to our classifications list.
private addJsTextClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
console.log(classificationResult.toObject());
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
return this.classificationResult;
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
classifierNode.addOutputStream(
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
this.addJsTextClassification(binaryProto);
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/
export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -21,7 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications",
"//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
// Placeholder for internal dependency on trusted resource url
import {ImageClassifierOptions} from './image_classifier_options';
import {Classifications} from './image_classifier_result';
import {ImageClassifierResult} from './image_classifier_result';
const IMAGE_CLASSIFIER_GRAPH =
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
const INPUT_STREAM = 'input_image';
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
const CLASSIFICATIONS_STREAM = 'classifications';
export {ImageSource}; // Used in the public API
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
/** Performs classification on images. */
export class ImageClassifier extends TaskRunner {
private classifications: Classifications[] = [];
private classificationResult: ImageClassifierResult = {classifications: []};
private readonly options = new ImageClassifierGraphOptions();
/**
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
* provided, defaults to `performance.now()`.
* @return The classification result of the image
*/
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
// Get classification classes by running our MediaPipe graph.
this.classifications = [];
classify(imageSource: ImageSource, timestamp?: number):
ImageClassifierResult {
// Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []};
this.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing();
return [...this.classifications];
}
/**
* Internal function for converting raw data into a classification, and
* adding it to our classfications list.
**/
private addJsImageClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
return this.classificationResult;
}
/** Updates the MediaPipe graph configuration. */
private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
classifierNode.addOutputStream(
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
this.addJsImageClassification(binaryProto);
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
});
const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/
export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';