Internal change
PiperOrigin-RevId: 479544054
This commit is contained in:
parent
db524adf0d
commit
d90daa859f
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// A single classification result.
|
||||
message Category {
|
||||
// The index of the category in the corresponding label map, usually packed in
|
||||
|
|
|
@ -19,6 +19,9 @@ package mediapipe.tasks.components.containers.proto;
|
|||
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||
option java_outer_classname = "ClassifierOptionsProto";
|
||||
|
||||
// Shared options used by all classification tasks.
|
||||
message ClassifierOptions {
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
|
|
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto";
|
||||
option java_outer_classname = "ImageClassifierGraphOptionsProto";
|
||||
|
||||
message ImageClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ImageClassifierGraphOptions ext = 456383383;
|
||||
|
|
|
@ -34,3 +34,23 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "classification_entry",
|
||||
srcs = ["ClassificationEntry.java"],
|
||||
deps = [
|
||||
":category",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "classifications",
|
||||
srcs = ["Classifications.java"],
|
||||
deps = [
|
||||
":classification_entry",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents a list of predicted categories with an optional timestamp. Typically used as result
|
||||
* for classification tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class ClassificationEntry {
|
||||
/**
|
||||
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
|
||||
* timestamp.
|
||||
*
|
||||
* @param categories the list of {@link Category} objects that contain category name, display
|
||||
* name, score and label index.
|
||||
* @param timestampMs the {@link long} representing the timestamp for which these categories were
|
||||
* obtained.
|
||||
*/
|
||||
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
|
||||
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
|
||||
}
|
||||
|
||||
/** The list of predicted {@link Category} objects, sorted by descending score. */
|
||||
public abstract List<Category> categories();
|
||||
|
||||
/**
|
||||
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
|
||||
* series use cases, e.g. audio classification.
|
||||
*/
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents the list of classification for a given classifier head. Typically used as a result for
|
||||
* classification tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class Classifications {
|
||||
|
||||
/**
|
||||
* Creates a {@link Classifications} instance.
|
||||
*
|
||||
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
|
||||
* categories.
|
||||
* @param headIndex the index of the classifier head.
|
||||
* @param headName the name of the classifier head.
|
||||
*/
|
||||
public static Classifications create(
|
||||
List<ClassificationEntry> entries, int headIndex, String headName) {
|
||||
return new AutoValue_Classifications(
|
||||
Collections.unmodifiableList(entries), headIndex, headName);
|
||||
}
|
||||
|
||||
/** A list of {@link ClassificationEntry} objects. */
|
||||
public abstract List<ClassificationEntry> entries();
|
||||
|
||||
/**
|
||||
* The index of the classifier head these entries refer to. This is useful for multi-head models.
|
||||
*/
|
||||
public abstract int headIndex();
|
||||
|
||||
/** The name of the classifier head, which is the corresponding tensor metadata name. */
|
||||
public abstract String headName();
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "classifieroptions",
|
||||
srcs = ["ClassifierOptions.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.processors;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Classifier options shared across MediaPipe Java classification tasks. */
|
||||
@AutoValue
|
||||
public abstract class ClassifierOptions {
|
||||
|
||||
/** Builder for {@link ClassifierOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setDisplayNamesLocale(String locale);
|
||||
|
||||
/**
|
||||
* Sets the optional maximum number of top-scored classification results to return.
|
||||
*
|
||||
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||
*/
|
||||
public abstract Builder setMaxResults(Integer maxResults);
|
||||
|
||||
/**
|
||||
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||
*
|
||||
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||
|
||||
/**
|
||||
* Sets the optional allowlist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryDenylist}.
|
||||
*/
|
||||
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||
|
||||
/**
|
||||
* Sets the optional denylist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryAllowlist}.
|
||||
*/
|
||||
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||
|
||||
abstract ClassifierOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ClassifierOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if {@link maxResults} is set to a value <= 0.
|
||||
*/
|
||||
public final ClassifierOptions build() {
|
||||
ClassifierOptions options = autoBuild();
|
||||
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||
throw new IllegalArgumentException("If specified, maxResults must be > 0");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
public abstract Optional<String> displayNamesLocale();
|
||||
|
||||
public abstract Optional<Integer> maxResults();
|
||||
|
||||
public abstract Optional<Float> scoreThreshold();
|
||||
|
||||
public abstract List<String> categoryAllowlist();
|
||||
|
||||
public abstract List<String> categoryDenylist();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ClassifierOptions.Builder()
|
||||
.setCategoryAllowlist(Collections.emptyList())
|
||||
.setCategoryDenylist(Collections.emptyList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassifierOptions} object to a {@link
|
||||
* ClassifierOptionsProto.ClassifierOptions} protobuf message.
|
||||
*/
|
||||
public ClassifierOptionsProto.ClassifierOptions convertToProto() {
|
||||
ClassifierOptionsProto.ClassifierOptions.Builder builder =
|
||||
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||
displayNamesLocale().ifPresent(builder::setDisplayNamesLocale);
|
||||
maxResults().ifPresent(builder::setMaxResults);
|
||||
scoreThreshold().ifPresent(builder::setScoreThreshold);
|
||||
if (!categoryAllowlist().isEmpty()) {
|
||||
builder.addAllCategoryAllowlist(categoryAllowlist());
|
||||
}
|
||||
if (!categoryDenylist().isEmpty()) {
|
||||
builder.addAllCategoryDenylist(categoryDenylist());
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user