diff --git a/mediapipe/dart/dart_builder/.gitignore b/mediapipe/dart/dart_builder/.gitignore index 3a8579040..f5ac5b0ae 100644 --- a/mediapipe/dart/dart_builder/.gitignore +++ b/mediapipe/dart/dart_builder/.gitignore @@ -1,3 +1,4 @@ # https://dart.dev/guides/libraries/private-files # Created by `dart pub` .dart_tool/ +cc/* diff --git a/mediapipe/dart/dart_builder/bin/cpp_test.dart b/mediapipe/dart/dart_builder/bin/cpp_test.dart new file mode 100644 index 000000000..52ac2a718 --- /dev/null +++ b/mediapipe/dart/dart_builder/bin/cpp_test.dart @@ -0,0 +1,36 @@ +// Copyright (c) 2019, the Dart project authors. Please see the AUTHORS file +// for details. All rights reserved. Use of this source code is governed by a +// BSD-style license that can be found in the LICENSE file. + +import 'dart:ffi' as ffi; +import 'dart:io' show Platform, Directory; + +import 'package:path/path.dart' as path; + +// FFI signature of the hello_world C function +typedef IncrementFunc = ffi.Int32 Function(ffi.Int32); +typedef Increment = int Function(int); + +void main() { + // Open the dynamic library + var libraryPath = + path.join(Directory.current.absolute.path, 'cpp', 'main.dylib'); + + // if (Platform.isMacOS) { + // libraryPath = + // path.join(Directory.current.path, 'hello_library', 'libhello.dylib'); + // } + + // if (Platform.isWindows) { + // libraryPath = path.join( + // Directory.current.path, 'hello_library', 'Debug', 'hello.dll'); + // } + + final dylib = ffi.DynamicLibrary.open(libraryPath); + + // Look up the C function 'hello_world' + final Increment increment = + dylib.lookup>('increment').asFunction(); + // Call the function + print(increment(99)); +} diff --git a/mediapipe/dart/dart_builder/lib/dart_builder.dart b/mediapipe/dart/dart_builder/lib/dart_builder.dart index bd6f219da..328b2c832 100644 --- a/mediapipe/dart/dart_builder/lib/dart_builder.dart +++ b/mediapipe/dart/dart_builder/lib/dart_builder.dart @@ -37,13 +37,6 @@ class DartProtoBuilder { repositoryRoot.absolute.path, 'mediapipe', 'dart', 'dart_builder'), ); - _buildDirectory = io.Directory( - path.join( - io.Directory.current.parent.parent.parent.parent.absolute.path, - 'build', - ), - ); - _outputDirectory = options.outputPath != null ? io.Directory(options.outputPath!) : io.Directory( @@ -67,12 +60,11 @@ class DartProtoBuilder { io.Directory get mediapipeDir => _mediapipeDir!; io.Directory? _mediapipeDir; - io.Directory get buildDirectory => _buildDirectory!; - io.Directory? _buildDirectory; - + /// Directory to place compiled protobufs. io.Directory get outputDirectory => _outputDirectory!; io.Directory? _outputDirectory; + /// Location of this command. io.Directory get dartBuilderDirectory => _dartBuilderDirectory!; io.Directory? _dartBuilderDirectory; @@ -100,7 +92,7 @@ class DartProtoBuilder { [ansi.green], ), ); - // await _buildProtos(); + await _buildProtos(); await _buildBarrelFiles(); } @@ -228,16 +220,13 @@ class DartProtoBuilder { ).absolute.path; Future _confirmOutputDirectories() async { - if (!await _buildDirectory!.exists()) { - _buildDirectory!.create(); - } if (!await _outputDirectory!.exists()) { _outputDirectory!.create(); } } Future _prepareProtos() async { - if (!(await outputDirectory.exists())) { + if (!await outputDirectory.exists()) { io.stdout.writeln('Creating output directory'); outputDirectory.create(); } diff --git a/mediapipe/dart/mediapipe/.gitignore b/mediapipe/dart/mediapipe/.gitignore index 3cceda557..871a15102 100644 --- a/mediapipe/dart/mediapipe/.gitignore +++ b/mediapipe/dart/mediapipe/.gitignore @@ -1,6 +1,8 @@ # https://dart.dev/guides/libraries/private-files # Created by `dart pub` .dart_tool/ +lib/generated/google/* +lib/generated/mediapipe/* # Avoid committing pubspec.lock for library packages; see # https://dart.dev/guides/libraries/private-files#pubspeclock. diff --git a/mediapipe/dart/mediapipe/Makefile b/mediapipe/dart/mediapipe/Makefile new file mode 100644 index 000000000..02bc1af0f --- /dev/null +++ b/mediapipe/dart/mediapipe/Makefile @@ -0,0 +1,10 @@ +ffigen: + dart run ffigen --config ffigen.yaml + +compile: + gcc c/text_classifier.c -o c/text_classifier + cd c && gcc -static -c -fPIC *.c -o text_classifier.o + cd c && gcc -shared -o text_classifier.dylib text_classifier.o + +run: + cd c && dart text_classifier_c.dart \ No newline at end of file diff --git a/mediapipe/dart/mediapipe/ffigen.yaml b/mediapipe/dart/mediapipe/ffigen.yaml new file mode 100644 index 000000000..96f1a21cf --- /dev/null +++ b/mediapipe/dart/mediapipe/ffigen.yaml @@ -0,0 +1,8 @@ +name: flutter_mediapipe +description: MediaPipe bindings. + +output: "lib/src/third_party/generated/mediapipe_bindings.dart" +headers: + entry-points: + - "third_party/mediapipe/classification_result.h" + - "third_party/mediapipe/text_classifier.h" diff --git a/mediapipe/dart/mediapipe/lib/src/tasks/core/task_runner.dart b/mediapipe/dart/mediapipe/lib/src/tasks/core/task_runner.dart index 4809a63f7..ca890bbd5 100644 --- a/mediapipe/dart/mediapipe/lib/src/tasks/core/task_runner.dart +++ b/mediapipe/dart/mediapipe/lib/src/tasks/core/task_runner.dart @@ -1,15 +1,35 @@ +import 'dart:ffi' as ffi; +// TODO: This will require a web-specific solution. +import 'dart:io'; +import 'package:path/path.dart' as path; + import '../../../generated/mediapipe/framework/calculator.pb.dart'; +// TODO: Figure out ffi type for Maps +typedef ProcessCC = Map Function(Map data); +typedef Process = Map Function(Map data); + // TODO: Wrap C++ TaskRunner with this, similarly to this Python wrapper: // https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/python/framework_bindings.cc?q=python%20framework_bindings.cc class TaskRunner { - TaskRunner(this.graphConfig); + TaskRunner(this.graphConfig) { + var libraryPath = + path.join(Directory.current.absolute.path, 'cc', 'main.dylib'); + mediaPipe = ffi.DynamicLibrary.open(libraryPath); + } final CalculatorGraphConfig graphConfig; + late ffi.DynamicLibrary mediaPipe; + // TODO: Actually decode this line for correct parameter type: // https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/tasks/python/text/text_classifier.py;l=181 - Map process(Map data) => {}; + Map process(Map data) { + throw UnimplementedError(); + // final Process ccProcess = + // mediaPipe.lookup>('process').asFunction(); + // return ccProcess(data); + } } // TODO: Wrap C++ Packet with this, similarly to this Python wrapper: diff --git a/mediapipe/dart/mediapipe/lib/src/tasks/text/text_classifier.dart b/mediapipe/dart/mediapipe/lib/src/tasks/text/text_classifier.dart index 277919c12..a3821100a 100644 --- a/mediapipe/dart/mediapipe/lib/src/tasks/text/text_classifier.dart +++ b/mediapipe/dart/mediapipe/lib/src/tasks/text/text_classifier.dart @@ -23,6 +23,8 @@ class TextClassifier { /// Configuration options for this [TextClassifier]. final TextClassifierOptions options; + + /// Configuration object passed to the [TaskRunner]. final TaskInfo _taskInfo; TaskRunner get taskRunner => _taskRunner!; @@ -35,11 +37,12 @@ class TextClassifier { static const taskGraphName = 'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; + // TODO: Don't return protobuf objects. Instead, convert to plain Dart objects. /// Performs classification on the input `text`. Future classify(String text) async { // TODO: Actually decode this line to correctly fill up this map parameter // https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/tasks/python/text/text_classifier.py;l=181 - final outputPackets = taskRunner.process({}); + final outputPackets = taskRunner.process({textInStreamName: Object()}); // TODO: Obviously this is not real return tasks_pb.ClassificationResult.create(); diff --git a/mediapipe/dart/mediapipe/lib/src/third_party/generated/mediapipe_bindings.dart b/mediapipe/dart/mediapipe/lib/src/third_party/generated/mediapipe_bindings.dart new file mode 100644 index 000000000..72a6d55ca --- /dev/null +++ b/mediapipe/dart/mediapipe/lib/src/third_party/generated/mediapipe_bindings.dart @@ -0,0 +1,420 @@ +// AUTO GENERATED FILE, DO NOT EDIT. +// +// Generated by `package:ffigen`. +// ignore_for_file: type=lint +import 'dart:ffi' as ffi; + +/// MediaPipe bindings. +class flutter_mediapipe { + /// Holds the symbol lookup function. + final ffi.Pointer Function(String symbolName) + _lookup; + + /// The symbols are looked up in [dynamicLibrary]. + flutter_mediapipe(ffi.DynamicLibrary dynamicLibrary) + : _lookup = dynamicLibrary.lookup; + + /// The symbols are looked up with [lookup]. + flutter_mediapipe.fromLookup( + ffi.Pointer Function(String symbolName) + lookup) + : _lookup = lookup; + + void text_classifier_create( + TextClassifierOptions options, + ) { + return _text_classifier_create( + options, + ); + } + + late final _text_classifier_createPtr = + _lookup>( + 'text_classifier_create'); + late final _text_classifier_create = _text_classifier_createPtr + .asFunction(); + + ffi.Pointer text_classifier_classify( + ffi.Pointer classifier, + ffi.Pointer utf8_text, + ) { + return _text_classifier_classify( + classifier, + utf8_text, + ); + } + + late final _text_classifier_classifyPtr = _lookup< + ffi.NativeFunction< + ffi.Pointer Function(ffi.Pointer, + ffi.Pointer)>>('text_classifier_classify'); + late final _text_classifier_classify = + _text_classifier_classifyPtr.asFunction< + ffi.Pointer Function( + ffi.Pointer, ffi.Pointer)>(); + + ffi.Pointer text_classifier_classify_simple( + ffi.Pointer utf8_text, + ) { + return _text_classifier_classify_simple( + utf8_text, + ); + } + + late final _text_classifier_classify_simplePtr = _lookup< + ffi.NativeFunction< + ffi.Pointer Function( + ffi.Pointer)>>('text_classifier_classify_simple'); + late final _text_classifier_classify_simple = + _text_classifier_classify_simplePtr.asFunction< + ffi.Pointer Function(ffi.Pointer)>(); + + void text_classifier_close( + ffi.Pointer classifier, + ) { + return _text_classifier_close( + classifier, + ); + } + + late final _text_classifier_closePtr = + _lookup)>>( + 'text_classifier_close'); + late final _text_classifier_close = _text_classifier_closePtr + .asFunction)>(); +} + +final class __mbstate_t extends ffi.Union { + @ffi.Array.multi([128]) + external ffi.Array __mbstate8; + + @ffi.LongLong() + external int _mbstateL; +} + +final class __darwin_pthread_handler_rec extends ffi.Struct { + external ffi + .Pointer)>> + __routine; + + external ffi.Pointer __arg; + + external ffi.Pointer<__darwin_pthread_handler_rec> __next; +} + +final class _opaque_pthread_attr_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([56]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_cond_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([40]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_condattr_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([8]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_mutex_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([56]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_mutexattr_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([8]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_once_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([8]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_rwlock_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([192]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_rwlockattr_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + @ffi.Array.multi([16]) + external ffi.Array __opaque; +} + +final class _opaque_pthread_t extends ffi.Struct { + @ffi.Long() + external int __sig; + + external ffi.Pointer<__darwin_pthread_handler_rec> __cleanup_stack; + + @ffi.Array.multi([8176]) + external ffi.Array __opaque; +} + +final class Category extends ffi.Struct { + @ffi.Int() + external int index; + + @ffi.Float() + external double score; + + external ffi.Pointer category_name; + + external ffi.Pointer display_name; +} + +final class Classifications extends ffi.Struct { + external ffi.Pointer categories; + + @ffi.Uint32() + external int categories_count; + + @ffi.Int() + external int head_index; + + external ffi.Pointer head_name; +} + +final class ClassificationResult extends ffi.Struct { + external ffi.Pointer classifications; + + @ffi.Uint32() + external int classifications_count; + + @ffi.Int64() + external int timestamp_ms; + + @ffi.Bool() + external bool has_timestamp_ms; +} + +final class BaseOptions extends ffi.Struct { + external ffi.Pointer model_asset_buffer; + + external ffi.Pointer model_asset_path; +} + +final class ClassifierOptions extends ffi.Struct { + external ffi.Pointer display_names_locale; + + @ffi.Int() + external int max_results; + + @ffi.Float() + external double score_threshold; + + external ffi.Pointer> category_allowlist; + + @ffi.Uint32() + external int category_allowlist_count; + + external ffi.Pointer> category_denylist; + + @ffi.Uint32() + external int category_denylist_count; +} + +final class TextClassifierOptions extends ffi.Struct { + external BaseOptions base_options; + + external ClassifierOptions classifier_options; +} + +typedef TextClassifierResult = ClassificationResult; + +const int true1 = 1; + +const int false1 = 0; + +const int __bool_true_false_are_defined = 1; + +const int __WORDSIZE = 64; + +const int __DARWIN_ONLY_64_BIT_INO_T = 1; + +const int __DARWIN_ONLY_UNIX_CONFORMANCE = 1; + +const int __DARWIN_ONLY_VERS_1050 = 1; + +const int __DARWIN_UNIX03 = 1; + +const int __DARWIN_64_BIT_INO_T = 1; + +const int __DARWIN_VERS_1050 = 1; + +const int __DARWIN_NON_CANCELABLE = 0; + +const String __DARWIN_SUF_EXTSN = '\$DARWIN_EXTSN'; + +const int __DARWIN_C_ANSI = 4096; + +const int __DARWIN_C_FULL = 900000; + +const int __DARWIN_C_LEVEL = 900000; + +const int __STDC_WANT_LIB_EXT1__ = 1; + +const int __DARWIN_NO_LONG_LONG = 0; + +const int _DARWIN_FEATURE_64_BIT_INODE = 1; + +const int _DARWIN_FEATURE_ONLY_64_BIT_INODE = 1; + +const int _DARWIN_FEATURE_ONLY_VERS_1050 = 1; + +const int _DARWIN_FEATURE_ONLY_UNIX_CONFORMANCE = 1; + +const int _DARWIN_FEATURE_UNIX_CONFORMANCE = 3; + +const int __has_ptrcheck = 0; + +const int __DARWIN_NULL = 0; + +const int __PTHREAD_SIZE__ = 8176; + +const int __PTHREAD_ATTR_SIZE__ = 56; + +const int __PTHREAD_MUTEXATTR_SIZE__ = 8; + +const int __PTHREAD_MUTEX_SIZE__ = 56; + +const int __PTHREAD_CONDATTR_SIZE__ = 8; + +const int __PTHREAD_COND_SIZE__ = 40; + +const int __PTHREAD_ONCE_SIZE__ = 8; + +const int __PTHREAD_RWLOCK_SIZE__ = 192; + +const int __PTHREAD_RWLOCKATTR_SIZE__ = 16; + +const int USER_ADDR_NULL = 0; + +const int INT8_MAX = 127; + +const int INT16_MAX = 32767; + +const int INT32_MAX = 2147483647; + +const int INT64_MAX = 9223372036854775807; + +const int INT8_MIN = -128; + +const int INT16_MIN = -32768; + +const int INT32_MIN = -2147483648; + +const int INT64_MIN = -9223372036854775808; + +const int UINT8_MAX = 255; + +const int UINT16_MAX = 65535; + +const int UINT32_MAX = 4294967295; + +const int UINT64_MAX = -1; + +const int INT_LEAST8_MIN = -128; + +const int INT_LEAST16_MIN = -32768; + +const int INT_LEAST32_MIN = -2147483648; + +const int INT_LEAST64_MIN = -9223372036854775808; + +const int INT_LEAST8_MAX = 127; + +const int INT_LEAST16_MAX = 32767; + +const int INT_LEAST32_MAX = 2147483647; + +const int INT_LEAST64_MAX = 9223372036854775807; + +const int UINT_LEAST8_MAX = 255; + +const int UINT_LEAST16_MAX = 65535; + +const int UINT_LEAST32_MAX = 4294967295; + +const int UINT_LEAST64_MAX = -1; + +const int INT_FAST8_MIN = -128; + +const int INT_FAST16_MIN = -32768; + +const int INT_FAST32_MIN = -2147483648; + +const int INT_FAST64_MIN = -9223372036854775808; + +const int INT_FAST8_MAX = 127; + +const int INT_FAST16_MAX = 32767; + +const int INT_FAST32_MAX = 2147483647; + +const int INT_FAST64_MAX = 9223372036854775807; + +const int UINT_FAST8_MAX = 255; + +const int UINT_FAST16_MAX = 65535; + +const int UINT_FAST32_MAX = 4294967295; + +const int UINT_FAST64_MAX = -1; + +const int INTPTR_MAX = 9223372036854775807; + +const int INTPTR_MIN = -9223372036854775808; + +const int UINTPTR_MAX = -1; + +const int INTMAX_MAX = 9223372036854775807; + +const int UINTMAX_MAX = -1; + +const int INTMAX_MIN = -9223372036854775808; + +const int PTRDIFF_MIN = -9223372036854775808; + +const int PTRDIFF_MAX = 9223372036854775807; + +const int SIZE_MAX = -1; + +const int RSIZE_MAX = 9223372036854775807; + +const int WCHAR_MAX = 2147483647; + +const int WCHAR_MIN = -2147483648; + +const int WINT_MIN = -2147483648; + +const int WINT_MAX = 2147483647; + +const int SIG_ATOMIC_MIN = -2147483648; + +const int SIG_ATOMIC_MAX = 2147483647; diff --git a/mediapipe/dart/mediapipe/pubspec.yaml b/mediapipe/dart/mediapipe/pubspec.yaml index 8c96f711f..938d7a11f 100644 --- a/mediapipe/dart/mediapipe/pubspec.yaml +++ b/mediapipe/dart/mediapipe/pubspec.yaml @@ -7,10 +7,12 @@ environment: # Add regular dependencies here. dependencies: + ffi: ^2.0.2 fixnum: ^1.1.0 path: ^1.8.3 protobuf: ^3.0.0 dev_dependencies: + ffigen: ^9.0.1 lints: ^2.0.0 test: ^1.21.0 diff --git a/mediapipe/dart/mediapipe/third_party/mediapipe/base_options.h b/mediapipe/dart/mediapipe/third_party/mediapipe/base_options.h new file mode 100644 index 000000000..7adf04503 --- /dev/null +++ b/mediapipe/dart/mediapipe/third_party/mediapipe/base_options.h @@ -0,0 +1,25 @@ +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ +#define THIRD_PARTY_MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ + +// Base options for MediaPipe C Tasks. +struct BaseOptions +{ + // The model asset file contents as a string. + char *model_asset_buffer; + + // The path to the model asset to open and mmap in memory. + char *model_asset_path; +}; +#endif // THIRD_PARTY_MEDIAPIPE_TASKS_C_CORE_BASE_OPTIONS_H_ diff --git a/mediapipe/dart/mediapipe/third_party/mediapipe/category.h b/mediapipe/dart/mediapipe/third_party/mediapipe/category.h new file mode 100644 index 000000000..550b3c5c0 --- /dev/null +++ b/mediapipe/dart/mediapipe/third_party/mediapipe/category.h @@ -0,0 +1,35 @@ + +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ +#define THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ +// Defines a single classification result. +// +// The label maps packed into the TFLite Model Metadata [1] are used to populate +// the 'category_name' and 'display_name' fields. +// +// [1]: https://www.tensorflow.org/lite/convert/metadata +struct Category { + // The index of the category in the classification model output. + int index; + // The score for this category, e.g. (but not necessarily) a probability in + // [0,1]. + float score; + // The optional ID for the category, read from the label map packed in the + // TFLite Model Metadata if present. Not necessarily human-readable. + char *category_name; + // The optional human-readable name for the category, read from the label map + // packed in the TFLite Model Metadata if present. + char *display_name; +}; + +#endif // THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CATEGORY_H_ diff --git a/mediapipe/dart/mediapipe/third_party/mediapipe/classification_result.h b/mediapipe/dart/mediapipe/third_party/mediapipe/classification_result.h new file mode 100644 index 000000000..472125f7e --- /dev/null +++ b/mediapipe/dart/mediapipe/third_party/mediapipe/classification_result.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ +#define THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ +#include +#include +#include "category.h" + +// Defines classification results for a given classifier head. +struct Classifications +{ + // The array of predicted categories, usually sorted by descending scores, + // e.g. from high to low probability. + struct Category *categories; + + // The number of elements in the categories array. + uint32_t categories_count; + + // The index of the classifier head (i.e. output tensor) these categories + // refer to. This is useful for multi-head models. + int head_index; + + // The optional name of the classifier head, as provided in the TFLite Model + // Metadata [1] if present. This is useful for multi-head models. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + char *head_name; +}; + +// Defines classification results of a model. +struct ClassificationResult +{ + // The classification results for each head of the model. + struct Classifications *classifications; + + // The number of classifications in the classifications array. + uint32_t classifications_count; + + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for classification on time series (e.g. audio + // classification). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + int64_t timestamp_ms; + + // Specifies whether the timestamp contains a valid value. + bool has_timestamp_ms; +}; + +#endif // THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ diff --git a/mediapipe/dart/mediapipe/third_party/mediapipe/classifier_options.h b/mediapipe/dart/mediapipe/third_party/mediapipe/classifier_options.h new file mode 100644 index 000000000..ca1748246 --- /dev/null +++ b/mediapipe/dart/mediapipe/third_party/mediapipe/classifier_options.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#include + +// Classifier options for MediaPipe C classification Tasks. +struct ClassifierOptions { + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + char *display_names_locale; + + // The maximum number of top-scored classification results to return. If < 0, + // all available results will be returned. If 0, an invalid argument error is + // returned. + int max_results; + + // Score threshold to override the one provided in the model metadata (if + // any). Results below this value are rejected. + float score_threshold; + + // The allowlist of category names. If non-empty, detection results whose + // category name is not in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_denylist. + char **category_allowlist; + + // The number of elements in the category allowlist. + uint32_t category_allowlist_count; + + // The denylist of category names. If non-empty, detection results whose + // category name is in this set will be filtered out. Duplicate or unknown + // category names are ignored. Mutually exclusive with category_allowlist. + char **category_denylist; + + // The number of elements in the category denylist. + uint32_t category_denylist_count; +}; +#endif // THIRD_PARTY_MEDIAPIPE_TASKS_C_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/dart/mediapipe/third_party/mediapipe/text_classifier.h b/mediapipe/dart/mediapipe/third_party/mediapipe/text_classifier.h new file mode 100644 index 000000000..a85f5b611 --- /dev/null +++ b/mediapipe/dart/mediapipe/third_party/mediapipe/text_classifier.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#define THIRD_PARTY_MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#include "base_options.h" +#include "classification_result.h" +#include "classifier_options.h" + +typedef struct ClassificationResult TextClassifierResult; + +// The options for configuring a MediaPipe text classifier task. +struct TextClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + struct BaseOptions base_options; + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + struct ClassifierOptions classifier_options; +}; + +// void *text_classifier_options_create(); + +// Creates a TextClassifier from the provided `options`. +void *text_classifier_create(struct TextClassifierOptions *options); + +// Performs classification on the input `text`. +TextClassifierResult *text_classifier_classify(void *classifier, + char *utf8_text); + +// Shuts down the TextClassifier when all the work is done. Frees all memory. +void text_classifier_close(void *classifier); +void text_classifier_result_close(TextClassifierResult *result); + +#endif // THIRD_PARTY_MEDIAPIPE_TASKS_C_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_