Add convertFromClassifications() helper

PiperOrigin-RevId: 519181016
This commit is contained in:
Sebastian Schmidt 2023-03-24 10:36:38 -07:00 committed by Copybara-Service
parent cec878df2b
commit 9f6b2cd577
2 changed files with 26 additions and 16 deletions

View File

@ -34,6 +34,7 @@ mediapipe_ts_library(
name = "classifier_result", name = "classifier_result",
srcs = ["classifier_result.ts"], srcs = ["classifier_result.ts"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
], ],

View File

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
import {Classification as ClassificationProto} from '../../../../framework/formats/classification_pb';
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result'; import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
@ -21,28 +22,36 @@ const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0; const DEFAULT_SCORE = 0.0;
/** /**
* Converts a Classifications proto to a Classifications object. * Converts a list of Classification protos to a Classifications object.
*/ */
function convertFromClassificationsProto(source: ClassificationsProto): export function convertFromClassifications(
Classifications { classifications: ClassificationProto[], headIndex = DEFAULT_INDEX,
const categories = headName = ''): Classifications {
source.getClassificationList()?.getClassificationList().map( const categories = classifications.map(classification => {
classification => {
return { return {
index: classification.getIndex() ?? DEFAULT_INDEX, index: classification.getIndex() ?? DEFAULT_INDEX,
score: classification.getScore() ?? DEFAULT_SCORE, score: classification.getScore() ?? DEFAULT_SCORE,
categoryName: classification.getLabel() ?? '', categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '', displayName: classification.getDisplayName() ?? '',
}; };
}) ?? });
[];
return { return {
categories, categories,
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX, headIndex,
headName: source.getHeadName() ?? '', headName,
}; };
} }
/**
* Converts a Classifications proto to a Classifications object.
*/
function convertFromClassificationsProto(source: ClassificationsProto):
Classifications {
return convertFromClassifications(
source.getClassificationList()?.getClassificationList() ?? [],
source.getHeadIndex(), source.getHeadName());
}
/** /**
* Converts a ClassificationResult proto to a ClassificationResult object. * Converts a ClassificationResult proto to a ClassificationResult object.
*/ */