Internal change

PiperOrigin-RevId: 487466061
This commit is contained in:
MediaPipe Team 2022-11-10 01:30:55 -08:00 committed by Copybara-Service
parent 0ac604d507
commit 0b12aa9435
13 changed files with 103 additions and 116 deletions

View File

@ -21,7 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierOptions} from './audio_classifier_options';
import {Classifications} from './audio_classifier_result'; import {AudioClassifierResult} from './audio_classifier_result';
const MEDIAPIPE_GRAPH = const MEDIAPIPE_GRAPH =
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
// implementation // implementation
const AUDIO_STREAM = 'input_audio'; const AUDIO_STREAM = 'input_audio';
const SAMPLE_RATE_STREAM = 'sample_rate'; const SAMPLE_RATE_STREAM = 'sample_rate';
const CLASSIFICATION_RESULT_STREAM = 'classification_result'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/** Performs audio classification. */ /** Performs audio classification. */
export class AudioClassifier extends TaskRunner { export class AudioClassifier extends TaskRunner {
private classifications: Classifications[] = []; private classificationResults: AudioClassifierResult[] = [];
private defaultSampleRate = 48000; private defaultSampleRate = 48000;
private readonly options = new AudioClassifierGraphOptions(); private readonly options = new AudioClassifierGraphOptions();
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
* `48000` if no custom default was set. * `48000` if no custom default was set.
* @return The classification result of the audio datas * @return The classification result of the audio datas
*/ */
classify(audioData: Float32Array, sampleRate?: number): Classifications[] { classify(audioData: Float32Array, sampleRate?: number):
AudioClassifierResult[] {
sampleRate = sampleRate ?? this.defaultSampleRate; sampleRate = sampleRate ?? this.defaultSampleRate;
// Configures the number of samples in the WASM layer. We re-configure the // Configures the number of samples in the WASM layer. We re-configure the
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
this.addAudioToStream(audioData, timestamp); this.addAudioToStream(audioData, timestamp);
this.classifications = []; this.classificationResults = [];
this.finishProcessing(); this.finishProcessing();
return [...this.classifications]; return [...this.classificationResults];
} }
/** /**
* Internal function for converting raw data into a classification, and * Internal function for converting raw data into classification results, and
* adding it to our classfications list. * adding them to our classfication results list.
**/ **/
private addJsAudioClassification(binaryProto: Uint8Array): void { private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
binaryProtos.forEach(binaryProto => {
const classificationResult = const classificationResult =
ClassificationResult.deserializeBinary(binaryProto); ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push( this.classificationResults.push(
...convertFromClassificationResultProto(classificationResult)); convertFromClassificationResultProto(classificationResult));
});
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -198,13 +201,14 @@ export class AudioClassifier extends TaskRunner {
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream(
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM); 'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { this.attachProtoVectorListener(
this.addJsAudioClassification(binaryProto); TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
this.addJsAudioClassificationResults(binaryProtos);
}); });
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/ */
export {Category} from '../../../../tasks/web/components/containers/category'; export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -10,8 +10,8 @@ mediapipe_ts_library(
) )
mediapipe_ts_library( mediapipe_ts_library(
name = "classifications", name = "classification_result",
srcs = ["classifications.d.ts"], srcs = ["classification_result.d.ts"],
deps = [":category"], deps = [":category"],
) )

View File

@ -16,27 +16,14 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
/** List of predicted categories with an optional timestamp. */ /** Classification results for a given classifier head. */
export interface ClassificationEntry { export interface Classifications {
/** /**
* The array of predicted categories, usually sorted by descending scores, * The array of predicted categories, usually sorted by descending scores,
* e.g., from high to low probability. * e.g., from high to low probability.
*/ */
categories: Category[]; categories: Category[];
/**
* The optional timestamp (in milliseconds) associated to the classification
* entry. This is useful for time series use cases, e.g., audio
* classification.
*/
timestampMs?: number;
}
/** Classifications for a given classifier head. */
export interface Classifications {
/** A list of classification entries. */
entries: ClassificationEntry[];
/** /**
* The index of the classifier head these categories refer to. This is * The index of the classifier head these categories refer to. This is
* useful for multi-head models. * useful for multi-head models.
@ -45,7 +32,24 @@ export interface Classifications {
/** /**
* The name of the classifier head, which is the corresponding tensor * The name of the classifier head, which is the corresponding tensor
* metadata name. * metadata name. Defaults to an empty string if there is no such metadata.
*/ */
headName: string; headName: string;
} }
/** Classification results of a model. */
export interface ClassificationResult {
/** The classification results for each head of the model. */
classifications: Classifications[];
/**
* The optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*
* This is only used for classification on time series (e.g. audio
* classification). In these use cases, the amount of data to process might
* exceed the maximum size that the model can process: to solve this, the
* input data is split into multiple chunks starting at different timestamps.
*/
timestampMs?: number;
}

View File

@ -17,8 +17,9 @@ mediapipe_ts_library(
name = "classifier_result", name = "classifier_result",
srcs = ["classifier_result.ts"], srcs = ["classifier_result.ts"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
], ],
) )

View File

@ -14,48 +14,46 @@
* limitations under the License. * limitations under the License.
*/ */
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
const DEFAULT_INDEX = -1; const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0; const DEFAULT_SCORE = 0.0;
/** /**
* Converts a ClassificationEntry proto to the ClassificationEntry result * Converts a Classifications proto to a Classifications object.
* type.
*/ */
function convertFromClassificationEntryProto(source: ClassificationEntryProto): function convertFromClassificationsProto(source: ClassificationsProto):
ClassificationEntry { Classifications {
const categories = source.getCategoriesList().map(category => { const categories =
source.getClassificationList()?.getClassificationList().map(
classification => {
return { return {
index: category.getIndex() ?? DEFAULT_INDEX, index: classification.getIndex() ?? DEFAULT_INDEX,
score: category.getScore() ?? DEFAULT_SCORE, score: classification.getScore() ?? DEFAULT_SCORE,
displayName: category.getDisplayName() ?? '', categoryName: classification.getLabel() ?? '',
categoryName: category.getCategoryName() ?? '', displayName: classification.getDisplayName() ?? '',
}; };
}); }) ??
[];
return { return {
categories, categories,
timestampMs: source.getTimestampMs(), headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
headName: source.getHeadName() ?? '',
}; };
} }
/** /**
* Converts a ClassificationResult proto to a list of classifications. * Converts a ClassificationResult proto to a ClassificationResult object.
*/ */
export function convertFromClassificationResultProto( export function convertFromClassificationResultProto(
classificationResult: ClassificationResult) : Classifications[] { source: ClassificationResultProto): ClassificationResult {
const result: Classifications[] = []; const result: ClassificationResult = {
for (const classificationsProto of classifications: source.getClassificationsList().map(
classificationResult.getClassificationsList()) { classififications => convertFromClassificationsProto(classififications))
const classifications: Classifications = {
entries: classificationsProto.getEntriesList().map(
entry => convertFromClassificationEntryProto(entry)),
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
headName: classificationsProto.getHeadName() ?? '',
}; };
result.push(classifications); if (source.hasTimestampMs()) {
result.timestampMs = source.getTimestampMs();
} }
return result; return result;
} }

View File

@ -22,7 +22,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierOptions} from './text_classifier_options';
import {Classifications} from './text_classifier_result'; import {TextClassifierResult} from './text_classifier_result';
const INPUT_STREAM = 'text_in'; const INPUT_STREAM = 'text_in';
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out'; const CLASSIFICATIONS_STREAM = 'classifications_out';
const TEXT_CLASSIFIER_GRAPH = const TEXT_CLASSIFIER_GRAPH =
'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; 'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
/** Performs Natural Language classification. */ /** Performs Natural Language classification. */
export class TextClassifier extends TaskRunner { export class TextClassifier extends TaskRunner {
private classifications: Classifications[] = []; private classificationResult: TextClassifierResult = {classifications: []};
private readonly options = new TextClassifierGraphOptions(); private readonly options = new TextClassifierGraphOptions();
/** /**
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
* @param text The text to process. * @param text The text to process.
* @return The classification result of the text * @return The classification result of the text
*/ */
classify(text: string): Classifications[] { classify(text: string): TextClassifierResult {
// Get classification classes by running our MediaPipe graph. // Get classification result by running our MediaPipe graph.
this.classifications = []; this.classificationResult = {classifications: []};
this.addStringToStream( this.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now()); text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing(); this.finishProcessing();
return [...this.classifications]; return this.classificationResult;
}
// Internal function for converting raw data into a classification, and
// adding it to our classifications list.
private addJsTextClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
console.log(classificationResult.toObject());
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
const classifierNode = new CalculatorGraphConfig.Node(); const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH); classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
classifierNode.addInputStream('TEXT:' + INPUT_STREAM); classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.addJsTextClassification(binaryProto); this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
}); });
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/ */
export {Category} from '../../../../tasks/web/components/containers/category'; export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -21,7 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ImageClassifierOptions} from './image_classifier_options'; import {ImageClassifierOptions} from './image_classifier_options';
import {Classifications} from './image_classifier_result'; import {ImageClassifierResult} from './image_classifier_result';
const IMAGE_CLASSIFIER_GRAPH = const IMAGE_CLASSIFIER_GRAPH =
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
const INPUT_STREAM = 'input_image'; const INPUT_STREAM = 'input_image';
const CLASSIFICATION_RESULT_STREAM = 'classification_result'; const CLASSIFICATIONS_STREAM = 'classifications';
export {ImageSource}; // Used in the public API export {ImageSource}; // Used in the public API
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
/** Performs classification on images. */ /** Performs classification on images. */
export class ImageClassifier extends TaskRunner { export class ImageClassifier extends TaskRunner {
private classifications: Classifications[] = []; private classificationResult: ImageClassifierResult = {classifications: []};
private readonly options = new ImageClassifierGraphOptions(); private readonly options = new ImageClassifierGraphOptions();
/** /**
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
* provided, defaults to `performance.now()`. * provided, defaults to `performance.now()`.
* @return The classification result of the image * @return The classification result of the image
*/ */
classify(imageSource: ImageSource, timestamp?: number): Classifications[] { classify(imageSource: ImageSource, timestamp?: number):
// Get classification classes by running our MediaPipe graph. ImageClassifierResult {
this.classifications = []; // Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []};
this.addGpuBufferAsImageToStream( this.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now()); imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();
return [...this.classifications]; return this.classificationResult;
}
/**
* Internal function for converting raw data into a classification, and
* adding it to our classfications list.
**/
private addJsImageClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
const classifierNode = new CalculatorGraphConfig.Node(); const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.addJsImageClassification(binaryProto); this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
}); });
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/ */
export {Category} from '../../../../tasks/web/components/containers/category'; export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';