Merge branch 'google:master' into hand-landmarker-python
This commit is contained in:
		
						commit
						6dd6d8921f
					
				| 
						 | 
					@ -20,6 +20,7 @@ android_library(
 | 
				
			||||||
    name = "category",
 | 
					    name = "category",
 | 
				
			||||||
    srcs = ["Category.java"],
 | 
					    srcs = ["Category.java"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_java_proto_lite",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
        "@maven//:com_google_guava_guava",
 | 
					        "@maven//:com_google_guava_guava",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					@ -36,20 +37,29 @@ android_library(
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
android_library(
 | 
					android_library(
 | 
				
			||||||
    name = "classification_entry",
 | 
					    name = "classifications",
 | 
				
			||||||
    srcs = ["ClassificationEntry.java"],
 | 
					    srcs = ["Classifications.java"],
 | 
				
			||||||
 | 
					    javacopts = [
 | 
				
			||||||
 | 
					        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":category",
 | 
					        ":category",
 | 
				
			||||||
 | 
					        "//mediapipe/framework/formats:classification_java_proto_lite",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
        "@maven//:com_google_guava_guava",
 | 
					        "@maven//:com_google_guava_guava",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
android_library(
 | 
					android_library(
 | 
				
			||||||
    name = "classifications",
 | 
					    name = "classificationresult",
 | 
				
			||||||
    srcs = ["Classifications.java"],
 | 
					    srcs = ["ClassificationResult.java"],
 | 
				
			||||||
 | 
					    javacopts = [
 | 
				
			||||||
 | 
					        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":classification_entry",
 | 
					        ":classifications",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
        "@maven//:com_google_guava_guava",
 | 
					        "@maven//:com_google_guava_guava",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
package com.google.mediapipe.tasks.components.containers;
 | 
					package com.google.mediapipe.tasks.components.containers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.auto.value.AutoValue;
 | 
					import com.google.auto.value.AutoValue;
 | 
				
			||||||
 | 
					import com.google.mediapipe.formats.proto.ClassificationProto;
 | 
				
			||||||
import java.util.Objects;
 | 
					import java.util.Objects;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
| 
						 | 
					@ -38,6 +39,16 @@ public abstract class Category {
 | 
				
			||||||
    return new AutoValue_Category(score, index, categoryName, displayName);
 | 
					    return new AutoValue_Category(score, index, categoryName, displayName);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates a {@link Category} object from a {@link ClassificationProto.Classification} protobuf
 | 
				
			||||||
 | 
					   * message.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param proto the {@link ClassificationProto.Classification} protobuf message to convert.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public static Category createFromProto(ClassificationProto.Classification proto) {
 | 
				
			||||||
 | 
					    return create(proto.getScore(), proto.getIndex(), proto.getLabel(), proto.getDisplayName());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** The probability score of this label category. */
 | 
					  /** The probability score of this label category. */
 | 
				
			||||||
  public abstract float score();
 | 
					  public abstract float score();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,48 +0,0 @@
 | 
				
			||||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
// you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
// You may obtain a copy of the License at
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
// See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
// limitations under the License.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package com.google.mediapipe.tasks.components.containers;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import com.google.auto.value.AutoValue;
 | 
					 | 
				
			||||||
import java.util.Collections;
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/**
 | 
					 | 
				
			||||||
 * Represents a list of predicted categories with an optional timestamp. Typically used as result
 | 
					 | 
				
			||||||
 * for classification tasks.
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
@AutoValue
 | 
					 | 
				
			||||||
public abstract class ClassificationEntry {
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
 | 
					 | 
				
			||||||
   * timestamp.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param categories the list of {@link Category} objects that contain category name, display
 | 
					 | 
				
			||||||
   *     name, score and label index.
 | 
					 | 
				
			||||||
   * @param timestampMs the {@link long} representing the timestamp for which these categories were
 | 
					 | 
				
			||||||
   *     obtained.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public static ClassificationEntry create(List<Category> categories, long timestampMs) {
 | 
					 | 
				
			||||||
    return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /** The list of predicted {@link Category} objects, sorted by descending score. */
 | 
					 | 
				
			||||||
  public abstract List<Category> categories();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * The timestamp (in milliseconds) associated to the classification entry. This is useful for time
 | 
					 | 
				
			||||||
   * series use cases, e.g. audio classification.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  public abstract long timestampMs();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,76 @@
 | 
				
			||||||
 | 
					// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package com.google.mediapipe.tasks.components.containers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.google.auto.value.AutoValue;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
 | 
					import java.util.ArrayList;
 | 
				
			||||||
 | 
					import java.util.Collections;
 | 
				
			||||||
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import java.util.Optional;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Represents the classification results of a model. Typically used as a result for classification
 | 
				
			||||||
 | 
					 * tasks.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					@AutoValue
 | 
				
			||||||
 | 
					public abstract class ClassificationResult {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates a {@link ClassificationResult} instance.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param classifications the list of {@link Classifications} objects containing the predicted
 | 
				
			||||||
 | 
					   *     categories for each head of the model.
 | 
				
			||||||
 | 
					   * @param timestampMs the optional timestamp (in milliseconds) of the start of the chunk of data
 | 
				
			||||||
 | 
					   *     corresponding to these results.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public static ClassificationResult create(
 | 
				
			||||||
 | 
					      List<Classifications> classifications, Optional<Long> timestampMs) {
 | 
				
			||||||
 | 
					    return new AutoValue_ClassificationResult(
 | 
				
			||||||
 | 
					        Collections.unmodifiableList(classifications), timestampMs);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates a {@link ClassificationResult} object from a {@link
 | 
				
			||||||
 | 
					   * ClassificationsProto.ClassificationResult} protobuf message.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public static ClassificationResult createFromProto(
 | 
				
			||||||
 | 
					      ClassificationsProto.ClassificationResult proto) {
 | 
				
			||||||
 | 
					    List<Classifications> classifications = new ArrayList<>();
 | 
				
			||||||
 | 
					    for (ClassificationsProto.Classifications classificationsProto :
 | 
				
			||||||
 | 
					        proto.getClassificationsList()) {
 | 
				
			||||||
 | 
					      classifications.add(Classifications.createFromProto(classificationsProto));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    Optional<Long> timestampMs =
 | 
				
			||||||
 | 
					        proto.hasTimestampMs() ? Optional.of(proto.getTimestampMs()) : Optional.empty();
 | 
				
			||||||
 | 
					    return create(classifications, timestampMs);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** The classification results for each head of the model. */
 | 
				
			||||||
 | 
					  public abstract List<Classifications> classifications();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * The optional timestamp (in milliseconds) of the start of the chunk of data corresponding to
 | 
				
			||||||
 | 
					   * these results.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * <p>This is only used for classification on time series (e.g. audio classification). In these
 | 
				
			||||||
 | 
					   * use cases, the amount of data to process might exceed the maximum size that the model can
 | 
				
			||||||
 | 
					   * process: to solve this, the input data is split into multiple chunks starting at different
 | 
				
			||||||
 | 
					   * timestamps.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public abstract Optional<Long> timestampMs();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -15,8 +15,12 @@
 | 
				
			||||||
package com.google.mediapipe.tasks.components.containers;
 | 
					package com.google.mediapipe.tasks.components.containers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import com.google.auto.value.AutoValue;
 | 
					import com.google.auto.value.AutoValue;
 | 
				
			||||||
 | 
					import com.google.mediapipe.formats.proto.ClassificationProto;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
 | 
					import java.util.ArrayList;
 | 
				
			||||||
import java.util.Collections;
 | 
					import java.util.Collections;
 | 
				
			||||||
import java.util.List;
 | 
					import java.util.List;
 | 
				
			||||||
 | 
					import java.util.Optional;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/**
 | 
					/**
 | 
				
			||||||
 * Represents the list of classification for a given classifier head. Typically used as a result for
 | 
					 * Represents the list of classification for a given classifier head. Typically used as a result for
 | 
				
			||||||
| 
						 | 
					@ -28,25 +32,41 @@ public abstract class Classifications {
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Creates a {@link Classifications} instance.
 | 
					   * Creates a {@link Classifications} instance.
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param entries the list of {@link ClassificationEntry} objects containing the predicted
 | 
					   * @param categories the list of {@link Category} objects containing the predicted categories.
 | 
				
			||||||
   *     categories.
 | 
					 | 
				
			||||||
   * @param headIndex the index of the classifier head.
 | 
					   * @param headIndex the index of the classifier head.
 | 
				
			||||||
   * @param headName the name of the classifier head.
 | 
					   * @param headName the optional name of the classifier head.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public static Classifications create(
 | 
					  public static Classifications create(
 | 
				
			||||||
      List<ClassificationEntry> entries, int headIndex, String headName) {
 | 
					      List<Category> categories, int headIndex, Optional<String> headName) {
 | 
				
			||||||
    return new AutoValue_Classifications(
 | 
					    return new AutoValue_Classifications(
 | 
				
			||||||
        Collections.unmodifiableList(entries), headIndex, headName);
 | 
					        Collections.unmodifiableList(categories), headIndex, headName);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** A list of {@link ClassificationEntry} objects. */
 | 
					  /**
 | 
				
			||||||
  public abstract List<ClassificationEntry> entries();
 | 
					   * Creates a {@link Classifications} object from a {@link ClassificationsProto.Classifications}
 | 
				
			||||||
 | 
					   * protobuf message.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param proto the {@link ClassificationsProto.Classifications} protobuf message to convert.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  public static Classifications createFromProto(ClassificationsProto.Classifications proto) {
 | 
				
			||||||
 | 
					    List<Category> categories = new ArrayList<>();
 | 
				
			||||||
 | 
					    for (ClassificationProto.Classification classificationProto :
 | 
				
			||||||
 | 
					        proto.getClassificationList().getClassificationList()) {
 | 
				
			||||||
 | 
					      categories.add(Category.createFromProto(classificationProto));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    Optional<String> headName =
 | 
				
			||||||
 | 
					        proto.hasHeadName() ? Optional.of(proto.getHeadName()) : Optional.empty();
 | 
				
			||||||
 | 
					    return create(categories, proto.getHeadIndex(), headName);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** A list of {@link Category} objects. */
 | 
				
			||||||
 | 
					  public abstract List<Category> categories();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * The index of the classifier head these entries refer to. This is useful for multi-head models.
 | 
					   * The index of the classifier head these entries refer to. This is useful for multi-head models.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public abstract int headIndex();
 | 
					  public abstract int headIndex();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** The name of the classifier head, which is the corresponding tensor metadata name. */
 | 
					  /** The optional name of the classifier head, which is the corresponding tensor metadata name. */
 | 
				
			||||||
  public abstract String headName();
 | 
					  public abstract Optional<String> headName();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -234,7 +234,7 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l
 | 
				
			||||||
            "//mediapipe/framework/formats:rect_java_proto_lite",
 | 
					            "//mediapipe/framework/formats:rect_java_proto_lite",
 | 
				
			||||||
            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
 | 
					            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
 | 
				
			||||||
            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
 | 
					            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
 | 
				
			||||||
            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
 | 
					            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
				
			||||||
            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
 | 
					            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
 | 
				
			||||||
            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
 | 
					            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
 | 
				
			||||||
            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
					            "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,8 +37,8 @@ cc_library(
 | 
				
			||||||
android_library(
 | 
					android_library(
 | 
				
			||||||
    name = "textclassifier",
 | 
					    name = "textclassifier",
 | 
				
			||||||
    srcs = [
 | 
					    srcs = [
 | 
				
			||||||
        "textclassifier/TextClassificationResult.java",
 | 
					 | 
				
			||||||
        "textclassifier/TextClassifier.java",
 | 
					        "textclassifier/TextClassifier.java",
 | 
				
			||||||
 | 
					        "textclassifier/TextClassifierResult.java",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    javacopts = [
 | 
					    javacopts = [
 | 
				
			||||||
        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
					        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
				
			||||||
| 
						 | 
					@ -51,9 +51,7 @@ android_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,103 +0,0 @@
 | 
				
			||||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
// you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
// You may obtain a copy of the License at
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
// See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
// limitations under the License.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package com.google.mediapipe.tasks.text.textclassifier;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import com.google.auto.value.AutoValue;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.Category;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.Classifications;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.core.TaskResult;
 | 
					 | 
				
			||||||
import java.util.ArrayList;
 | 
					 | 
				
			||||||
import java.util.Collections;
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/** Represents the classification results generated by {@link TextClassifier}. */
 | 
					 | 
				
			||||||
@AutoValue
 | 
					 | 
				
			||||||
public abstract class TextClassificationResult implements TaskResult {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Creates an {@link TextClassificationResult} instance from a {@link
 | 
					 | 
				
			||||||
   * ClassificationsProto.ClassificationResult} protobuf message.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
 | 
					 | 
				
			||||||
   *     message.
 | 
					 | 
				
			||||||
   * @param timestampMs a timestamp for this result.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  // TODO: consolidate output formats across platforms.
 | 
					 | 
				
			||||||
  static TextClassificationResult create(
 | 
					 | 
				
			||||||
      ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
 | 
					 | 
				
			||||||
    List<Classifications> classifications = new ArrayList<>();
 | 
					 | 
				
			||||||
    for (ClassificationsProto.Classifications classificationsProto :
 | 
					 | 
				
			||||||
        classificationResult.getClassificationsList()) {
 | 
					 | 
				
			||||||
      classifications.add(classificationsFromProto(classificationsProto));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return new AutoValue_TextClassificationResult(
 | 
					 | 
				
			||||||
        timestampMs, Collections.unmodifiableList(classifications));
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  @Override
 | 
					 | 
				
			||||||
  public abstract long timestampMs();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /** Contains one set of results per classifier head. */
 | 
					 | 
				
			||||||
  @SuppressWarnings("AutoValueImmutableFields")
 | 
					 | 
				
			||||||
  public abstract List<Classifications> classifications();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param category the {@link CategoryProto.Category} protobuf message to convert.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  static Category categoryFromProto(CategoryProto.Category category) {
 | 
					 | 
				
			||||||
    return Category.create(
 | 
					 | 
				
			||||||
        category.getScore(),
 | 
					 | 
				
			||||||
        category.getIndex(),
 | 
					 | 
				
			||||||
        category.getCategoryName(),
 | 
					 | 
				
			||||||
        category.getDisplayName());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
 | 
					 | 
				
			||||||
   * ClassificationEntry} object.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  static ClassificationEntry classificationEntryFromProto(
 | 
					 | 
				
			||||||
      ClassificationsProto.ClassificationEntry entry) {
 | 
					 | 
				
			||||||
    List<Category> categories = new ArrayList<>();
 | 
					 | 
				
			||||||
    for (CategoryProto.Category category : entry.getCategoriesList()) {
 | 
					 | 
				
			||||||
      categories.add(categoryFromProto(category));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return ClassificationEntry.create(categories, entry.getTimestampMs());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
 | 
					 | 
				
			||||||
   * Classifications} object.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
 | 
					 | 
				
			||||||
   *     convert.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  static Classifications classificationsFromProto(
 | 
					 | 
				
			||||||
      ClassificationsProto.Classifications classifications) {
 | 
					 | 
				
			||||||
    List<ClassificationEntry> entries = new ArrayList<>();
 | 
					 | 
				
			||||||
    for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
 | 
					 | 
				
			||||||
      entries.add(classificationEntryFromProto(entry));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return Classifications.create(
 | 
					 | 
				
			||||||
        entries, classifications.getHeadIndex(), classifications.getHeadName());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -22,6 +22,7 @@ import com.google.mediapipe.framework.MediaPipeException;
 | 
				
			||||||
import com.google.mediapipe.framework.Packet;
 | 
					import com.google.mediapipe.framework.Packet;
 | 
				
			||||||
import com.google.mediapipe.framework.PacketGetter;
 | 
					import com.google.mediapipe.framework.PacketGetter;
 | 
				
			||||||
import com.google.mediapipe.framework.ProtoUtil;
 | 
					import com.google.mediapipe.framework.ProtoUtil;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
					import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
					import com.google.mediapipe.tasks.core.BaseOptions;
 | 
				
			||||||
| 
						 | 
					@ -86,10 +87,9 @@ public final class TextClassifier implements AutoCloseable {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @SuppressWarnings("ConstantCaseForConstants")
 | 
					  @SuppressWarnings("ConstantCaseForConstants")
 | 
				
			||||||
  private static final List<String> OUTPUT_STREAMS =
 | 
					  private static final List<String> OUTPUT_STREAMS =
 | 
				
			||||||
      Collections.unmodifiableList(
 | 
					      Collections.unmodifiableList(Arrays.asList("CLASSIFICATIONS:classifications_out"));
 | 
				
			||||||
          Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
 | 
					  private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
 | 
				
			||||||
  private static final String TASK_GRAPH_NAME =
 | 
					  private static final String TASK_GRAPH_NAME =
 | 
				
			||||||
      "mediapipe.tasks.text.text_classifier.TextClassifierGraph";
 | 
					      "mediapipe.tasks.text.text_classifier.TextClassifierGraph";
 | 
				
			||||||
  private final TaskRunner runner;
 | 
					  private final TaskRunner runner;
 | 
				
			||||||
| 
						 | 
					@ -142,17 +142,18 @@ public final class TextClassifier implements AutoCloseable {
 | 
				
			||||||
   * @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
 | 
					   * @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
 | 
					  public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
 | 
				
			||||||
    OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>();
 | 
					    OutputHandler<TextClassifierResult, Void> handler = new OutputHandler<>();
 | 
				
			||||||
    handler.setOutputPacketConverter(
 | 
					    handler.setOutputPacketConverter(
 | 
				
			||||||
        new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() {
 | 
					        new OutputHandler.OutputPacketConverter<TextClassifierResult, Void>() {
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public TextClassificationResult convertToTaskResult(List<Packet> packets) {
 | 
					          public TextClassifierResult convertToTaskResult(List<Packet> packets) {
 | 
				
			||||||
            try {
 | 
					            try {
 | 
				
			||||||
              return TextClassificationResult.create(
 | 
					              return TextClassifierResult.create(
 | 
				
			||||||
 | 
					                  ClassificationResult.createFromProto(
 | 
				
			||||||
                      PacketGetter.getProto(
 | 
					                      PacketGetter.getProto(
 | 
				
			||||||
                      packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
 | 
					                          packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
 | 
				
			||||||
                      ClassificationsProto.ClassificationResult.getDefaultInstance()),
 | 
					                          ClassificationsProto.ClassificationResult.getDefaultInstance())),
 | 
				
			||||||
                  packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
 | 
					                  packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
 | 
				
			||||||
            } catch (IOException e) {
 | 
					            } catch (IOException e) {
 | 
				
			||||||
              throw new MediaPipeException(
 | 
					              throw new MediaPipeException(
 | 
				
			||||||
                  MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
 | 
					                  MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
 | 
				
			||||||
| 
						 | 
					@ -192,10 +193,10 @@ public final class TextClassifier implements AutoCloseable {
 | 
				
			||||||
   *
 | 
					   *
 | 
				
			||||||
   * @param inputText a {@link String} for processing.
 | 
					   * @param inputText a {@link String} for processing.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public TextClassificationResult classify(String inputText) {
 | 
					  public TextClassifierResult classify(String inputText) {
 | 
				
			||||||
    Map<String, Packet> inputPackets = new HashMap<>();
 | 
					    Map<String, Packet> inputPackets = new HashMap<>();
 | 
				
			||||||
    inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
 | 
					    inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
 | 
				
			||||||
    return (TextClassificationResult) runner.process(inputPackets);
 | 
					    return (TextClassifierResult) runner.process(inputPackets);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Closes and cleans up the {@link TextClassifier}. */
 | 
					  /** Closes and cleans up the {@link TextClassifier}. */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,55 @@
 | 
				
			||||||
 | 
					// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package com.google.mediapipe.tasks.text.textclassifier;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.google.auto.value.AutoValue;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.core.TaskResult;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Represents the classification results generated by {@link TextClassifier}. */
 | 
				
			||||||
 | 
					@AutoValue
 | 
				
			||||||
 | 
					public abstract class TextClassifierResult implements TaskResult {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates an {@link TextClassifierResult} instance.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param classificationResult the {@link ClassificationResult} object containing one set of
 | 
				
			||||||
 | 
					   *     results per classifier head.
 | 
				
			||||||
 | 
					   * @param timestampMs a timestamp for this result.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static TextClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
 | 
				
			||||||
 | 
					    return new AutoValue_TextClassifierResult(classificationResult, timestampMs);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates an {@link TextClassifierResult} instance from a {@link
 | 
				
			||||||
 | 
					   * ClassificationsProto.ClassificationResult} protobuf message.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
 | 
				
			||||||
 | 
					   * @param timestampMs a timestamp for this result.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  // TODO: consolidate output formats across platforms.
 | 
				
			||||||
 | 
					  static TextClassifierResult createFromProto(
 | 
				
			||||||
 | 
					      ClassificationsProto.ClassificationResult proto, long timestampMs) {
 | 
				
			||||||
 | 
					    return create(ClassificationResult.createFromProto(proto), timestampMs);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Contains one set of results per classifier head. */
 | 
				
			||||||
 | 
					  public abstract ClassificationResult classificationResult();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Override
 | 
				
			||||||
 | 
					  public abstract long timestampMs();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -84,8 +84,8 @@ android_library(
 | 
				
			||||||
android_library(
 | 
					android_library(
 | 
				
			||||||
    name = "imageclassifier",
 | 
					    name = "imageclassifier",
 | 
				
			||||||
    srcs = [
 | 
					    srcs = [
 | 
				
			||||||
        "imageclassifier/ImageClassificationResult.java",
 | 
					 | 
				
			||||||
        "imageclassifier/ImageClassifier.java",
 | 
					        "imageclassifier/ImageClassifier.java",
 | 
				
			||||||
 | 
					        "imageclassifier/ImageClassifierResult.java",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    javacopts = [
 | 
					    javacopts = [
 | 
				
			||||||
        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
					        "-Xep:AndroidJdkLibsChecker:OFF",
 | 
				
			||||||
| 
						 | 
					@ -100,9 +100,7 @@ android_library(
 | 
				
			||||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
 | 
					        "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classificationresult",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
 | 
					 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
 | 
				
			||||||
        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
					        "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
 | 
				
			||||||
        "//third_party:autovalue",
 | 
					        "//third_party:autovalue",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,102 +0,0 @@
 | 
				
			||||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
// you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
// You may obtain a copy of the License at
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
//
 | 
					 | 
				
			||||||
// Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
// See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
// limitations under the License.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package com.google.mediapipe.tasks.vision.imageclassifier;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import com.google.auto.value.AutoValue;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.Category;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.Classifications;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.CategoryProto;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					 | 
				
			||||||
import com.google.mediapipe.tasks.core.TaskResult;
 | 
					 | 
				
			||||||
import java.util.ArrayList;
 | 
					 | 
				
			||||||
import java.util.Collections;
 | 
					 | 
				
			||||||
import java.util.List;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/** Represents the classification results generated by {@link ImageClassifier}. */
 | 
					 | 
				
			||||||
@AutoValue
 | 
					 | 
				
			||||||
public abstract class ImageClassificationResult implements TaskResult {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Creates an {@link ImageClassificationResult} instance from a {@link
 | 
					 | 
				
			||||||
   * ClassificationsProto.ClassificationResult} protobuf message.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
 | 
					 | 
				
			||||||
   *     message.
 | 
					 | 
				
			||||||
   * @param timestampMs a timestamp for this result.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  // TODO: consolidate output formats across platforms.
 | 
					 | 
				
			||||||
  static ImageClassificationResult create(
 | 
					 | 
				
			||||||
      ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
 | 
					 | 
				
			||||||
    List<Classifications> classifications = new ArrayList<>();
 | 
					 | 
				
			||||||
    for (ClassificationsProto.Classifications classificationsProto :
 | 
					 | 
				
			||||||
        classificationResult.getClassificationsList()) {
 | 
					 | 
				
			||||||
      classifications.add(classificationsFromProto(classificationsProto));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return new AutoValue_ImageClassificationResult(
 | 
					 | 
				
			||||||
        timestampMs, Collections.unmodifiableList(classifications));
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  @Override
 | 
					 | 
				
			||||||
  public abstract long timestampMs();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /** Contains one set of results per classifier head. */
 | 
					 | 
				
			||||||
  public abstract List<Classifications> classifications();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param category the {@link CategoryProto.Category} protobuf message to convert.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  static Category categoryFromProto(CategoryProto.Category category) {
 | 
					 | 
				
			||||||
    return Category.create(
 | 
					 | 
				
			||||||
        category.getScore(),
 | 
					 | 
				
			||||||
        category.getIndex(),
 | 
					 | 
				
			||||||
        category.getCategoryName(),
 | 
					 | 
				
			||||||
        category.getDisplayName());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
 | 
					 | 
				
			||||||
   * ClassificationEntry} object.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  static ClassificationEntry classificationEntryFromProto(
 | 
					 | 
				
			||||||
      ClassificationsProto.ClassificationEntry entry) {
 | 
					 | 
				
			||||||
    List<Category> categories = new ArrayList<>();
 | 
					 | 
				
			||||||
    for (CategoryProto.Category category : entry.getCategoriesList()) {
 | 
					 | 
				
			||||||
      categories.add(categoryFromProto(category));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return ClassificationEntry.create(categories, entry.getTimestampMs());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /**
 | 
					 | 
				
			||||||
   * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
 | 
					 | 
				
			||||||
   * Classifications} object.
 | 
					 | 
				
			||||||
   *
 | 
					 | 
				
			||||||
   * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
 | 
					 | 
				
			||||||
   *     convert.
 | 
					 | 
				
			||||||
   */
 | 
					 | 
				
			||||||
  static Classifications classificationsFromProto(
 | 
					 | 
				
			||||||
      ClassificationsProto.Classifications classifications) {
 | 
					 | 
				
			||||||
    List<ClassificationEntry> entries = new ArrayList<>();
 | 
					 | 
				
			||||||
    for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
 | 
					 | 
				
			||||||
      entries.add(classificationEntryFromProto(entry));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    return Classifications.create(
 | 
					 | 
				
			||||||
        entries, classifications.getHeadIndex(), classifications.getHeadName());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -25,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter;
 | 
				
			||||||
import com.google.mediapipe.framework.ProtoUtil;
 | 
					import com.google.mediapipe.framework.ProtoUtil;
 | 
				
			||||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
 | 
					import com.google.mediapipe.framework.image.BitmapImageBuilder;
 | 
				
			||||||
import com.google.mediapipe.framework.image.MPImage;
 | 
					import com.google.mediapipe.framework.image.MPImage;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
					import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
 | 
				
			||||||
import com.google.mediapipe.tasks.core.BaseOptions;
 | 
					import com.google.mediapipe.tasks.core.BaseOptions;
 | 
				
			||||||
| 
						 | 
					@ -96,8 +97,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
					          Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
 | 
				
			||||||
  private static final List<String> OUTPUT_STREAMS =
 | 
					  private static final List<String> OUTPUT_STREAMS =
 | 
				
			||||||
      Collections.unmodifiableList(
 | 
					      Collections.unmodifiableList(
 | 
				
			||||||
          Arrays.asList("CLASSIFICATION_RESULT:classification_result_out", "IMAGE:image_out"));
 | 
					          Arrays.asList("CLASSIFICATIONS:classifications_out", "IMAGE:image_out"));
 | 
				
			||||||
  private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
 | 
					  private static final int CLASSIFICATIONS_OUT_STREAM_INDEX = 0;
 | 
				
			||||||
  private static final int IMAGE_OUT_STREAM_INDEX = 1;
 | 
					  private static final int IMAGE_OUT_STREAM_INDEX = 1;
 | 
				
			||||||
  private static final String TASK_GRAPH_NAME =
 | 
					  private static final String TASK_GRAPH_NAME =
 | 
				
			||||||
      "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
 | 
					      "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
 | 
				
			||||||
| 
						 | 
					@ -164,17 +165,18 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
   * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
 | 
					   * @throws MediaPipeException if there is an error during {@link ImageClassifier} creation.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
 | 
					  public static ImageClassifier createFromOptions(Context context, ImageClassifierOptions options) {
 | 
				
			||||||
    OutputHandler<ImageClassificationResult, MPImage> handler = new OutputHandler<>();
 | 
					    OutputHandler<ImageClassifierResult, MPImage> handler = new OutputHandler<>();
 | 
				
			||||||
    handler.setOutputPacketConverter(
 | 
					    handler.setOutputPacketConverter(
 | 
				
			||||||
        new OutputHandler.OutputPacketConverter<ImageClassificationResult, MPImage>() {
 | 
					        new OutputHandler.OutputPacketConverter<ImageClassifierResult, MPImage>() {
 | 
				
			||||||
          @Override
 | 
					          @Override
 | 
				
			||||||
          public ImageClassificationResult convertToTaskResult(List<Packet> packets) {
 | 
					          public ImageClassifierResult convertToTaskResult(List<Packet> packets) {
 | 
				
			||||||
            try {
 | 
					            try {
 | 
				
			||||||
              return ImageClassificationResult.create(
 | 
					              return ImageClassifierResult.create(
 | 
				
			||||||
 | 
					                  ClassificationResult.createFromProto(
 | 
				
			||||||
                      PacketGetter.getProto(
 | 
					                      PacketGetter.getProto(
 | 
				
			||||||
                      packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
 | 
					                          packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX),
 | 
				
			||||||
                      ClassificationsProto.ClassificationResult.getDefaultInstance()),
 | 
					                          ClassificationsProto.ClassificationResult.getDefaultInstance())),
 | 
				
			||||||
                  packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
 | 
					                  packets.get(CLASSIFICATIONS_OUT_STREAM_INDEX).getTimestamp());
 | 
				
			||||||
            } catch (IOException e) {
 | 
					            } catch (IOException e) {
 | 
				
			||||||
              throw new MediaPipeException(
 | 
					              throw new MediaPipeException(
 | 
				
			||||||
                  MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
 | 
					                  MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
 | 
				
			||||||
| 
						 | 
					@ -229,7 +231,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
   * @param image a MediaPipe {@link MPImage} object for processing.
 | 
					   * @param image a MediaPipe {@link MPImage} object for processing.
 | 
				
			||||||
   * @throws MediaPipeException if there is an internal error.
 | 
					   * @throws MediaPipeException if there is an internal error.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public ImageClassificationResult classify(MPImage image) {
 | 
					  public ImageClassifierResult classify(MPImage image) {
 | 
				
			||||||
    return classify(image, ImageProcessingOptions.builder().build());
 | 
					    return classify(image, ImageProcessingOptions.builder().build());
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -248,9 +250,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
   *     input image before running inference.
 | 
					   *     input image before running inference.
 | 
				
			||||||
   * @throws MediaPipeException if there is an internal error.
 | 
					   * @throws MediaPipeException if there is an internal error.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public ImageClassificationResult classify(
 | 
					  public ImageClassifierResult classify(
 | 
				
			||||||
      MPImage image, ImageProcessingOptions imageProcessingOptions) {
 | 
					      MPImage image, ImageProcessingOptions imageProcessingOptions) {
 | 
				
			||||||
    return (ImageClassificationResult) processImageData(image, imageProcessingOptions);
 | 
					    return (ImageClassifierResult) processImageData(image, imageProcessingOptions);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
| 
						 | 
					@ -271,7 +273,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
   * @param timestampMs the input timestamp (in milliseconds).
 | 
					   * @param timestampMs the input timestamp (in milliseconds).
 | 
				
			||||||
   * @throws MediaPipeException if there is an internal error.
 | 
					   * @throws MediaPipeException if there is an internal error.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public ImageClassificationResult classifyForVideo(MPImage image, long timestampMs) {
 | 
					  public ImageClassifierResult classifyForVideo(MPImage image, long timestampMs) {
 | 
				
			||||||
    return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
 | 
					    return classifyForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -294,9 +296,9 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
   * @param timestampMs the input timestamp (in milliseconds).
 | 
					   * @param timestampMs the input timestamp (in milliseconds).
 | 
				
			||||||
   * @throws MediaPipeException if there is an internal error.
 | 
					   * @throws MediaPipeException if there is an internal error.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  public ImageClassificationResult classifyForVideo(
 | 
					  public ImageClassifierResult classifyForVideo(
 | 
				
			||||||
      MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
 | 
					      MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
 | 
				
			||||||
    return (ImageClassificationResult) processVideoData(image, imageProcessingOptions, timestampMs);
 | 
					    return (ImageClassifierResult) processVideoData(image, imageProcessingOptions, timestampMs);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
| 
						 | 
					@ -383,7 +385,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
       * the image classifier is in the live stream mode.
 | 
					       * the image classifier is in the live stream mode.
 | 
				
			||||||
       */
 | 
					       */
 | 
				
			||||||
      public abstract Builder setResultListener(
 | 
					      public abstract Builder setResultListener(
 | 
				
			||||||
          ResultListener<ImageClassificationResult, MPImage> resultListener);
 | 
					          ResultListener<ImageClassifierResult, MPImage> resultListener);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      /** Sets an optional {@link ErrorListener}. */
 | 
					      /** Sets an optional {@link ErrorListener}. */
 | 
				
			||||||
      public abstract Builder setErrorListener(ErrorListener errorListener);
 | 
					      public abstract Builder setErrorListener(ErrorListener errorListener);
 | 
				
			||||||
| 
						 | 
					@ -420,7 +422,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ClassifierOptions> classifierOptions();
 | 
					    abstract Optional<ClassifierOptions> classifierOptions();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ResultListener<ImageClassificationResult, MPImage>> resultListener();
 | 
					    abstract Optional<ResultListener<ImageClassifierResult, MPImage>> resultListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    abstract Optional<ErrorListener> errorListener();
 | 
					    abstract Optional<ErrorListener> errorListener();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,55 @@
 | 
				
			||||||
 | 
					// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package com.google.mediapipe.tasks.vision.imageclassifier;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import com.google.auto.value.AutoValue;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.ClassificationResult;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.components.containers.proto.ClassificationsProto;
 | 
				
			||||||
 | 
					import com.google.mediapipe.tasks.core.TaskResult;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/** Represents the classification results generated by {@link ImageClassifier}. */
 | 
				
			||||||
 | 
					@AutoValue
 | 
				
			||||||
 | 
					public abstract class ImageClassifierResult implements TaskResult {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates an {@link ImageClassifierResult} instance.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param classificationResult the {@link ClassificationResult} object containing one set of
 | 
				
			||||||
 | 
					   *     results per classifier head.
 | 
				
			||||||
 | 
					   * @param timestampMs a timestamp for this result.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  static ImageClassifierResult create(ClassificationResult classificationResult, long timestampMs) {
 | 
				
			||||||
 | 
					    return new AutoValue_ImageClassifierResult(classificationResult, timestampMs);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Creates an {@link ImageClassifierResult} instance from a {@link
 | 
				
			||||||
 | 
					   * ClassificationsProto.ClassificationResult} protobuf message.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param proto the {@link ClassificationsProto.ClassificationResult} protobuf message to convert.
 | 
				
			||||||
 | 
					   * @param timestampMs a timestamp for this result.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  // TODO: consolidate output formats across platforms.
 | 
				
			||||||
 | 
					  static ImageClassifierResult createFromProto(
 | 
				
			||||||
 | 
					      ClassificationsProto.ClassificationResult proto, long timestampMs) {
 | 
				
			||||||
 | 
					    return create(ClassificationResult.createFromProto(proto), timestampMs);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Contains one set of results per classifier head. */
 | 
				
			||||||
 | 
					  public abstract ClassificationResult classificationResult();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  @Override
 | 
				
			||||||
 | 
					  public abstract long timestampMs();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -76,7 +76,7 @@ public class TextClassifierTest {
 | 
				
			||||||
  public void classify_succeedsWithBert() throws Exception {
 | 
					  public void classify_succeedsWithBert() throws Exception {
 | 
				
			||||||
    TextClassifier textClassifier =
 | 
					    TextClassifier textClassifier =
 | 
				
			||||||
        TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
 | 
					        TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
 | 
				
			||||||
    TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
 | 
					    TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
 | 
				
			||||||
    assertHasOneHead(negativeResults);
 | 
					    assertHasOneHead(negativeResults);
 | 
				
			||||||
    assertCategoriesAre(
 | 
					    assertCategoriesAre(
 | 
				
			||||||
        negativeResults,
 | 
					        negativeResults,
 | 
				
			||||||
| 
						 | 
					@ -84,7 +84,7 @@ public class TextClassifierTest {
 | 
				
			||||||
            Category.create(0.95630914f, 0, "negative", ""),
 | 
					            Category.create(0.95630914f, 0, "negative", ""),
 | 
				
			||||||
            Category.create(0.04369091f, 1, "positive", "")));
 | 
					            Category.create(0.04369091f, 1, "positive", "")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
 | 
					    TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
 | 
				
			||||||
    assertHasOneHead(positiveResults);
 | 
					    assertHasOneHead(positiveResults);
 | 
				
			||||||
    assertCategoriesAre(
 | 
					    assertCategoriesAre(
 | 
				
			||||||
        positiveResults,
 | 
					        positiveResults,
 | 
				
			||||||
| 
						 | 
					@ -99,7 +99,7 @@ public class TextClassifierTest {
 | 
				
			||||||
        TextClassifier.createFromFile(
 | 
					        TextClassifier.createFromFile(
 | 
				
			||||||
            ApplicationProvider.getApplicationContext(),
 | 
					            ApplicationProvider.getApplicationContext(),
 | 
				
			||||||
            TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
 | 
					            TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
 | 
				
			||||||
    TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
 | 
					    TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
 | 
				
			||||||
    assertHasOneHead(negativeResults);
 | 
					    assertHasOneHead(negativeResults);
 | 
				
			||||||
    assertCategoriesAre(
 | 
					    assertCategoriesAre(
 | 
				
			||||||
        negativeResults,
 | 
					        negativeResults,
 | 
				
			||||||
| 
						 | 
					@ -107,7 +107,7 @@ public class TextClassifierTest {
 | 
				
			||||||
            Category.create(0.95630914f, 0, "negative", ""),
 | 
					            Category.create(0.95630914f, 0, "negative", ""),
 | 
				
			||||||
            Category.create(0.04369091f, 1, "positive", "")));
 | 
					            Category.create(0.04369091f, 1, "positive", "")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
 | 
					    TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
 | 
				
			||||||
    assertHasOneHead(positiveResults);
 | 
					    assertHasOneHead(positiveResults);
 | 
				
			||||||
    assertHasOneHead(positiveResults);
 | 
					    assertHasOneHead(positiveResults);
 | 
				
			||||||
    assertCategoriesAre(
 | 
					    assertCategoriesAre(
 | 
				
			||||||
| 
						 | 
					@ -122,7 +122,7 @@ public class TextClassifierTest {
 | 
				
			||||||
    TextClassifier textClassifier =
 | 
					    TextClassifier textClassifier =
 | 
				
			||||||
        TextClassifier.createFromFile(
 | 
					        TextClassifier.createFromFile(
 | 
				
			||||||
            ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
 | 
					            ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
 | 
				
			||||||
    TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
 | 
					    TextClassifierResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
 | 
				
			||||||
    assertHasOneHead(negativeResults);
 | 
					    assertHasOneHead(negativeResults);
 | 
				
			||||||
    assertCategoriesAre(
 | 
					    assertCategoriesAre(
 | 
				
			||||||
        negativeResults,
 | 
					        negativeResults,
 | 
				
			||||||
| 
						 | 
					@ -130,7 +130,7 @@ public class TextClassifierTest {
 | 
				
			||||||
            Category.create(0.6647746f, 0, "Negative", ""),
 | 
					            Category.create(0.6647746f, 0, "Negative", ""),
 | 
				
			||||||
            Category.create(0.33522537f, 1, "Positive", "")));
 | 
					            Category.create(0.33522537f, 1, "Positive", "")));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
 | 
					    TextClassifierResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
 | 
				
			||||||
    assertHasOneHead(positiveResults);
 | 
					    assertHasOneHead(positiveResults);
 | 
				
			||||||
    assertCategoriesAre(
 | 
					    assertCategoriesAre(
 | 
				
			||||||
        positiveResults,
 | 
					        positiveResults,
 | 
				
			||||||
| 
						 | 
					@ -139,16 +139,15 @@ public class TextClassifierTest {
 | 
				
			||||||
            Category.create(0.48799595f, 1, "Positive", "")));
 | 
					            Category.create(0.48799595f, 1, "Positive", "")));
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void assertHasOneHead(TextClassificationResult results) {
 | 
					  private static void assertHasOneHead(TextClassifierResult results) {
 | 
				
			||||||
    assertThat(results.classifications()).hasSize(1);
 | 
					    assertThat(results.classificationResult().classifications()).hasSize(1);
 | 
				
			||||||
    assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
 | 
					    assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
 | 
				
			||||||
    assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
 | 
					    assertThat(results.classificationResult().classifications().get(0).headName().get())
 | 
				
			||||||
    assertThat(results.classifications().get(0).entries()).hasSize(1);
 | 
					        .isEqualTo("probability");
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void assertCategoriesAre(
 | 
					  private static void assertCategoriesAre(TextClassifierResult results, List<Category> categories) {
 | 
				
			||||||
      TextClassificationResult results, List<Category> categories) {
 | 
					    assertThat(results.classificationResult().classifications().get(0).categories())
 | 
				
			||||||
    assertThat(results.classifications().get(0).entries().get(0).categories())
 | 
					 | 
				
			||||||
        .isEqualTo(categories);
 | 
					        .isEqualTo(categories);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -91,11 +91,12 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromFile(
 | 
					          ImageClassifier.createFromFile(
 | 
				
			||||||
              ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
 | 
					              ApplicationProvider.getApplicationContext(), FLOAT_MODEL_FILE);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertThat(results.classifications().get(0).entries().get(0).categories()).hasSize(1001);
 | 
					      assertThat(results.classificationResult().classifications().get(0).categories())
 | 
				
			||||||
      assertThat(results.classifications().get(0).entries().get(0).categories().get(0))
 | 
					          .hasSize(1001);
 | 
				
			||||||
 | 
					      assertThat(results.classificationResult().classifications().get(0).categories().get(0))
 | 
				
			||||||
          .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
 | 
					          .isEqualTo(Category.create(0.7952058f, 934, "cheeseburger", ""));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -108,9 +109,9 @@ public class ImageClassifierTest {
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results,
 | 
					          results,
 | 
				
			||||||
          Arrays.asList(
 | 
					          Arrays.asList(
 | 
				
			||||||
| 
						 | 
					@ -128,9 +129,9 @@ public class ImageClassifierTest {
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
 | 
					          results, Arrays.asList(Category.create(0.97265625f, 934, "cheeseburger", "")));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -144,9 +145,9 @@ public class ImageClassifierTest {
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results,
 | 
					          results,
 | 
				
			||||||
          Arrays.asList(
 | 
					          Arrays.asList(
 | 
				
			||||||
| 
						 | 
					@ -166,9 +167,9 @@ public class ImageClassifierTest {
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results,
 | 
					          results,
 | 
				
			||||||
          Arrays.asList(
 | 
					          Arrays.asList(
 | 
				
			||||||
| 
						 | 
					@ -190,9 +191,9 @@ public class ImageClassifierTest {
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results,
 | 
					          results,
 | 
				
			||||||
          Arrays.asList(
 | 
					          Arrays.asList(
 | 
				
			||||||
| 
						 | 
					@ -214,10 +215,10 @@ public class ImageClassifierTest {
 | 
				
			||||||
      RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
 | 
					      RectF roi = new RectF(0.450f, 0.308f, 0.614f, 0.734f);
 | 
				
			||||||
      ImageProcessingOptions imageProcessingOptions =
 | 
					      ImageProcessingOptions imageProcessingOptions =
 | 
				
			||||||
          ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
 | 
					          ImageProcessingOptions.builder().setRegionOfInterest(roi).build();
 | 
				
			||||||
      ImageClassificationResult results =
 | 
					      ImageClassifierResult results =
 | 
				
			||||||
          imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
 | 
					          imageClassifier.classify(getImageFromAsset(MULTI_OBJECTS_IMAGE), imageProcessingOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
 | 
					          results, Arrays.asList(Category.create(0.9969325f, 806, "soccer ball", "")));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -233,10 +234,10 @@ public class ImageClassifierTest {
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageProcessingOptions imageProcessingOptions =
 | 
					      ImageProcessingOptions imageProcessingOptions =
 | 
				
			||||||
          ImageProcessingOptions.builder().setRotationDegrees(-90).build();
 | 
					          ImageProcessingOptions.builder().setRotationDegrees(-90).build();
 | 
				
			||||||
      ImageClassificationResult results =
 | 
					      ImageClassifierResult results =
 | 
				
			||||||
          imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
 | 
					          imageClassifier.classify(getImageFromAsset(BURGER_ROTATED_IMAGE), imageProcessingOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results,
 | 
					          results,
 | 
				
			||||||
          Arrays.asList(
 | 
					          Arrays.asList(
 | 
				
			||||||
| 
						 | 
					@ -258,11 +259,11 @@ public class ImageClassifierTest {
 | 
				
			||||||
      RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
 | 
					      RectF roi = new RectF(0.0f, 0.1763f, 0.5642f, 0.3049f);
 | 
				
			||||||
      ImageProcessingOptions imageProcessingOptions =
 | 
					      ImageProcessingOptions imageProcessingOptions =
 | 
				
			||||||
          ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
 | 
					          ImageProcessingOptions.builder().setRegionOfInterest(roi).setRotationDegrees(-90).build();
 | 
				
			||||||
      ImageClassificationResult results =
 | 
					      ImageClassifierResult results =
 | 
				
			||||||
          imageClassifier.classify(
 | 
					          imageClassifier.classify(
 | 
				
			||||||
              getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
 | 
					              getImageFromAsset(MULTI_OBJECTS_ROTATED_IMAGE), imageProcessingOptions);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
 | 
					          results, Arrays.asList(Category.create(0.686824f, 560, "folding chair", "")));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -391,9 +392,9 @@ public class ImageClassifierTest {
 | 
				
			||||||
              .build();
 | 
					              .build();
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      ImageClassificationResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
					      ImageClassifierResult results = imageClassifier.classify(getImageFromAsset(BURGER_IMAGE));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      assertHasOneHeadAndOneTimestamp(results, 0);
 | 
					      assertHasOneHead(results);
 | 
				
			||||||
      assertCategoriesAre(
 | 
					      assertCategoriesAre(
 | 
				
			||||||
          results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
 | 
					          results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
| 
						 | 
					@ -410,9 +411,8 @@ public class ImageClassifierTest {
 | 
				
			||||||
      ImageClassifier imageClassifier =
 | 
					      ImageClassifier imageClassifier =
 | 
				
			||||||
          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
					          ImageClassifier.createFromOptions(ApplicationProvider.getApplicationContext(), options);
 | 
				
			||||||
      for (int i = 0; i < 3; i++) {
 | 
					      for (int i = 0; i < 3; i++) {
 | 
				
			||||||
        ImageClassificationResult results =
 | 
					        ImageClassifierResult results = imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
 | 
				
			||||||
            imageClassifier.classifyForVideo(image, /*timestampMs=*/ i);
 | 
					        assertHasOneHead(results);
 | 
				
			||||||
        assertHasOneHeadAndOneTimestamp(results, i);
 | 
					 | 
				
			||||||
        assertCategoriesAre(
 | 
					        assertCategoriesAre(
 | 
				
			||||||
            results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
 | 
					            results, Arrays.asList(Category.create(0.7952058f, 934, "cheeseburger", "")));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
| 
						 | 
					@ -478,24 +478,17 @@ public class ImageClassifierTest {
 | 
				
			||||||
    return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
 | 
					    return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void assertHasOneHeadAndOneTimestamp(
 | 
					  private static void assertHasOneHead(ImageClassifierResult results) {
 | 
				
			||||||
      ImageClassificationResult results, long timestampMs) {
 | 
					    assertThat(results.classificationResult().classifications()).hasSize(1);
 | 
				
			||||||
    assertThat(results.classifications()).hasSize(1);
 | 
					    assertThat(results.classificationResult().classifications().get(0).headIndex()).isEqualTo(0);
 | 
				
			||||||
    assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
 | 
					    assertThat(results.classificationResult().classifications().get(0).headName().get())
 | 
				
			||||||
    assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
 | 
					        .isEqualTo("probability");
 | 
				
			||||||
    assertThat(results.classifications().get(0).entries()).hasSize(1);
 | 
					 | 
				
			||||||
    assertThat(results.classifications().get(0).entries().get(0).timestampMs())
 | 
					 | 
				
			||||||
        .isEqualTo(timestampMs);
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void assertCategoriesAre(
 | 
					  private static void assertCategoriesAre(
 | 
				
			||||||
      ImageClassificationResult results, List<Category> categories) {
 | 
					      ImageClassifierResult results, List<Category> categories) {
 | 
				
			||||||
    assertThat(results.classifications().get(0).entries().get(0).categories())
 | 
					    assertThat(results.classificationResult().classifications().get(0).categories())
 | 
				
			||||||
        .hasSize(categories.size());
 | 
					        .isEqualTo(categories);
 | 
				
			||||||
    for (int i = 0; i < categories.size(); i++) {
 | 
					 | 
				
			||||||
      assertThat(results.classifications().get(0).entries().get(0).categories().get(i))
 | 
					 | 
				
			||||||
          .isEqualTo(categories.get(i));
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private static void assertImageSizeIsExpected(MPImage inputImage) {
 | 
					  private static void assertImageSizeIsExpected(MPImage inputImage) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -174,12 +174,10 @@ class AudioClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      self.assertIsInstance(classifier, _AudioClassifier)
 | 
					      self.assertIsInstance(classifier, _AudioClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
    # Invalid empty model path.
 | 
					 | 
				
			||||||
    with self.assertRaisesRegex(
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
        ValueError,
 | 
					        RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
 | 
				
			||||||
        r"ExternalFile must specify at least one of 'file_content', "
 | 
					      base_options = _BaseOptions(
 | 
				
			||||||
        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
					          model_asset_path='/path/to/invalid/model.tflite')
 | 
				
			||||||
      base_options = _BaseOptions(model_asset_path='')
 | 
					 | 
				
			||||||
      options = _AudioClassifierOptions(base_options=base_options)
 | 
					      options = _AudioClassifierOptions(base_options=base_options)
 | 
				
			||||||
      _AudioClassifier.create_from_options(options)
 | 
					      _AudioClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,12 +154,10 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      self.assertIsInstance(classifier, _TextClassifier)
 | 
					      self.assertIsInstance(classifier, _TextClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
    # Invalid empty model path.
 | 
					 | 
				
			||||||
    with self.assertRaisesRegex(
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
        ValueError,
 | 
					        RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
 | 
				
			||||||
        r"ExternalFile must specify at least one of 'file_content', "
 | 
					      base_options = _BaseOptions(
 | 
				
			||||||
        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
					          model_asset_path='/path/to/invalid/model.tflite')
 | 
				
			||||||
      base_options = _BaseOptions(model_asset_path='')
 | 
					 | 
				
			||||||
      options = _TextClassifierOptions(base_options=base_options)
 | 
					      options = _TextClassifierOptions(base_options=base_options)
 | 
				
			||||||
      _TextClassifier.create_from_options(options)
 | 
					      _TextClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -147,12 +147,10 @@ class ImageClassifierTest(parameterized.TestCase):
 | 
				
			||||||
      self.assertIsInstance(classifier, _ImageClassifier)
 | 
					      self.assertIsInstance(classifier, _ImageClassifier)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
    # Invalid empty model path.
 | 
					 | 
				
			||||||
    with self.assertRaisesRegex(
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
        ValueError,
 | 
					        RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
 | 
				
			||||||
        r"ExternalFile must specify at least one of 'file_content', "
 | 
					      base_options = _BaseOptions(
 | 
				
			||||||
        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
					          model_asset_path='/path/to/invalid/model.tflite')
 | 
				
			||||||
      base_options = _BaseOptions(model_asset_path='')
 | 
					 | 
				
			||||||
      options = _ImageClassifierOptions(base_options=base_options)
 | 
					      options = _ImageClassifierOptions(base_options=base_options)
 | 
				
			||||||
      _ImageClassifier.create_from_options(options)
 | 
					      _ImageClassifier.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,12 +97,10 @@ class ImageSegmenterTest(parameterized.TestCase):
 | 
				
			||||||
      self.assertIsInstance(segmenter, _ImageSegmenter)
 | 
					      self.assertIsInstance(segmenter, _ImageSegmenter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
    # Invalid empty model path.
 | 
					 | 
				
			||||||
    with self.assertRaisesRegex(
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
        ValueError,
 | 
					        RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
 | 
				
			||||||
        r"ExternalFile must specify at least one of 'file_content', "
 | 
					      base_options = _BaseOptions(
 | 
				
			||||||
        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
					          model_asset_path='/path/to/invalid/model.tflite')
 | 
				
			||||||
      base_options = _BaseOptions(model_asset_path='')
 | 
					 | 
				
			||||||
      options = _ImageSegmenterOptions(base_options=base_options)
 | 
					      options = _ImageSegmenterOptions(base_options=base_options)
 | 
				
			||||||
      _ImageSegmenter.create_from_options(options)
 | 
					      _ImageSegmenter.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,12 +119,10 @@ class ObjectDetectorTest(parameterized.TestCase):
 | 
				
			||||||
      self.assertIsInstance(detector, _ObjectDetector)
 | 
					      self.assertIsInstance(detector, _ObjectDetector)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
					  def test_create_from_options_fails_with_invalid_model_path(self):
 | 
				
			||||||
    # Invalid empty model path.
 | 
					 | 
				
			||||||
    with self.assertRaisesRegex(
 | 
					    with self.assertRaisesRegex(
 | 
				
			||||||
        ValueError,
 | 
					        RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
 | 
				
			||||||
        r"ExternalFile must specify at least one of 'file_content', "
 | 
					      base_options = _BaseOptions(
 | 
				
			||||||
        r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
 | 
					          model_asset_path='/path/to/invalid/model.tflite')
 | 
				
			||||||
      base_options = _BaseOptions(model_asset_path='')
 | 
					 | 
				
			||||||
      options = _ObjectDetectorOptions(base_options=base_options)
 | 
					      options = _ObjectDetectorOptions(base_options=base_options)
 | 
				
			||||||
      _ObjectDetector.create_from_options(options)
 | 
					      _ObjectDetector.create_from_options(options)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										6
									
								
								mediapipe/tasks/testdata/text/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								mediapipe/tasks/testdata/text/BUILD
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -28,6 +28,7 @@ mediapipe_files(srcs = [
 | 
				
			||||||
    "bert_text_classifier.tflite",
 | 
					    "bert_text_classifier.tflite",
 | 
				
			||||||
    "mobilebert_embedding_with_metadata.tflite",
 | 
					    "mobilebert_embedding_with_metadata.tflite",
 | 
				
			||||||
    "mobilebert_with_metadata.tflite",
 | 
					    "mobilebert_with_metadata.tflite",
 | 
				
			||||||
 | 
					    "regex_one_embedding_with_metadata.tflite",
 | 
				
			||||||
    "test_model_text_classifier_bool_output.tflite",
 | 
					    "test_model_text_classifier_bool_output.tflite",
 | 
				
			||||||
    "test_model_text_classifier_with_regex_tokenizer.tflite",
 | 
					    "test_model_text_classifier_with_regex_tokenizer.tflite",
 | 
				
			||||||
    "universal_sentence_encoder_qa_with_metadata.tflite",
 | 
					    "universal_sentence_encoder_qa_with_metadata.tflite",
 | 
				
			||||||
| 
						 | 
					@ -92,6 +93,11 @@ filegroup(
 | 
				
			||||||
    srcs = ["mobilebert_embedding_with_metadata.tflite"],
 | 
					    srcs = ["mobilebert_embedding_with_metadata.tflite"],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					filegroup(
 | 
				
			||||||
 | 
					    name = "regex_embedding_with_metadata",
 | 
				
			||||||
 | 
					    srcs = ["regex_one_embedding_with_metadata.tflite"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
filegroup(
 | 
					filegroup(
 | 
				
			||||||
    name = "universal_sentence_encoder_qa",
 | 
					    name = "universal_sentence_encoder_qa",
 | 
				
			||||||
    data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
 | 
					    data = ["universal_sentence_encoder_qa_with_metadata.tflite"],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,8 +119,8 @@ export class AudioClassifier extends TaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
					  async setOptions(options: AudioClassifierOptions): Promise<void> {
 | 
				
			||||||
    if (options.baseOptions) {
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
      const baseOptionsProto =
 | 
					      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
				
			||||||
          await convertBaseOptionsToProto(options.baseOptions);
 | 
					          options.baseOptions, this.options.getBaseOptions());
 | 
				
			||||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -26,6 +26,8 @@ mediapipe_ts_library(
 | 
				
			||||||
    name = "base_options",
 | 
					    name = "base_options",
 | 
				
			||||||
    srcs = ["base_options.ts"],
 | 
					    srcs = ["base_options.ts"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
 | 
					        "//mediapipe/calculators/tensor:inference_calculator_jspb_proto",
 | 
				
			||||||
 | 
					        "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
					        "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
 | 
					        "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto",
 | 
				
			||||||
        "//mediapipe/tasks/web/core",
 | 
					        "//mediapipe/tasks/web/core",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,6 +14,8 @@
 | 
				
			||||||
 * limitations under the License.
 | 
					 * limitations under the License.
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import {InferenceCalculatorOptions} from '../../../../calculators/tensor/inference_calculator_pb';
 | 
				
			||||||
 | 
					import {Acceleration} from '../../../../tasks/cc/core/proto/acceleration_pb';
 | 
				
			||||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
					import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
 | 
				
			||||||
import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
 | 
					import {ExternalFile} from '../../../../tasks/cc/core/proto/external_file_pb';
 | 
				
			||||||
import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
					import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
				
			||||||
| 
						 | 
					@ -25,26 +27,60 @@ import {BaseOptions} from '../../../../tasks/web/core/base_options';
 | 
				
			||||||
 * Converts a BaseOptions API object to its Protobuf representation.
 | 
					 * Converts a BaseOptions API object to its Protobuf representation.
 | 
				
			||||||
 * @throws If neither a model assset path or buffer is provided
 | 
					 * @throws If neither a model assset path or buffer is provided
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
export async function convertBaseOptionsToProto(baseOptions: BaseOptions):
 | 
					export async function convertBaseOptionsToProto(
 | 
				
			||||||
    Promise<BaseOptionsProto> {
 | 
					    updatedOptions: BaseOptions,
 | 
				
			||||||
  if (baseOptions.modelAssetPath && baseOptions.modelAssetBuffer) {
 | 
					    currentOptions?: BaseOptionsProto): Promise<BaseOptionsProto> {
 | 
				
			||||||
 | 
					  const result =
 | 
				
			||||||
 | 
					      currentOptions ? currentOptions.clone() : new BaseOptionsProto();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  await configureExternalFile(updatedOptions, result);
 | 
				
			||||||
 | 
					  configureAcceleration(updatedOptions, result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * Configues the `externalFile` option and validates that a single model is
 | 
				
			||||||
 | 
					 * provided.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					async function configureExternalFile(
 | 
				
			||||||
 | 
					    options: BaseOptions, proto: BaseOptionsProto) {
 | 
				
			||||||
 | 
					  const externalFile = proto.getModelAsset() || new ExternalFile();
 | 
				
			||||||
 | 
					  proto.setModelAsset(externalFile);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (options.modelAssetPath || options.modelAssetBuffer) {
 | 
				
			||||||
 | 
					    if (options.modelAssetPath && options.modelAssetBuffer) {
 | 
				
			||||||
      throw new Error(
 | 
					      throw new Error(
 | 
				
			||||||
          'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
 | 
					          'Cannot set both baseOptions.modelAssetPath and baseOptions.modelAssetBuffer');
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  if (!baseOptions.modelAssetPath && !baseOptions.modelAssetBuffer) {
 | 
					
 | 
				
			||||||
 | 
					    let modelAssetBuffer = options.modelAssetBuffer;
 | 
				
			||||||
 | 
					    if (!modelAssetBuffer) {
 | 
				
			||||||
 | 
					      const response = await fetch(options.modelAssetPath!.toString());
 | 
				
			||||||
 | 
					      modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    externalFile.setFileContent(modelAssetBuffer);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (!externalFile.hasFileContent()) {
 | 
				
			||||||
    throw new Error(
 | 
					    throw new Error(
 | 
				
			||||||
        'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
 | 
					        'Either baseOptions.modelAssetPath or baseOptions.modelAssetBuffer must be set');
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
  let modelAssetBuffer = baseOptions.modelAssetBuffer;
 | 
					
 | 
				
			||||||
  if (!modelAssetBuffer) {
 | 
					/** Configues the `acceleration` option. */
 | 
				
			||||||
    const response = await fetch(baseOptions.modelAssetPath!.toString());
 | 
					function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
 | 
				
			||||||
    modelAssetBuffer = new Uint8Array(await response.arrayBuffer());
 | 
					  if ('delegate' in options) {
 | 
				
			||||||
  }
 | 
					    const acceleration = new Acceleration();
 | 
				
			||||||
 | 
					    if (options.delegate === 'cpu') {
 | 
				
			||||||
  const proto = new BaseOptionsProto();
 | 
					      acceleration.setXnnpack(
 | 
				
			||||||
  const externalFile = new ExternalFile();
 | 
					          new InferenceCalculatorOptions.Delegate.Xnnpack());
 | 
				
			||||||
  externalFile.setFileContent(modelAssetBuffer);
 | 
					      proto.setAcceleration(acceleration);
 | 
				
			||||||
  proto.setModelAsset(externalFile);
 | 
					    } else if (options.delegate === 'gpu') {
 | 
				
			||||||
  return proto;
 | 
					      acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
 | 
				
			||||||
 | 
					      proto.setAcceleration(acceleration);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      proto.clearAcceleration();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										8
									
								
								mediapipe/tasks/web/core/base_options.d.ts
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								mediapipe/tasks/web/core/base_options.d.ts
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -22,10 +22,14 @@ export interface BaseOptions {
 | 
				
			||||||
   * The model path to the model asset file. Only one of `modelAssetPath` or
 | 
					   * The model path to the model asset file. Only one of `modelAssetPath` or
 | 
				
			||||||
   * `modelAssetBuffer` can be set.
 | 
					   * `modelAssetBuffer` can be set.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  modelAssetPath?: string;
 | 
					  modelAssetPath?: string|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * A buffer containing the model aaset. Only one of `modelAssetPath` or
 | 
					   * A buffer containing the model aaset. Only one of `modelAssetPath` or
 | 
				
			||||||
   * `modelAssetBuffer` can be set.
 | 
					   * `modelAssetBuffer` can be set.
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  modelAssetBuffer?: Uint8Array;
 | 
					  modelAssetBuffer?: Uint8Array|undefined;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Overrides the default backend to use for the provided model. */
 | 
				
			||||||
 | 
					  delegate?: 'cpu'|'gpu'|undefined;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -111,8 +111,8 @@ export class TextClassifier extends TaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  async setOptions(options: TextClassifierOptions): Promise<void> {
 | 
					  async setOptions(options: TextClassifierOptions): Promise<void> {
 | 
				
			||||||
    if (options.baseOptions) {
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
      const baseOptionsProto =
 | 
					      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
				
			||||||
          await convertBaseOptionsToProto(options.baseOptions);
 | 
					          options.baseOptions, this.options.getBaseOptions());
 | 
				
			||||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -171,8 +171,8 @@ export class GestureRecognizer extends TaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  async setOptions(options: GestureRecognizerOptions): Promise<void> {
 | 
					  async setOptions(options: GestureRecognizerOptions): Promise<void> {
 | 
				
			||||||
    if (options.baseOptions) {
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
      const baseOptionsProto =
 | 
					      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
				
			||||||
          await convertBaseOptionsToProto(options.baseOptions);
 | 
					          options.baseOptions, this.options.getBaseOptions());
 | 
				
			||||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -114,8 +114,8 @@ export class ImageClassifier extends TaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  async setOptions(options: ImageClassifierOptions): Promise<void> {
 | 
					  async setOptions(options: ImageClassifierOptions): Promise<void> {
 | 
				
			||||||
    if (options.baseOptions) {
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
      const baseOptionsProto =
 | 
					      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
				
			||||||
          await convertBaseOptionsToProto(options.baseOptions);
 | 
					          options.baseOptions, this.options.getBaseOptions());
 | 
				
			||||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -112,8 +112,8 @@ export class ObjectDetector extends TaskRunner {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  async setOptions(options: ObjectDetectorOptions): Promise<void> {
 | 
					  async setOptions(options: ObjectDetectorOptions): Promise<void> {
 | 
				
			||||||
    if (options.baseOptions) {
 | 
					    if (options.baseOptions) {
 | 
				
			||||||
      const baseOptionsProto =
 | 
					      const baseOptionsProto = await convertBaseOptionsToProto(
 | 
				
			||||||
          await convertBaseOptionsToProto(options.baseOptions);
 | 
					          options.baseOptions, this.options.getBaseOptions());
 | 
				
			||||||
      this.options.setBaseOptions(baseOptionsProto);
 | 
					      this.options.setBaseOptions(baseOptionsProto);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										26
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										26
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -402,8 +402,8 @@ def external_files():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
        name = "com_google_mediapipe_labels_txt",
 | 
					        name = "com_google_mediapipe_labels_txt",
 | 
				
			||||||
        sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
 | 
					        sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
 | 
				
			||||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667855388142641"],
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667888034706429"],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
| 
						 | 
					@ -552,8 +552,14 @@ def external_files():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
        name = "com_google_mediapipe_movie_review_json",
 | 
					        name = "com_google_mediapipe_movie_review_json",
 | 
				
			||||||
        sha256 = "89ad347ad1cb7c587da144de6efbadec1d3e8ff0cd13e379dd16661a8186fbb5",
 | 
					        sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e",
 | 
				
			||||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667855392734031"],
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667888039053188"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http_file(
 | 
				
			||||||
 | 
					        name = "com_google_mediapipe_movie_review_labels_txt",
 | 
				
			||||||
 | 
					        sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
 | 
				
			||||||
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667888041670721"],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
| 
						 | 
					@ -688,6 +694,18 @@ def external_files():
 | 
				
			||||||
        urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"],
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/README.md?generation=1661875904887163"],
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http_file(
 | 
				
			||||||
 | 
					        name = "com_google_mediapipe_regex_one_embedding_with_metadata_tflite",
 | 
				
			||||||
 | 
					        sha256 = "b8f5d6d090c2c73984b2b92cd2fda27e5630562741a93d127b9a744d60505bc0",
 | 
				
			||||||
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/regex_one_embedding_with_metadata.tflite?generation=1667888045310541"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http_file(
 | 
				
			||||||
 | 
					        name = "com_google_mediapipe_regex_vocab_txt",
 | 
				
			||||||
 | 
					        sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
 | 
				
			||||||
 | 
					        urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667888047885461"],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    http_file(
 | 
					    http_file(
 | 
				
			||||||
        name = "com_google_mediapipe_right_hands_jpg",
 | 
					        name = "com_google_mediapipe_right_hands_jpg",
 | 
				
			||||||
        sha256 = "240c082e80128ff1ca8a83ce645e2ba4d8bc30f0967b7991cf5fa375bab489e1",
 | 
					        sha256 = "240c082e80128ff1ca8a83ce645e2ba4d8bc30f0967b7991cf5fa375bab489e1",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user