add beginning of Dart-specific MediaPipe implementation

This commit is contained in:
Craig Labenz 2023-07-19 17:08:54 -07:00
parent d1b7e960ee
commit 8502ed9806
14 changed files with 408 additions and 0 deletions

7
mediapipe/dart/mediapipe/.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
# https://dart.dev/guides/libraries/private-files
# Created by `dart pub`
.dart_tool/
# Avoid committing pubspec.lock for library packages; see
# https://dart.dev/guides/libraries/private-files#pubspeclock.
pubspec.lock

View File

@ -0,0 +1,3 @@
## 1.0.0
- Initial version.

View File

@ -0,0 +1,39 @@
<!--
This README describes the package. If you publish this package to pub.dev,
this README's contents appear on the landing page for your package.
For information about how to write a good package README, see the guide for
[writing package pages](https://dart.dev/guides/libraries/writing-package-pages).
For general information about developing packages, see the Dart guide for
[creating packages](https://dart.dev/guides/libraries/create-library-packages)
and the Flutter guide for
[developing packages and plugins](https://flutter.dev/developing-packages).
-->
TODO: Put a short description of the package here that helps potential users
know whether this package might be useful for them.
## Features
TODO: List what your package can do. Maybe include images, gifs, or videos.
## Getting started
TODO: List prerequisites and provide or point to information on how to
start using the package.
## Usage
TODO: Include short and useful examples for package users. Add longer examples
to `/example` folder.
```dart
const like = 'sample';
```
## Additional information
TODO: Tell users more about the package: where to find more information, how to
contribute to the package, how to file issues, what response they can expect
from the package authors, and more.

View File

@ -0,0 +1,30 @@
# This file configures the static analysis results for your project (errors,
# warnings, and lints).
#
# This enables the 'recommended' set of lints from `package:lints`.
# This set helps identify many issues that may lead to problems when running
# or consuming Dart code, and enforces writing Dart using a single, idiomatic
# style and format.
#
# If you want a smaller set of lints you can change this to specify
# 'package:lints/core.yaml'. These are just the most critical lints
# (the recommended set includes the core lints).
# The core lints are also what is used by pub.dev for scoring packages.
include: package:lints/recommended.yaml
# Uncomment the following section to specify additional rules.
# linter:
# rules:
# - camel_case_types
# analyzer:
# exclude:
# - path/to/excluded/files/**
# For more information about the core and recommended set of lints, see
# https://dart.dev/go/core-lints
# For additional information about configuring this file, see
# https://dart.dev/guides/language/analysis-options

View File

@ -0,0 +1,5 @@
To generate protocol buffers for Dart (in this folder), navigate to the `dart_builder` directory and run:
```
$ dart bin/dart_builder.dart
```

View File

@ -0,0 +1,3 @@
library mediapipe;
export 'src/tasks/tasks.dart';

View File

@ -0,0 +1,83 @@
import 'dart:io' as io;
import 'dart:typed_data';
import 'package:path/path.dart' as path;
import 'package:protobuf/protobuf.dart' as $pb;
import '../../../generated/mediapipe/calculators/calculators.dart';
import '../../../generated/mediapipe/tasks/tasks.dart' as tasks_pb;
/// Class to extend in task-specific *Options classes. Funnels the three
/// [BaseOptions] attributes into their own object.
abstract class TaskOptions {
TaskOptions({this.modelAssetBuffer, this.modelAssetPath, this.delegate})
: baseOptions = BaseOptions(
delegate: delegate,
modelAssetBuffer: modelAssetBuffer,
modelAssetPath: modelAssetPath,
);
final Uint8List? modelAssetBuffer;
final String? modelAssetPath;
final Delegate? delegate;
final BaseOptions baseOptions;
$pb.GeneratedMessage toProto();
/// In proto2 syntax, extensions are unique IDs, suitable for keys in a hash
/// map, which power the Extensions pattern for protos to house arbitrary
/// extended data.
///
/// In proto3, this pattern is replaced with the [Any] protobuf, as the
/// convention for setting the unique identifiers surpassed the maximum upper
/// bound of 29 bits as allocated in the protobuf spec.
$pb.Extension get ext;
}
final class BaseOptions {
BaseOptions({this.modelAssetBuffer, this.modelAssetPath, this.delegate});
/// The model asset file contents as bytes;
Uint8List? modelAssetBuffer;
/// Path to the model asset file.
String? modelAssetPath;
/// Acceleration strategy to use. GPU support is currently limited to
/// Ubuntu platform.
Delegate? delegate;
/// See also: https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/tasks/python/core/base_options.py;l=63-89;rcl=548857458
tasks_pb.BaseOptions toProto() {
String? absModelPath =
modelAssetPath != null ? path.absolute(modelAssetPath!) : null;
if (!io.Platform.isLinux && delegate == Delegate.gpu) {
throw Exception(
'GPU Delegate is not yet supported for ${io.Platform.operatingSystem}',
);
}
tasks_pb.Acceleration? acceleration;
if (delegate == Delegate.cpu) {
acceleration = tasks_pb.Acceleration.create()
..tflite = InferenceCalculatorOptions_Delegate_TfLite.create();
}
final modelAsset = tasks_pb.ExternalFile.create();
if (absModelPath != null) {
modelAsset.fileName = absModelPath;
}
if (modelAssetBuffer != null) {
modelAsset.fileContent = modelAssetBuffer!;
}
final options = tasks_pb.BaseOptions.create()..modelAsset = modelAsset;
if (acceleration != null) {
options.acceleration = acceleration;
}
return options;
}
}
/// Hardware location to perform the given task.
enum Delegate { cpu, gpu }

View File

@ -0,0 +1,3 @@
export 'base_options.dart';
export 'task_info.dart';
export 'task_runner.dart';

View File

@ -0,0 +1,91 @@
import 'package:mediapipe/generated/mediapipe/calculators/calculators.dart';
import '../../../generated/mediapipe/framework/framework.dart';
import '../tasks.dart';
class TaskInfo<T extends TaskOptions> {
TaskInfo({
required this.taskGraph,
required this.inputStreams,
required this.outputStreams,
required this.taskOptions,
});
String taskGraph;
List<String> inputStreams;
List<String> outputStreams;
T taskOptions;
CalculatorGraphConfig generateGraphConfig({
bool enableFlowLimiting = false,
}) {
assert(inputStreams.isNotEmpty, 'TaskInfo.inputStreams must be non-empty');
assert(
outputStreams.isNotEmpty,
'TaskInfo.outputStreams must be non-empty',
);
FlowLimiterCalculatorOptions.ext;
final taskSubgraphOptions = CalculatorOptions();
taskSubgraphOptions.addExtension(taskOptions.ext, taskOptions.toProto());
if (!enableFlowLimiting) {
return CalculatorGraphConfig.create()
..node.add(
CalculatorGraphConfig_Node.create()
..calculator = taskGraph
..inputStream.addAll(inputStreams)
..outputStream.addAll(outputStreams)
..options = taskSubgraphOptions,
);
}
// When a FlowLimiterCalculator is inserted to lower the overall graph
// latency, the task doesn't guarantee that each input must have the
// corresponding output.
final taskSubgraphInputs =
inputStreams.map<String>(_addStreamNamePrefix).toList();
String finishedStream = 'FINISHED: ${_stripTagIndex(outputStreams.first)}';
final flowLimiterOptions = CalculatorOptions.create();
flowLimiterOptions.setExtension(
FlowLimiterCalculatorOptions.ext,
FlowLimiterCalculatorOptions.create()
..maxInFlight = 1
..maxInQueue = 1,
);
final flowLimiter = CalculatorGraphConfig_Node.create()
..calculator = 'FlowLimiterCalculator'
..inputStreamInfo.add(
InputStreamInfo.create()
..tagIndex = 'FINISHED'
..backEdge = true,
)
..inputStream.addAll(inputStreams.map<String>(_stripTagIndex).toList())
..inputStream.add(finishedStream)
..outputStream.addAll(
taskSubgraphInputs.map<String>(_stripTagIndex).toList(),
)
..options = flowLimiterOptions;
final config = CalculatorGraphConfig.create()
..node.add(
CalculatorGraphConfig_Node.create()
..calculator = taskGraph
..inputStream.addAll(taskSubgraphInputs)
..outputStream.addAll(outputStreams)
..options = taskSubgraphOptions,
)
..node.add(flowLimiter)
..inputStream.addAll(inputStreams)
..outputStream.addAll(outputStreams);
return config;
}
}
String _stripTagIndex(String tagIndexName) => tagIndexName.split(':').last;
String _addStreamNamePrefix(String tagIndexName) {
final split = tagIndexName.split(':');
split.last = 'trottled_${split.last}';
return split.join(':');
}

View File

@ -0,0 +1,17 @@
import '../../../generated/mediapipe/framework/calculator.pb.dart';
// TODO: Wrap C++ TaskRunner with this, similarly to this Python wrapper:
// https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/python/framework_bindings.cc?q=python%20framework_bindings.cc
class TaskRunner {
TaskRunner(this.graphConfig);
final CalculatorGraphConfig graphConfig;
// TODO: Actually decode this line for correct parameter type:
// https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/tasks/python/text/text_classifier.py;l=181
Map<String, Packet> process(Map<String, Object> data) => {};
}
// TODO: Wrap C++ Packet with this, similarly to this Python wrapper:
// https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/python/pybind/packet.h
class Packet {}

View File

@ -0,0 +1,2 @@
export 'core/core.dart';
export 'text/text.dart';

View File

@ -0,0 +1 @@
export 'text_classifier.dart';

View File

@ -0,0 +1,108 @@
import 'package:protobuf/protobuf.dart' as $pb;
import '../../../generated/mediapipe/tasks/tasks.dart' as tasks_pb;
import '../tasks.dart';
class TextClassifier {
/// Primary constructor for [TextClassifier].
TextClassifier(this.options)
: _taskInfo = TaskInfo(
taskGraph: taskGraphName,
inputStreams: <String>['$textTag:$textInStreamName'],
outputStreams: <String>[
'$classificationsTag:$classificationsStreamName'
],
taskOptions: options,
) {
_taskRunner = TaskRunner(_taskInfo.generateGraphConfig());
}
/// Shortcut constructor which only accepts a local path to the model.
factory TextClassifier.fromAssetPath(String assetPath) => TextClassifier(
TextClassifierOptions(modelAssetPath: assetPath),
);
/// Configuration options for this [TextClassifier].
final TextClassifierOptions options;
final TaskInfo _taskInfo;
TaskRunner get taskRunner => _taskRunner!;
TaskRunner? _taskRunner;
static const classificationsStreamName = 'classifications_out';
static const classificationsTag = 'CLASSIFICATIONS';
static const textTag = 'TEXT';
static const textInStreamName = 'text_in';
static const taskGraphName =
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
/// Performs classification on the input `text`.
Future<tasks_pb.ClassificationResult> classify(String text) async {
// TODO: Actually decode this line to correctly fill up this map parameter
// https://source.corp.google.com/piper///depot/google3/third_party/mediapipe/tasks/python/text/text_classifier.py;l=181
final outputPackets = taskRunner.process({});
// TODO: Obviously this is not real
return tasks_pb.ClassificationResult.create();
}
}
class TextClassifierOptions extends TaskOptions {
TextClassifierOptions({
this.displayNamesLocale,
this.maxResults,
this.scoreThreshold,
this.categoryAllowlist,
this.categoryDenylist,
super.modelAssetBuffer,
super.modelAssetPath,
super.delegate,
});
/// The locale to use for display names specified through the TFLite Model
/// Metadata.
String? displayNamesLocale;
/// The maximum number of top-scored classification results to return.
int? maxResults;
/// Overrides the ones provided in the model metadata. Results below this
/// value are rejected.
double? scoreThreshold;
/// Allowlist of category names. If non-empty, classification results whose
/// category name is not in this set will be discarded. Duplicate or unknown
/// category names are ignored. Mutually exclusive with `categoryDenylist`.
List<String>? categoryAllowlist;
/// Denylist of category names. If non-empty, classification results whose
/// category name is in this set will be discarded. Duplicate or unknown
/// category names are ignored. Mutually exclusive with `categoryAllowList`.
List<String>? categoryDenylist;
@override
$pb.Extension get ext => tasks_pb.TextClassifierGraphOptions.ext;
@override
tasks_pb.TextClassifierGraphOptions toProto() {
final classifierOptions = tasks_pb.ClassifierOptions.create();
if (displayNamesLocale != null) {
classifierOptions.displayNamesLocale = displayNamesLocale!;
}
if (maxResults != null) {
classifierOptions.maxResults = maxResults!;
}
if (scoreThreshold != null) {
classifierOptions.scoreThreshold = scoreThreshold!;
}
if (categoryAllowlist != null) {
classifierOptions.categoryAllowlist.addAll(categoryAllowlist!);
}
if (categoryDenylist != null) {
classifierOptions.categoryDenylist.addAll(categoryDenylist!);
}
return tasks_pb.TextClassifierGraphOptions.create()
..baseOptions = baseOptions.toProto()
..classifierOptions = classifierOptions;
}
}

View File

@ -0,0 +1,16 @@
name: mediapipe
description: Flutter plugin for Google's MediaPipe.
version: 1.0.0
# repository: https://github.com/my_org/my_repo
environment:
sdk: ^3.1.0-149.0.dev
# Add regular dependencies here.
dependencies:
fixnum: ^1.1.0
path: ^1.8.3
protobuf: ^3.0.0
dev_dependencies:
lints: ^2.0.0
test: ^1.21.0