Add convertFromClassifications() helper

PiperOrigin-RevId: 519181016
This commit is contained in:
Sebastian Schmidt 2023-03-24 10:36:38 -07:00 committed by Copybara-Service
parent cec878df2b
commit 9f6b2cd577
2 changed files with 26 additions and 16 deletions

View File

@ -34,6 +34,7 @@ mediapipe_ts_library(
name = "classifier_result",
srcs = ["classifier_result.ts"],
deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classification_result",
],

View File

@ -14,33 +14,42 @@
* limitations under the License.
*/
import {Classification as ClassificationProto} from '../../../../framework/formats/classification_pb';
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0;
/**
* Converts a list of Classification protos to a Classifications object.
*/
export function convertFromClassifications(
classifications: ClassificationProto[], headIndex = DEFAULT_INDEX,
headName = ''): Classifications {
const categories = classifications.map(classification => {
return {
index: classification.getIndex() ?? DEFAULT_INDEX,
score: classification.getScore() ?? DEFAULT_SCORE,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
};
});
return {
categories,
headIndex,
headName,
};
}
/**
* Converts a Classifications proto to a Classifications object.
*/
function convertFromClassificationsProto(source: ClassificationsProto):
Classifications {
const categories =
source.getClassificationList()?.getClassificationList().map(
classification => {
return {
index: classification.getIndex() ?? DEFAULT_INDEX,
score: classification.getScore() ?? DEFAULT_SCORE,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
};
}) ??
[];
return {
categories,
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
headName: source.getHeadName() ?? '',
};
return convertFromClassifications(
source.getClassificationList()?.getClassificationList() ?? [],
source.getHeadIndex(), source.getHeadName());
}
/**