Add convertFromClassifications() helper

PiperOrigin-RevId: 519181016
This commit is contained in:
Sebastian Schmidt 2023-03-24 10:36:38 -07:00 committed by Copybara-Service
parent cec878df2b
commit 9f6b2cd577
2 changed files with 26 additions and 16 deletions

View File

@ -34,6 +34,7 @@ mediapipe_ts_library(
name = "classifier_result", name = "classifier_result",
srcs = ["classifier_result.ts"], srcs = ["classifier_result.ts"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classification_result", "//mediapipe/tasks/web/components/containers:classification_result",
], ],

View File

@ -14,33 +14,42 @@
* limitations under the License. * limitations under the License.
*/ */
import {Classification as ClassificationProto} from '../../../../framework/formats/classification_pb';
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result'; import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
const DEFAULT_INDEX = -1; const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0; const DEFAULT_SCORE = 0.0;
/**
* Converts a list of Classification protos to a Classifications object.
*/
export function convertFromClassifications(
classifications: ClassificationProto[], headIndex = DEFAULT_INDEX,
headName = ''): Classifications {
const categories = classifications.map(classification => {
return {
index: classification.getIndex() ?? DEFAULT_INDEX,
score: classification.getScore() ?? DEFAULT_SCORE,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
};
});
return {
categories,
headIndex,
headName,
};
}
/** /**
* Converts a Classifications proto to a Classifications object. * Converts a Classifications proto to a Classifications object.
*/ */
function convertFromClassificationsProto(source: ClassificationsProto): function convertFromClassificationsProto(source: ClassificationsProto):
Classifications { Classifications {
const categories = return convertFromClassifications(
source.getClassificationList()?.getClassificationList().map( source.getClassificationList()?.getClassificationList() ?? [],
classification => { source.getHeadIndex(), source.getHeadName());
return {
index: classification.getIndex() ?? DEFAULT_INDEX,
score: classification.getScore() ?? DEFAULT_SCORE,
categoryName: classification.getLabel() ?? '',
displayName: classification.getDisplayName() ?? '',
};
}) ??
[];
return {
categories,
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
headName: source.getHeadName() ?? '',
};
} }
/** /**