Internal change

PiperOrigin-RevId: 479544054
This commit is contained in:
MediaPipe Team 2022-10-07 04:05:09 -07:00 committed by Copybara-Service
parent db524adf0d
commit d90daa859f
9 changed files with 280 additions and 0 deletions

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_outer_classname = "CategoryProto";
// A single classification result.
message Category {
// The index of the category in the corresponding label map, usually packed in

View File

@ -19,6 +19,9 @@ package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_outer_classname = "ClassificationsProto";
// List of predicted categories with an optional timestamp.
message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores,

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "ClassifierOptionsProto";
// Shared options used by all classification tasks.
message ClassifierOptions {
// The locale to use for display names specified through the TFLite Model

View File

@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto";
option java_outer_classname = "ImageClassifierGraphOptionsProto";
message ImageClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional ImageClassifierGraphOptions ext = 456383383;

View File

@ -34,3 +34,23 @@ android_library(
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "classification_entry",
srcs = ["ClassificationEntry.java"],
deps = [
":category",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "classifications",
srcs = ["Classifications.java"],
deps = [
":classification_entry",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,48 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents a list of predicted categories with an optional timestamp. Typically used as result
* for classification tasks.
*/
@AutoValue
public abstract class ClassificationEntry {
/**
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
* timestamp.
*
* @param categories the list of {@link Category} objects that contain category name, display
* name, score and label index.
* @param timestampMs the {@link long} representing the timestamp for which these categories were
* obtained.
*/
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
}
/** The list of predicted {@link Category} objects, sorted by descending score. */
public abstract List<Category> categories();
/**
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
* series use cases, e.g. audio classification.
*/
public abstract long timestampMs();
}

View File

@ -0,0 +1,52 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents the list of classification for a given classifier head. Typically used as a result for
* classification tasks.
*/
@AutoValue
public abstract class Classifications {
/**
* Creates a {@link Classifications} instance.
*
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
* categories.
* @param headIndex the index of the classifier head.
* @param headName the name of the classifier head.
*/
public static Classifications create(
List<ClassificationEntry> entries, int headIndex, String headName) {
return new AutoValue_Classifications(
Collections.unmodifiableList(entries), headIndex, headName);
}
/** A list of {@link ClassificationEntry} objects. */
public abstract List<ClassificationEntry> entries();
/**
* The index of the classifier head these entries refer to. This is useful for multi-head models.
*/
public abstract int headIndex();
/** The name of the classifier head, which is the corresponding tensor metadata name. */
public abstract String headName();
}

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "classifieroptions",
srcs = ["ClassifierOptions.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,118 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.processors;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/** Classifier options shared across MediaPipe Java classification tasks. */
@AutoValue
public abstract class ClassifierOptions {
/** Builder for {@link ClassifierOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the optional locale to use for display names specified through the TFLite Model
* Metadata, if any.
*/
public abstract Builder setDisplayNamesLocale(String locale);
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
abstract ClassifierOptions autoBuild();
/**
* Validates and builds the {@link ClassifierOptions} instance.
*
* @throws IllegalArgumentException if {@link maxResults} is set to a value <= 0.
*/
public final ClassifierOptions build() {
ClassifierOptions options = autoBuild();
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0");
}
return options;
}
}
public abstract Optional<String> displayNamesLocale();
public abstract Optional<Integer> maxResults();
public abstract Optional<Float> scoreThreshold();
public abstract List<String> categoryAllowlist();
public abstract List<String> categoryDenylist();
public static Builder builder() {
return new AutoValue_ClassifierOptions.Builder()
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/**
* Converts a {@link ClassifierOptions} object to a {@link
* ClassifierOptionsProto.ClassifierOptions} protobuf message.
*/
public ClassifierOptionsProto.ClassifierOptions convertToProto() {
ClassifierOptionsProto.ClassifierOptions.Builder builder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(builder::setDisplayNamesLocale);
maxResults().ifPresent(builder::setMaxResults);
scoreThreshold().ifPresent(builder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
builder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
builder.addAllCategoryDenylist(categoryDenylist());
}
return builder.build();
}
}