diff --git a/mediapipe/tasks/web/audio/audio_classifier/BUILD b/mediapipe/tasks/web/audio/audio_classifier/BUILD index bc3048df1..1bc4af309 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/web/audio/audio_classifier/BUILD @@ -21,7 +21,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts index 3d3ca5ae7..e3700cd7a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier.ts @@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm // Placeholder for internal dependency on trusted resource url import {AudioClassifierOptions} from './audio_classifier_options'; -import {Classifications} from './audio_classifier_result'; +import {AudioClassifierResult} from './audio_classifier_result'; const MEDIAPIPE_GRAPH = 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; @@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH = // implementation const AUDIO_STREAM = 'input_audio'; const SAMPLE_RATE_STREAM = 'sample_rate'; -const CLASSIFICATION_RESULT_STREAM = 'classification_result'; +const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern /** Performs audio classification. */ export class AudioClassifier extends TaskRunner { - private classifications: Classifications[] = []; + private classificationResults: AudioClassifierResult[] = []; private defaultSampleRate = 48000; private readonly options = new AudioClassifierGraphOptions(); @@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner { * `48000` if no custom default was set. * @return The classification result of the audio datas */ - classify(audioData: Float32Array, sampleRate?: number): Classifications[] { + classify(audioData: Float32Array, sampleRate?: number): + AudioClassifierResult[] { sampleRate = sampleRate ?? this.defaultSampleRate; // Configures the number of samples in the WASM layer. We re-configure the @@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner { this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); this.addAudioToStream(audioData, timestamp); - this.classifications = []; + this.classificationResults = []; this.finishProcessing(); - return [...this.classifications]; + return [...this.classificationResults]; } /** - * Internal function for converting raw data into a classification, and - * adding it to our classfications list. + * Internal function for converting raw data into classification results, and + * adding them to our classfication results list. **/ - private addJsAudioClassification(binaryProto: Uint8Array): void { - const classificationResult = - ClassificationResult.deserializeBinary(binaryProto); - this.classifications.push( - ...convertFromClassificationResultProto(classificationResult)); + private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void { + binaryProtos.forEach(binaryProto => { + const classificationResult = + ClassificationResult.deserializeBinary(binaryProto); + this.classificationResults.push( + convertFromClassificationResultProto(classificationResult)); + }); } /** Updates the MediaPipe graph configuration. */ @@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM); - graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -198,14 +201,15 @@ export class AudioClassifier extends TaskRunner { classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); classifierNode.addOutputStream( - 'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM); + 'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { - this.addJsAudioClassification(binaryProto); - }); + this.attachProtoVectorListener( + TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => { + this.addJsAudioClassificationResults(binaryProtos); + }); const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts index 0a51dee04..0b616126a 100644 --- a/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts +++ b/mediapipe/tasks/web/audio/audio_classifier/audio_classifier_result.ts @@ -15,4 +15,4 @@ */ export {Category} from '../../../../tasks/web/components/containers/category'; -export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; +export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result'; diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index 7d13fadcb..1b0e403ff 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -10,8 +10,8 @@ mediapipe_ts_library( ) mediapipe_ts_library( - name = "classifications", - srcs = ["classifications.d.ts"], + name = "classification_result", + srcs = ["classification_result.d.ts"], deps = [":category"], ) diff --git a/mediapipe/tasks/web/components/containers/classifications.d.ts b/mediapipe/tasks/web/components/containers/classification_result.d.ts similarity index 61% rename from mediapipe/tasks/web/components/containers/classifications.d.ts rename to mediapipe/tasks/web/components/containers/classification_result.d.ts index 67a259bbe..33632f925 100644 --- a/mediapipe/tasks/web/components/containers/classifications.d.ts +++ b/mediapipe/tasks/web/components/containers/classification_result.d.ts @@ -16,27 +16,14 @@ import {Category} from '../../../../tasks/web/components/containers/category'; -/** List of predicted categories with an optional timestamp. */ -export interface ClassificationEntry { +/** Classification results for a given classifier head. */ +export interface Classifications { /** * The array of predicted categories, usually sorted by descending scores, * e.g., from high to low probability. */ categories: Category[]; - /** - * The optional timestamp (in milliseconds) associated to the classification - * entry. This is useful for time series use cases, e.g., audio - * classification. - */ - timestampMs?: number; -} - -/** Classifications for a given classifier head. */ -export interface Classifications { - /** A list of classification entries. */ - entries: ClassificationEntry[]; - /** * The index of the classifier head these categories refer to. This is * useful for multi-head models. @@ -45,7 +32,24 @@ export interface Classifications { /** * The name of the classifier head, which is the corresponding tensor - * metadata name. + * metadata name. Defaults to an empty string if there is no such metadata. */ headName: string; } + +/** Classification results of a model. */ +export interface ClassificationResult { + /** The classification results for each head of the model. */ + classifications: Classifications[]; + + /** + * The optional timestamp (in milliseconds) of the start of the chunk of data + * corresponding to these results. + * + * This is only used for classification on time series (e.g. audio + * classification). In these use cases, the amount of data to process might + * exceed the maximum size that the model can process: to solve this, the + * input data is split into multiple chunks starting at different timestamps. + */ + timestampMs?: number; +} diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index cd7190dd9..e0d84b632 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -17,8 +17,9 @@ mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ + "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", - "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/containers:classification_result", ], ) diff --git a/mediapipe/tasks/web/components/processors/classifier_result.ts b/mediapipe/tasks/web/components/processors/classifier_result.ts index ade967932..90d10b84d 100644 --- a/mediapipe/tasks/web/components/processors/classifier_result.ts +++ b/mediapipe/tasks/web/components/processors/classifier_result.ts @@ -14,48 +14,46 @@ * limitations under the License. */ -import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; -import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; +import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; +import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result'; const DEFAULT_INDEX = -1; const DEFAULT_SCORE = 0.0; /** - * Converts a ClassificationEntry proto to the ClassificationEntry result - * type. + * Converts a Classifications proto to a Classifications object. */ -function convertFromClassificationEntryProto(source: ClassificationEntryProto): - ClassificationEntry { - const categories = source.getCategoriesList().map(category => { - return { - index: category.getIndex() ?? DEFAULT_INDEX, - score: category.getScore() ?? DEFAULT_SCORE, - displayName: category.getDisplayName() ?? '', - categoryName: category.getCategoryName() ?? '', - }; - }); - +function convertFromClassificationsProto(source: ClassificationsProto): + Classifications { + const categories = + source.getClassificationList()?.getClassificationList().map( + classification => { + return { + index: classification.getIndex() ?? DEFAULT_INDEX, + score: classification.getScore() ?? DEFAULT_SCORE, + categoryName: classification.getLabel() ?? '', + displayName: classification.getDisplayName() ?? '', + }; + }) ?? + []; return { categories, - timestampMs: source.getTimestampMs(), + headIndex: source.getHeadIndex() ?? DEFAULT_INDEX, + headName: source.getHeadName() ?? '', }; } /** - * Converts a ClassificationResult proto to a list of classifications. + * Converts a ClassificationResult proto to a ClassificationResult object. */ export function convertFromClassificationResultProto( - classificationResult: ClassificationResult) : Classifications[] { - const result: Classifications[] = []; - for (const classificationsProto of - classificationResult.getClassificationsList()) { - const classifications: Classifications = { - entries: classificationsProto.getEntriesList().map( - entry => convertFromClassificationEntryProto(entry)), - headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX, - headName: classificationsProto.getHeadName() ?? '', - }; - result.push(classifications); + source: ClassificationResultProto): ClassificationResult { + const result: ClassificationResult = { + classifications: source.getClassificationsList().map( + classififications => convertFromClassificationsProto(classififications)) + }; + if (source.hasTimestampMs()) { + result.timestampMs = source.getTimestampMs(); } return result; } diff --git a/mediapipe/tasks/web/text/text_classifier/BUILD b/mediapipe/tasks/web/text/text_classifier/BUILD index 25a8817d4..4ebdce18a 100644 --- a/mediapipe/tasks/web/text/text_classifier/BUILD +++ b/mediapipe/tasks/web/text/text_classifier/BUILD @@ -22,7 +22,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index d92248b80..e1d0c9601 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm // Placeholder for internal dependency on trusted resource url import {TextClassifierOptions} from './text_classifier_options'; -import {Classifications} from './text_classifier_result'; +import {TextClassifierResult} from './text_classifier_result'; const INPUT_STREAM = 'text_in'; -const CLASSIFICATION_RESULT_STREAM = 'classification_result_out'; +const CLASSIFICATIONS_STREAM = 'classifications_out'; const TEXT_CLASSIFIER_GRAPH = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; @@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH = /** Performs Natural Language classification. */ export class TextClassifier extends TaskRunner { - private classifications: Classifications[] = []; + private classificationResult: TextClassifierResult = {classifications: []}; private readonly options = new TextClassifierGraphOptions(); /** @@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner { * @param text The text to process. * @return The classification result of the text */ - classify(text: string): Classifications[] { - // Get classification classes by running our MediaPipe graph. - this.classifications = []; + classify(text: string): TextClassifierResult { + // Get classification result by running our MediaPipe graph. + this.classificationResult = {classifications: []}; this.addStringToStream( text, INPUT_STREAM, /* timestamp= */ performance.now()); this.finishProcessing(); - return [...this.classifications]; - } - - // Internal function for converting raw data into a classification, and - // adding it to our classifications list. - private addJsTextClassification(binaryProto: Uint8Array): void { - const classificationResult = - ClassificationResult.deserializeBinary(binaryProto); - console.log(classificationResult.toObject()); - this.classifications.push( - ...convertFromClassificationResultProto(classificationResult)); + return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); - graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner { const classifierNode = new CalculatorGraphConfig.Node(); classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH); classifierNode.addInputStream('TEXT:' + INPUT_STREAM); - classifierNode.addOutputStream( - 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { - this.addJsTextClassification(binaryProto); + this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts index 0a51dee04..707ba5da2 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier_result.ts @@ -15,4 +15,4 @@ */ export {Category} from '../../../../tasks/web/components/containers/category'; -export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; +export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result'; diff --git a/mediapipe/tasks/web/vision/image_classifier/BUILD b/mediapipe/tasks/web/vision/image_classifier/BUILD index 6937dc4f3..e96d6a8e3 100644 --- a/mediapipe/tasks/web/vision/image_classifier/BUILD +++ b/mediapipe/tasks/web/vision/image_classifier/BUILD @@ -21,7 +21,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", - "//mediapipe/tasks/web/components/containers:classifications", + "//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_result", diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts index cb63874c4..ba4b6c907 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier.ts @@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap // Placeholder for internal dependency on trusted resource url import {ImageClassifierOptions} from './image_classifier_options'; -import {Classifications} from './image_classifier_result'; +import {ImageClassifierResult} from './image_classifier_result'; const IMAGE_CLASSIFIER_GRAPH = 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; const INPUT_STREAM = 'input_image'; -const CLASSIFICATION_RESULT_STREAM = 'classification_result'; +const CLASSIFICATIONS_STREAM = 'classifications'; export {ImageSource}; // Used in the public API @@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API /** Performs classification on images. */ export class ImageClassifier extends TaskRunner { - private classifications: Classifications[] = []; + private classificationResult: ImageClassifierResult = {classifications: []}; private readonly options = new ImageClassifierGraphOptions(); /** @@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner { * provided, defaults to `performance.now()`. * @return The classification result of the image */ - classify(imageSource: ImageSource, timestamp?: number): Classifications[] { - // Get classification classes by running our MediaPipe graph. - this.classifications = []; + classify(imageSource: ImageSource, timestamp?: number): + ImageClassifierResult { + // Get classification result by running our MediaPipe graph. + this.classificationResult = {classifications: []}; this.addGpuBufferAsImageToStream( imageSource, INPUT_STREAM, timestamp ?? performance.now()); this.finishProcessing(); - return [...this.classifications]; - } - - /** - * Internal function for converting raw data into a classification, and - * adding it to our classfications list. - **/ - private addJsImageClassification(binaryProto: Uint8Array): void { - const classificationResult = - ClassificationResult.deserializeBinary(binaryProto); - this.classifications.push( - ...convertFromClassificationResultProto(classificationResult)); + return this.classificationResult; } /** Updates the MediaPipe graph configuration. */ private refreshGraph(): void { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(INPUT_STREAM); - graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); + graphConfig.addOutputStream(CLASSIFICATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner { const classifierNode = new CalculatorGraphConfig.Node(); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); - classifierNode.addOutputStream( - 'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM); + classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM); classifierNode.setOptions(calculatorOptions); graphConfig.addNode(classifierNode); - this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { - this.addJsImageClassification(binaryProto); + this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => { + this.classificationResult = convertFromClassificationResultProto( + ClassificationResult.deserializeBinary(binaryProto)); }); const binaryGraph = graphConfig.serializeBinary(); diff --git a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts index 0a51dee04..44032234c 100644 --- a/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts +++ b/mediapipe/tasks/web/vision/image_classifier/image_classifier_result.ts @@ -15,4 +15,4 @@ */ export {Category} from '../../../../tasks/web/components/containers/category'; -export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; +export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';