From 9f6b2cd577c7876d8772251ba7a14bfd70ed772d Mon Sep 17 00:00:00 2001 From: Sebastian Schmidt Date: Fri, 24 Mar 2023 10:36:38 -0700 Subject: [PATCH] Add convertFromClassifications() helper PiperOrigin-RevId: 519181016 --- .../tasks/web/components/processors/BUILD | 1 + .../processors/classifier_result.ts | 41 +++++++++++-------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index cab24293d..b83f73eb2 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -34,6 +34,7 @@ mediapipe_ts_library( name = "classifier_result", srcs = ["classifier_result.ts"], deps = [ + "//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/web/components/containers:classification_result", ], diff --git a/mediapipe/tasks/web/components/processors/classifier_result.ts b/mediapipe/tasks/web/components/processors/classifier_result.ts index 90d10b84d..ae58252e8 100644 --- a/mediapipe/tasks/web/components/processors/classifier_result.ts +++ b/mediapipe/tasks/web/components/processors/classifier_result.ts @@ -14,33 +14,42 @@ * limitations under the License. */ +import {Classification as ClassificationProto} from '../../../../framework/formats/classification_pb'; import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result'; const DEFAULT_INDEX = -1; const DEFAULT_SCORE = 0.0; +/** + * Converts a list of Classification protos to a Classifications object. + */ +export function convertFromClassifications( + classifications: ClassificationProto[], headIndex = DEFAULT_INDEX, + headName = ''): Classifications { + const categories = classifications.map(classification => { + return { + index: classification.getIndex() ?? DEFAULT_INDEX, + score: classification.getScore() ?? DEFAULT_SCORE, + categoryName: classification.getLabel() ?? '', + displayName: classification.getDisplayName() ?? '', + }; + }); + return { + categories, + headIndex, + headName, + }; +} + /** * Converts a Classifications proto to a Classifications object. */ function convertFromClassificationsProto(source: ClassificationsProto): Classifications { - const categories = - source.getClassificationList()?.getClassificationList().map( - classification => { - return { - index: classification.getIndex() ?? DEFAULT_INDEX, - score: classification.getScore() ?? DEFAULT_SCORE, - categoryName: classification.getLabel() ?? '', - displayName: classification.getDisplayName() ?? '', - }; - }) ?? - []; - return { - categories, - headIndex: source.getHeadIndex() ?? DEFAULT_INDEX, - headName: source.getHeadName() ?? '', - }; + return convertFromClassifications( + source.getClassificationList()?.getClassificationList() ?? [], + source.getHeadIndex(), source.getHeadName()); } /**