Add C Headers for Text Classifier

PiperOrigin-RevId: 551618765
This commit is contained in:
Sebastian Schmidt 2023-07-27 13:07:40 -07:00 committed by Copybara-Service
parent 5b31f1e3e9
commit fdea10d230
9 changed files with 328 additions and 0 deletions

View File

@ -0,0 +1,29 @@
# TODO: describe this package.
# Copyright 2022 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "category",
hdrs = ["category.h"],
)
cc_library(
name = "classification_result",
hdrs = ["classification_result.h"],
)

View File

@ -0,0 +1,42 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_
// Defines a single classification result.
//
// The label maps packed into the TFLite Model Metadata [1] are used to populate
// the 'category_name' and 'display_name' fields.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
struct Category {
// The index of the category in the classification model output.
int index;
// The score for this category, e.g. (but not necessarily) a probability in
// [0,1].
float score;
// The optional ID for the category, read from the label map packed in the
// TFLite Model Metadata if present. Not necessarily human-readable.
char* category_name;
// The optional human-readable name for the category, read from the label map
// packed in the TFLite Model Metadata if present.
char* display_name;
};
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_

View File

@ -0,0 +1,60 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
#include <stdbool.h>
#include <stdint.h>
// Defines classification results for a given classifier head.
struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
// e.g. from high to low probability.
struct Category* categories;
// The number of elements in the categories array.
uint32_t categories_count;
// The index of the classifier head (i.e. output tensor) these categories
// refer to. This is useful for multi-head models.
int head_index;
// The optional name of the classifier head, as provided in the TFLite Model
// Metadata [1] if present. This is useful for multi-head models.
//
// [1]: https://www.tensorflow.org/lite/convert/metadata
char* head_name;
};
// Defines classification results of a model.
struct ClassificationResult {
// The classification results for each head of the model.
struct Classifications* classifications;
// The number of classifications in the classifications array.
uint32_t classifications_count;
// The optional timestamp (in milliseconds) of the start of the chunk of data
// corresponding to these results.
//
// This is only used for classification on time series (e.g. audio
// classification). In these use cases, the amount of data to process might
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
int64_t timestamp_ms;
// Specifies whether the timestamp contains a valid value.
bool has_timestamp_ms;
};
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_

View File

@ -0,0 +1,22 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "classifier_options",
hdrs = ["classifier_options.h"],
)

View File

@ -0,0 +1,51 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#define MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#include <stdint.h>
// Classifier options for MediaPipe C classification Tasks.
struct ClassifierOptions {
// The locale to use for display names specified through the TFLite Model
// Metadata, if any. Defaults to English.
char* display_names_locale;
// The maximum number of top-scored classification results to return. If < 0,
// all available results will be returned. If 0, an invalid argument error is
// returned.
int max_results;
// Score threshold to override the one provided in the model metadata (if
// any). Results below this value are rejected.
float score_threshold;
// The allowlist of category names. If non-empty, detection results whose
// category name is not in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_denylist.
char** category_allowlist;
// The number of elements in the category allowlist.
uint32_t category_allowlist_count;
// The denylist of category names. If non-empty, detection results whose
// category name is in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_allowlist.
char** category_denylist = {};
// The number of elements in the category denylist.
uint32_t category_denylist_count;
};
#endif // MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_

View File

@ -0,0 +1,22 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "base_options",
hdrs = ["base_options.h"],
)

View File

@ -0,0 +1,28 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
#define MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_
// Base options for MediaPipe C Tasks.
struct BaseOptions {
// The model asset file contents as a string.
char* model_asset_buffer;
// The path to the model asset to open and mmap in memory.
char* model_asset_path;
};
#endif // MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_

View File

@ -0,0 +1,28 @@
# Copyright 2023 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "text_classifier",
hdrs = ["text_classifier.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/c/components/containers:classification_result",
"//mediapipe/tasks/c/components/processors:classifier_options",
"//mediapipe/tasks/c/core:base_options",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
#define MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
#include "mediapipe/tasks/c/components/containers/classification_result.h"
#include "mediapipe/tasks/c/components/processors/classifier_options.h"
#include "mediapipe/tasks/c/core/base_options.h"
typedef ClassificationResult TextClassifierResult;
// The options for configuring a MediaPipe text classifier task.
struct TextClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
struct BaseOptions base_options;
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
struct ClassifierOptions classifier_options;
};
// Creates a TextClassifier from the provided `options`.
void* text_classsifier_create(struct TextClassifierOptions options);
// Performs classification on the input `text`.
TextClassifierResult text_classifier_classify(void* classifier,
char* utf8_text);
// Shuts down the TextClassifier when all the work is done. Frees all memory.
void text_classsifier_close(void* classifier);
#endif // MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_