Internal change
PiperOrigin-RevId: 479544054
This commit is contained in:
parent
db524adf0d
commit
d90daa859f
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.containers.proto;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||||
|
option java_outer_classname = "CategoryProto";
|
||||||
|
|
||||||
// A single classification result.
|
// A single classification result.
|
||||||
message Category {
|
message Category {
|
||||||
// The index of the category in the corresponding label map, usually packed in
|
// The index of the category in the corresponding label map, usually packed in
|
||||||
|
|
|
@ -19,6 +19,9 @@ package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||||
|
option java_outer_classname = "ClassificationsProto";
|
||||||
|
|
||||||
// List of predicted categories with an optional timestamp.
|
// List of predicted categories with an optional timestamp.
|
||||||
message ClassificationEntry {
|
message ClassificationEntry {
|
||||||
// The array of predicted categories, usually sorted by descending scores,
|
// The array of predicted categories, usually sorted by descending scores,
|
||||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.processors.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||||
|
option java_outer_classname = "ClassifierOptionsProto";
|
||||||
|
|
||||||
// Shared options used by all classification tasks.
|
// Shared options used by all classification tasks.
|
||||||
message ClassifierOptions {
|
message ClassifierOptions {
|
||||||
// The locale to use for display names specified through the TFLite Model
|
// The locale to use for display names specified through the TFLite Model
|
||||||
|
|
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
|
option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto";
|
||||||
|
option java_outer_classname = "ImageClassifierGraphOptionsProto";
|
||||||
|
|
||||||
message ImageClassifierGraphOptions {
|
message ImageClassifierGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ImageClassifierGraphOptions ext = 456383383;
|
optional ImageClassifierGraphOptions ext = 456383383;
|
||||||
|
|
|
@ -34,3 +34,23 @@ android_library(
|
||||||
"@maven//:com_google_guava_guava",
|
"@maven//:com_google_guava_guava",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "classification_entry",
|
||||||
|
srcs = ["ClassificationEntry.java"],
|
||||||
|
deps = [
|
||||||
|
":category",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "classifications",
|
||||||
|
srcs = ["Classifications.java"],
|
||||||
|
deps = [
|
||||||
|
":classification_entry",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a list of predicted categories with an optional timestamp. Typically used as result
|
||||||
|
* for classification tasks.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ClassificationEntry {
|
||||||
|
/**
|
||||||
|
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
|
||||||
|
* timestamp.
|
||||||
|
*
|
||||||
|
* @param categories the list of {@link Category} objects that contain category name, display
|
||||||
|
* name, score and label index.
|
||||||
|
* @param timestampMs the {@link long} representing the timestamp for which these categories were
|
||||||
|
* obtained.
|
||||||
|
*/
|
||||||
|
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
|
||||||
|
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** The list of predicted {@link Category} objects, sorted by descending score. */
|
||||||
|
public abstract List<Category> categories();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
|
||||||
|
* series use cases, e.g. audio classification.
|
||||||
|
*/
|
||||||
|
public abstract long timestampMs();
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.containers;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents the list of classification for a given classifier head. Typically used as a result for
|
||||||
|
* classification tasks.
|
||||||
|
*/
|
||||||
|
@AutoValue
|
||||||
|
public abstract class Classifications {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link Classifications} instance.
|
||||||
|
*
|
||||||
|
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
|
||||||
|
* categories.
|
||||||
|
* @param headIndex the index of the classifier head.
|
||||||
|
* @param headName the name of the classifier head.
|
||||||
|
*/
|
||||||
|
public static Classifications create(
|
||||||
|
List<ClassificationEntry> entries, int headIndex, String headName) {
|
||||||
|
return new AutoValue_Classifications(
|
||||||
|
Collections.unmodifiableList(entries), headIndex, headName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A list of {@link ClassificationEntry} objects. */
|
||||||
|
public abstract List<ClassificationEntry> entries();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The index of the classifier head these entries refer to. This is useful for multi-head models.
|
||||||
|
*/
|
||||||
|
public abstract int headIndex();
|
||||||
|
|
||||||
|
/** The name of the classifier head, which is the corresponding tensor metadata name. */
|
||||||
|
public abstract String headName();
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "classifieroptions",
|
||||||
|
srcs = ["ClassifierOptions.java"],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.components.processors;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/** Classifier options shared across MediaPipe Java classification tasks. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class ClassifierOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link ClassifierOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/**
|
||||||
|
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||||
|
* Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setDisplayNamesLocale(String locale);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional maximum number of top-scored classification results to return.
|
||||||
|
*
|
||||||
|
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||||
|
*/
|
||||||
|
public abstract Builder setMaxResults(Integer maxResults);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||||
|
*
|
||||||
|
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||||
|
*/
|
||||||
|
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional allowlist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||||
|
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryDenylist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the optional denylist of category names.
|
||||||
|
*
|
||||||
|
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||||
|
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||||
|
* categoryAllowlist}.
|
||||||
|
*/
|
||||||
|
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||||
|
|
||||||
|
abstract ClassifierOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link ClassifierOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if {@link maxResults} is set to a value <= 0.
|
||||||
|
*/
|
||||||
|
public final ClassifierOptions build() {
|
||||||
|
ClassifierOptions options = autoBuild();
|
||||||
|
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||||
|
throw new IllegalArgumentException("If specified, maxResults must be > 0");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract Optional<String> displayNamesLocale();
|
||||||
|
|
||||||
|
public abstract Optional<Integer> maxResults();
|
||||||
|
|
||||||
|
public abstract Optional<Float> scoreThreshold();
|
||||||
|
|
||||||
|
public abstract List<String> categoryAllowlist();
|
||||||
|
|
||||||
|
public abstract List<String> categoryDenylist();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_ClassifierOptions.Builder()
|
||||||
|
.setCategoryAllowlist(Collections.emptyList())
|
||||||
|
.setCategoryDenylist(Collections.emptyList());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a {@link ClassifierOptions} object to a {@link
|
||||||
|
* ClassifierOptionsProto.ClassifierOptions} protobuf message.
|
||||||
|
*/
|
||||||
|
public ClassifierOptionsProto.ClassifierOptions convertToProto() {
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.Builder builder =
|
||||||
|
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||||
|
displayNamesLocale().ifPresent(builder::setDisplayNamesLocale);
|
||||||
|
maxResults().ifPresent(builder::setMaxResults);
|
||||||
|
scoreThreshold().ifPresent(builder::setScoreThreshold);
|
||||||
|
if (!categoryAllowlist().isEmpty()) {
|
||||||
|
builder.addAllCategoryAllowlist(categoryAllowlist());
|
||||||
|
}
|
||||||
|
if (!categoryDenylist().isEmpty()) {
|
||||||
|
builder.addAllCategoryDenylist(categoryDenylist());
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user