Internal change
PiperOrigin-RevId: 487466061
This commit is contained in:
parent
0ac604d507
commit
0b12aa9435
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
|
|
@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {AudioClassifierOptions} from './audio_classifier_options';
|
import {AudioClassifierOptions} from './audio_classifier_options';
|
||||||
import {Classifications} from './audio_classifier_result';
|
import {AudioClassifierResult} from './audio_classifier_result';
|
||||||
|
|
||||||
const MEDIAPIPE_GRAPH =
|
const MEDIAPIPE_GRAPH =
|
||||||
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||||
|
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
|
||||||
// implementation
|
// implementation
|
||||||
const AUDIO_STREAM = 'input_audio';
|
const AUDIO_STREAM = 'input_audio';
|
||||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
/** Performs audio classification. */
|
/** Performs audio classification. */
|
||||||
export class AudioClassifier extends TaskRunner {
|
export class AudioClassifier extends TaskRunner {
|
||||||
private classifications: Classifications[] = [];
|
private classificationResults: AudioClassifierResult[] = [];
|
||||||
private defaultSampleRate = 48000;
|
private defaultSampleRate = 48000;
|
||||||
private readonly options = new AudioClassifierGraphOptions();
|
private readonly options = new AudioClassifierGraphOptions();
|
||||||
|
|
||||||
|
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
|
||||||
* `48000` if no custom default was set.
|
* `48000` if no custom default was set.
|
||||||
* @return The classification result of the audio datas
|
* @return The classification result of the audio datas
|
||||||
*/
|
*/
|
||||||
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
|
classify(audioData: Float32Array, sampleRate?: number):
|
||||||
|
AudioClassifierResult[] {
|
||||||
sampleRate = sampleRate ?? this.defaultSampleRate;
|
sampleRate = sampleRate ?? this.defaultSampleRate;
|
||||||
|
|
||||||
// Configures the number of samples in the WASM layer. We re-configure the
|
// Configures the number of samples in the WASM layer. We re-configure the
|
||||||
|
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
|
||||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
||||||
this.addAudioToStream(audioData, timestamp);
|
this.addAudioToStream(audioData, timestamp);
|
||||||
|
|
||||||
this.classifications = [];
|
this.classificationResults = [];
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return [...this.classifications];
|
return [...this.classificationResults];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Internal function for converting raw data into a classification, and
|
* Internal function for converting raw data into classification results, and
|
||||||
* adding it to our classfications list.
|
* adding them to our classfication results list.
|
||||||
**/
|
**/
|
||||||
private addJsAudioClassification(binaryProto: Uint8Array): void {
|
private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
|
||||||
const classificationResult =
|
binaryProtos.forEach(binaryProto => {
|
||||||
ClassificationResult.deserializeBinary(binaryProto);
|
const classificationResult =
|
||||||
this.classifications.push(
|
ClassificationResult.deserializeBinary(binaryProto);
|
||||||
...convertFromClassificationResultProto(classificationResult));
|
this.classificationResults.push(
|
||||||
|
convertFromClassificationResultProto(classificationResult));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(AUDIO_STREAM);
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -198,14 +201,15 @@ export class AudioClassifier extends TaskRunner {
|
||||||
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
||||||
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
||||||
classifierNode.addOutputStream(
|
classifierNode.addOutputStream(
|
||||||
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
|
'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||||
classifierNode.setOptions(calculatorOptions);
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(classifierNode);
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
this.attachProtoVectorListener(
|
||||||
this.addJsAudioClassification(binaryProto);
|
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
|
||||||
});
|
this.addJsAudioClassificationResults(binaryProtos);
|
||||||
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
|
|
@ -15,4 +15,4 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
|
@ -10,8 +10,8 @@ mediapipe_ts_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "classifications",
|
name = "classification_result",
|
||||||
srcs = ["classifications.d.ts"],
|
srcs = ["classification_result.d.ts"],
|
||||||
deps = [":category"],
|
deps = [":category"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,27 +16,14 @@
|
||||||
|
|
||||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
|
||||||
/** List of predicted categories with an optional timestamp. */
|
/** Classification results for a given classifier head. */
|
||||||
export interface ClassificationEntry {
|
export interface Classifications {
|
||||||
/**
|
/**
|
||||||
* The array of predicted categories, usually sorted by descending scores,
|
* The array of predicted categories, usually sorted by descending scores,
|
||||||
* e.g., from high to low probability.
|
* e.g., from high to low probability.
|
||||||
*/
|
*/
|
||||||
categories: Category[];
|
categories: Category[];
|
||||||
|
|
||||||
/**
|
|
||||||
* The optional timestamp (in milliseconds) associated to the classification
|
|
||||||
* entry. This is useful for time series use cases, e.g., audio
|
|
||||||
* classification.
|
|
||||||
*/
|
|
||||||
timestampMs?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Classifications for a given classifier head. */
|
|
||||||
export interface Classifications {
|
|
||||||
/** A list of classification entries. */
|
|
||||||
entries: ClassificationEntry[];
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The index of the classifier head these categories refer to. This is
|
* The index of the classifier head these categories refer to. This is
|
||||||
* useful for multi-head models.
|
* useful for multi-head models.
|
||||||
|
@ -45,7 +32,24 @@ export interface Classifications {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The name of the classifier head, which is the corresponding tensor
|
* The name of the classifier head, which is the corresponding tensor
|
||||||
* metadata name.
|
* metadata name. Defaults to an empty string if there is no such metadata.
|
||||||
*/
|
*/
|
||||||
headName: string;
|
headName: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Classification results of a model. */
|
||||||
|
export interface ClassificationResult {
|
||||||
|
/** The classification results for each head of the model. */
|
||||||
|
classifications: Classifications[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
* corresponding to these results.
|
||||||
|
*
|
||||||
|
* This is only used for classification on time series (e.g. audio
|
||||||
|
* classification). In these use cases, the amount of data to process might
|
||||||
|
* exceed the maximum size that the model can process: to solve this, the
|
||||||
|
* input data is split into multiple chunks starting at different timestamps.
|
||||||
|
*/
|
||||||
|
timestampMs?: number;
|
||||||
|
}
|
|
@ -17,8 +17,9 @@ mediapipe_ts_library(
|
||||||
name = "classifier_result",
|
name = "classifier_result",
|
||||||
srcs = ["classifier_result.ts"],
|
srcs = ["classifier_result.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,48 +14,46 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
||||||
const DEFAULT_INDEX = -1;
|
const DEFAULT_INDEX = -1;
|
||||||
const DEFAULT_SCORE = 0.0;
|
const DEFAULT_SCORE = 0.0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts a ClassificationEntry proto to the ClassificationEntry result
|
* Converts a Classifications proto to a Classifications object.
|
||||||
* type.
|
|
||||||
*/
|
*/
|
||||||
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
|
function convertFromClassificationsProto(source: ClassificationsProto):
|
||||||
ClassificationEntry {
|
Classifications {
|
||||||
const categories = source.getCategoriesList().map(category => {
|
const categories =
|
||||||
return {
|
source.getClassificationList()?.getClassificationList().map(
|
||||||
index: category.getIndex() ?? DEFAULT_INDEX,
|
classification => {
|
||||||
score: category.getScore() ?? DEFAULT_SCORE,
|
return {
|
||||||
displayName: category.getDisplayName() ?? '',
|
index: classification.getIndex() ?? DEFAULT_INDEX,
|
||||||
categoryName: category.getCategoryName() ?? '',
|
score: classification.getScore() ?? DEFAULT_SCORE,
|
||||||
};
|
categoryName: classification.getLabel() ?? '',
|
||||||
});
|
displayName: classification.getDisplayName() ?? '',
|
||||||
|
};
|
||||||
|
}) ??
|
||||||
|
[];
|
||||||
return {
|
return {
|
||||||
categories,
|
categories,
|
||||||
timestampMs: source.getTimestampMs(),
|
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
|
||||||
|
headName: source.getHeadName() ?? '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts a ClassificationResult proto to a list of classifications.
|
* Converts a ClassificationResult proto to a ClassificationResult object.
|
||||||
*/
|
*/
|
||||||
export function convertFromClassificationResultProto(
|
export function convertFromClassificationResultProto(
|
||||||
classificationResult: ClassificationResult) : Classifications[] {
|
source: ClassificationResultProto): ClassificationResult {
|
||||||
const result: Classifications[] = [];
|
const result: ClassificationResult = {
|
||||||
for (const classificationsProto of
|
classifications: source.getClassificationsList().map(
|
||||||
classificationResult.getClassificationsList()) {
|
classififications => convertFromClassificationsProto(classififications))
|
||||||
const classifications: Classifications = {
|
};
|
||||||
entries: classificationsProto.getEntriesList().map(
|
if (source.hasTimestampMs()) {
|
||||||
entry => convertFromClassificationEntryProto(entry)),
|
result.timestampMs = source.getTimestampMs();
|
||||||
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
|
|
||||||
headName: classificationsProto.getHeadName() ?? '',
|
|
||||||
};
|
|
||||||
result.push(classifications);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
|
|
@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {TextClassifierOptions} from './text_classifier_options';
|
import {TextClassifierOptions} from './text_classifier_options';
|
||||||
import {Classifications} from './text_classifier_result';
|
import {TextClassifierResult} from './text_classifier_result';
|
||||||
|
|
||||||
const INPUT_STREAM = 'text_in';
|
const INPUT_STREAM = 'text_in';
|
||||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
|
const CLASSIFICATIONS_STREAM = 'classifications_out';
|
||||||
const TEXT_CLASSIFIER_GRAPH =
|
const TEXT_CLASSIFIER_GRAPH =
|
||||||
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
|
||||||
|
|
||||||
/** Performs Natural Language classification. */
|
/** Performs Natural Language classification. */
|
||||||
export class TextClassifier extends TaskRunner {
|
export class TextClassifier extends TaskRunner {
|
||||||
private classifications: Classifications[] = [];
|
private classificationResult: TextClassifierResult = {classifications: []};
|
||||||
private readonly options = new TextClassifierGraphOptions();
|
private readonly options = new TextClassifierGraphOptions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
|
||||||
* @param text The text to process.
|
* @param text The text to process.
|
||||||
* @return The classification result of the text
|
* @return The classification result of the text
|
||||||
*/
|
*/
|
||||||
classify(text: string): Classifications[] {
|
classify(text: string): TextClassifierResult {
|
||||||
// Get classification classes by running our MediaPipe graph.
|
// Get classification result by running our MediaPipe graph.
|
||||||
this.classifications = [];
|
this.classificationResult = {classifications: []};
|
||||||
this.addStringToStream(
|
this.addStringToStream(
|
||||||
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return [...this.classifications];
|
return this.classificationResult;
|
||||||
}
|
|
||||||
|
|
||||||
// Internal function for converting raw data into a classification, and
|
|
||||||
// adding it to our classifications list.
|
|
||||||
private addJsTextClassification(binaryProto: Uint8Array): void {
|
|
||||||
const classificationResult =
|
|
||||||
ClassificationResult.deserializeBinary(binaryProto);
|
|
||||||
console.log(classificationResult.toObject());
|
|
||||||
this.classifications.push(
|
|
||||||
...convertFromClassificationResultProto(classificationResult));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
private refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
|
||||||
const classifierNode = new CalculatorGraphConfig.Node();
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
||||||
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
||||||
classifierNode.addOutputStream(
|
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
|
||||||
classifierNode.setOptions(calculatorOptions);
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(classifierNode);
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||||
this.addJsTextClassification(binaryProto);
|
this.classificationResult = convertFromClassificationResultProto(
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto));
|
||||||
});
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
|
|
@ -15,4 +15,4 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
|
|
@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ImageClassifierOptions} from './image_classifier_options';
|
import {ImageClassifierOptions} from './image_classifier_options';
|
||||||
import {Classifications} from './image_classifier_result';
|
import {ImageClassifierResult} from './image_classifier_result';
|
||||||
|
|
||||||
const IMAGE_CLASSIFIER_GRAPH =
|
const IMAGE_CLASSIFIER_GRAPH =
|
||||||
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
||||||
const INPUT_STREAM = 'input_image';
|
const INPUT_STREAM = 'input_image';
|
||||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
const CLASSIFICATIONS_STREAM = 'classifications';
|
||||||
|
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
/** Performs classification on images. */
|
/** Performs classification on images. */
|
||||||
export class ImageClassifier extends TaskRunner {
|
export class ImageClassifier extends TaskRunner {
|
||||||
private classifications: Classifications[] = [];
|
private classificationResult: ImageClassifierResult = {classifications: []};
|
||||||
private readonly options = new ImageClassifierGraphOptions();
|
private readonly options = new ImageClassifierGraphOptions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
|
||||||
* provided, defaults to `performance.now()`.
|
* provided, defaults to `performance.now()`.
|
||||||
* @return The classification result of the image
|
* @return The classification result of the image
|
||||||
*/
|
*/
|
||||||
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
|
classify(imageSource: ImageSource, timestamp?: number):
|
||||||
// Get classification classes by running our MediaPipe graph.
|
ImageClassifierResult {
|
||||||
this.classifications = [];
|
// Get classification result by running our MediaPipe graph.
|
||||||
|
this.classificationResult = {classifications: []};
|
||||||
this.addGpuBufferAsImageToStream(
|
this.addGpuBufferAsImageToStream(
|
||||||
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return [...this.classifications];
|
return this.classificationResult;
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Internal function for converting raw data into a classification, and
|
|
||||||
* adding it to our classfications list.
|
|
||||||
**/
|
|
||||||
private addJsImageClassification(binaryProto: Uint8Array): void {
|
|
||||||
const classificationResult =
|
|
||||||
ClassificationResult.deserializeBinary(binaryProto);
|
|
||||||
this.classifications.push(
|
|
||||||
...convertFromClassificationResultProto(classificationResult));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
private refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
|
||||||
const classifierNode = new CalculatorGraphConfig.Node();
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
||||||
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
||||||
classifierNode.addOutputStream(
|
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
|
||||||
classifierNode.setOptions(calculatorOptions);
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(classifierNode);
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||||
this.addJsImageClassification(binaryProto);
|
this.classificationResult = convertFromClassificationResultProto(
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto));
|
||||||
});
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
|
|
@ -15,4 +15,4 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
Loading…
Reference in New Issue
Block a user