Internal change
PiperOrigin-RevId: 487466061
This commit is contained in:
parent
0ac604d507
commit
0b12aa9435
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
|
|
|
@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {AudioClassifierOptions} from './audio_classifier_options';
|
||||
import {Classifications} from './audio_classifier_result';
|
||||
import {AudioClassifierResult} from './audio_classifier_result';
|
||||
|
||||
const MEDIAPIPE_GRAPH =
|
||||
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||
|
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
|
|||
// implementation
|
||||
const AUDIO_STREAM = 'input_audio';
|
||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
||||
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/** Performs audio classification. */
|
||||
export class AudioClassifier extends TaskRunner {
|
||||
private classifications: Classifications[] = [];
|
||||
private classificationResults: AudioClassifierResult[] = [];
|
||||
private defaultSampleRate = 48000;
|
||||
private readonly options = new AudioClassifierGraphOptions();
|
||||
|
||||
|
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
|
|||
* `48000` if no custom default was set.
|
||||
* @return The classification result of the audio datas
|
||||
*/
|
||||
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
|
||||
classify(audioData: Float32Array, sampleRate?: number):
|
||||
AudioClassifierResult[] {
|
||||
sampleRate = sampleRate ?? this.defaultSampleRate;
|
||||
|
||||
// Configures the number of samples in the WASM layer. We re-configure the
|
||||
|
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
|
|||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
||||
this.addAudioToStream(audioData, timestamp);
|
||||
|
||||
this.classifications = [];
|
||||
this.classificationResults = [];
|
||||
this.finishProcessing();
|
||||
return [...this.classifications];
|
||||
return [...this.classificationResults];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal function for converting raw data into a classification, and
|
||||
* adding it to our classfications list.
|
||||
* Internal function for converting raw data into classification results, and
|
||||
* adding them to our classfication results list.
|
||||
**/
|
||||
private addJsAudioClassification(binaryProto: Uint8Array): void {
|
||||
private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
|
||||
binaryProtos.forEach(binaryProto => {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
this.classifications.push(
|
||||
...convertFromClassificationResultProto(classificationResult));
|
||||
this.classificationResults.push(
|
||||
convertFromClassificationResultProto(classificationResult));
|
||||
});
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
|
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
|
|||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(AUDIO_STREAM);
|
||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||
graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -198,13 +201,14 @@ export class AudioClassifier extends TaskRunner {
|
|||
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
||||
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
|
||||
'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||
this.addJsAudioClassification(binaryProto);
|
||||
this.attachProtoVectorListener(
|
||||
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
|
||||
this.addJsAudioClassificationResults(binaryProtos);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
*/
|
||||
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
|
|
@ -10,8 +10,8 @@ mediapipe_ts_library(
|
|||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "classifications",
|
||||
srcs = ["classifications.d.ts"],
|
||||
name = "classification_result",
|
||||
srcs = ["classification_result.d.ts"],
|
||||
deps = [":category"],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,27 +16,14 @@
|
|||
|
||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||
|
||||
/** List of predicted categories with an optional timestamp. */
|
||||
export interface ClassificationEntry {
|
||||
/** Classification results for a given classifier head. */
|
||||
export interface Classifications {
|
||||
/**
|
||||
* The array of predicted categories, usually sorted by descending scores,
|
||||
* e.g., from high to low probability.
|
||||
*/
|
||||
categories: Category[];
|
||||
|
||||
/**
|
||||
* The optional timestamp (in milliseconds) associated to the classification
|
||||
* entry. This is useful for time series use cases, e.g., audio
|
||||
* classification.
|
||||
*/
|
||||
timestampMs?: number;
|
||||
}
|
||||
|
||||
/** Classifications for a given classifier head. */
|
||||
export interface Classifications {
|
||||
/** A list of classification entries. */
|
||||
entries: ClassificationEntry[];
|
||||
|
||||
/**
|
||||
* The index of the classifier head these categories refer to. This is
|
||||
* useful for multi-head models.
|
||||
|
@ -45,7 +32,24 @@ export interface Classifications {
|
|||
|
||||
/**
|
||||
* The name of the classifier head, which is the corresponding tensor
|
||||
* metadata name.
|
||||
* metadata name. Defaults to an empty string if there is no such metadata.
|
||||
*/
|
||||
headName: string;
|
||||
}
|
||||
|
||||
/** Classification results of a model. */
|
||||
export interface ClassificationResult {
|
||||
/** The classification results for each head of the model. */
|
||||
classifications: Classifications[];
|
||||
|
||||
/**
|
||||
* The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
* corresponding to these results.
|
||||
*
|
||||
* This is only used for classification on time series (e.g. audio
|
||||
* classification). In these use cases, the amount of data to process might
|
||||
* exceed the maximum size that the model can process: to solve this, the
|
||||
* input data is split into multiple chunks starting at different timestamps.
|
||||
*/
|
||||
timestampMs?: number;
|
||||
}
|
|
@ -17,8 +17,9 @@ mediapipe_ts_library(
|
|||
name = "classifier_result",
|
||||
srcs = ["classifier_result.ts"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -14,48 +14,46 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
||||
const DEFAULT_INDEX = -1;
|
||||
const DEFAULT_SCORE = 0.0;
|
||||
|
||||
/**
|
||||
* Converts a ClassificationEntry proto to the ClassificationEntry result
|
||||
* type.
|
||||
* Converts a Classifications proto to a Classifications object.
|
||||
*/
|
||||
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
|
||||
ClassificationEntry {
|
||||
const categories = source.getCategoriesList().map(category => {
|
||||
function convertFromClassificationsProto(source: ClassificationsProto):
|
||||
Classifications {
|
||||
const categories =
|
||||
source.getClassificationList()?.getClassificationList().map(
|
||||
classification => {
|
||||
return {
|
||||
index: category.getIndex() ?? DEFAULT_INDEX,
|
||||
score: category.getScore() ?? DEFAULT_SCORE,
|
||||
displayName: category.getDisplayName() ?? '',
|
||||
categoryName: category.getCategoryName() ?? '',
|
||||
index: classification.getIndex() ?? DEFAULT_INDEX,
|
||||
score: classification.getScore() ?? DEFAULT_SCORE,
|
||||
categoryName: classification.getLabel() ?? '',
|
||||
displayName: classification.getDisplayName() ?? '',
|
||||
};
|
||||
});
|
||||
|
||||
}) ??
|
||||
[];
|
||||
return {
|
||||
categories,
|
||||
timestampMs: source.getTimestampMs(),
|
||||
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
|
||||
headName: source.getHeadName() ?? '',
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a ClassificationResult proto to a list of classifications.
|
||||
* Converts a ClassificationResult proto to a ClassificationResult object.
|
||||
*/
|
||||
export function convertFromClassificationResultProto(
|
||||
classificationResult: ClassificationResult) : Classifications[] {
|
||||
const result: Classifications[] = [];
|
||||
for (const classificationsProto of
|
||||
classificationResult.getClassificationsList()) {
|
||||
const classifications: Classifications = {
|
||||
entries: classificationsProto.getEntriesList().map(
|
||||
entry => convertFromClassificationEntryProto(entry)),
|
||||
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
|
||||
headName: classificationsProto.getHeadName() ?? '',
|
||||
source: ClassificationResultProto): ClassificationResult {
|
||||
const result: ClassificationResult = {
|
||||
classifications: source.getClassificationsList().map(
|
||||
classififications => convertFromClassificationsProto(classififications))
|
||||
};
|
||||
result.push(classifications);
|
||||
if (source.hasTimestampMs()) {
|
||||
result.timestampMs = source.getTimestampMs();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
|
|
|
@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {TextClassifierOptions} from './text_classifier_options';
|
||||
import {Classifications} from './text_classifier_result';
|
||||
import {TextClassifierResult} from './text_classifier_result';
|
||||
|
||||
const INPUT_STREAM = 'text_in';
|
||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
|
||||
const CLASSIFICATIONS_STREAM = 'classifications_out';
|
||||
const TEXT_CLASSIFIER_GRAPH =
|
||||
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||
|
||||
|
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
|
|||
|
||||
/** Performs Natural Language classification. */
|
||||
export class TextClassifier extends TaskRunner {
|
||||
private classifications: Classifications[] = [];
|
||||
private classificationResult: TextClassifierResult = {classifications: []};
|
||||
private readonly options = new TextClassifierGraphOptions();
|
||||
|
||||
/**
|
||||
|
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
|
|||
* @param text The text to process.
|
||||
* @return The classification result of the text
|
||||
*/
|
||||
classify(text: string): Classifications[] {
|
||||
// Get classification classes by running our MediaPipe graph.
|
||||
this.classifications = [];
|
||||
classify(text: string): TextClassifierResult {
|
||||
// Get classification result by running our MediaPipe graph.
|
||||
this.classificationResult = {classifications: []};
|
||||
this.addStringToStream(
|
||||
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||
this.finishProcessing();
|
||||
return [...this.classifications];
|
||||
}
|
||||
|
||||
// Internal function for converting raw data into a classification, and
|
||||
// adding it to our classifications list.
|
||||
private addJsTextClassification(binaryProto: Uint8Array): void {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
console.log(classificationResult.toObject());
|
||||
this.classifications.push(
|
||||
...convertFromClassificationResultProto(classificationResult));
|
||||
return this.classificationResult;
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
|
|||
const classifierNode = new CalculatorGraphConfig.Node();
|
||||
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
||||
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||
this.addJsTextClassification(binaryProto);
|
||||
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
*/
|
||||
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
|
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
|
|
|
@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {ImageClassifierOptions} from './image_classifier_options';
|
||||
import {Classifications} from './image_classifier_result';
|
||||
import {ImageClassifierResult} from './image_classifier_result';
|
||||
|
||||
const IMAGE_CLASSIFIER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
||||
const INPUT_STREAM = 'input_image';
|
||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
||||
const CLASSIFICATIONS_STREAM = 'classifications';
|
||||
|
||||
export {ImageSource}; // Used in the public API
|
||||
|
||||
|
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
|
|||
|
||||
/** Performs classification on images. */
|
||||
export class ImageClassifier extends TaskRunner {
|
||||
private classifications: Classifications[] = [];
|
||||
private classificationResult: ImageClassifierResult = {classifications: []};
|
||||
private readonly options = new ImageClassifierGraphOptions();
|
||||
|
||||
/**
|
||||
|
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
|
|||
* provided, defaults to `performance.now()`.
|
||||
* @return The classification result of the image
|
||||
*/
|
||||
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
|
||||
// Get classification classes by running our MediaPipe graph.
|
||||
this.classifications = [];
|
||||
classify(imageSource: ImageSource, timestamp?: number):
|
||||
ImageClassifierResult {
|
||||
// Get classification result by running our MediaPipe graph.
|
||||
this.classificationResult = {classifications: []};
|
||||
this.addGpuBufferAsImageToStream(
|
||||
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||
this.finishProcessing();
|
||||
return [...this.classifications];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal function for converting raw data into a classification, and
|
||||
* adding it to our classfications list.
|
||||
**/
|
||||
private addJsImageClassification(binaryProto: Uint8Array): void {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
this.classifications.push(
|
||||
...convertFromClassificationResultProto(classificationResult));
|
||||
return this.classificationResult;
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
|
|||
const classifierNode = new CalculatorGraphConfig.Node();
|
||||
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
||||
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||
this.addJsImageClassification(binaryProto);
|
||||
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
*/
|
||||
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
|
Loading…
Reference in New Issue
Block a user