Internal change

PiperOrigin-RevId: 495613573
This commit is contained in:
Jiuqiang Tang 2022-12-15 09:20:22 -08:00 committed by Copybara-Service
parent 6db5eabe0b
commit 299aa03302
14 changed files with 239 additions and 6 deletions

View File

@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi {
TaskRunner.create(
context,
TaskInfo.<AudioClassifierOptions>builder()
.setTaskName(AudioClassifier.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
TaskRunner.create(
context,
TaskInfo.<AudioEmbedderOptions>builder()
.setTaskName(AudioEmbedder.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -22,6 +22,7 @@ android_library(
],
manifest = "AndroidManifest.xml",
deps = [
":logging",
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
@ -37,6 +38,17 @@ android_library(
],
)
android_library(
name = "logging",
srcs = glob(
["logging/*.java"],
),
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar")
mediapipe_tasks_core_aar(

View File

@ -32,6 +32,12 @@ public abstract class TaskInfo<T extends TaskOptions> {
/** Builder for {@link TaskInfo}. */
@AutoValue.Builder
public abstract static class Builder<T extends TaskOptions> {
/** Sets the MediaPipe task name. */
public abstract Builder<T> setTaskName(String value);
/** Sets the MediaPipe task running mode name. */
public abstract Builder<T> setTaskRunningModeName(String value);
/** Sets the MediaPipe task graph name. */
public abstract Builder<T> setTaskGraphName(String value);
@ -71,6 +77,10 @@ public abstract class TaskInfo<T extends TaskOptions> {
}
}
abstract String taskName();
abstract String taskRunningModeName();
abstract String taskGraphName();
abstract T taskOptions();
@ -82,7 +92,7 @@ public abstract class TaskInfo<T extends TaskOptions> {
abstract Boolean enableFlowLimiting();
public static <T extends TaskOptions> Builder<T> builder() {
return new AutoValue_TaskInfo.Builder<T>();
return new AutoValue_TaskInfo.Builder<T>().setTaskName("").setTaskRunningModeName("");
}
/* Returns a list of the output stream names without the stream tags. */

View File

@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.tasks.core.logging.TasksStatsLogger;
import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable {
private final Graph graph;
private final ModelResourcesCache modelResourcesCache;
private final AndroidPacketCreator packetCreator;
private final TasksStatsLogger statsLogger;
private long lastSeenTimestamp = Long.MIN_VALUE;
private ErrorListener errorListener;
@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable {
Context context,
TaskInfo<? extends TaskOptions> taskInfo,
OutputHandler<? extends TaskResult, ?> outputHandler) {
TasksStatsLogger statsLogger =
TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName());
AndroidAssetUtil.initializeNativeAssetManager(context);
Graph mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable {
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
mediapipeGraph.addMultiStreamCallback(
taskInfo.outputStreamNames(),
outputHandler::run,
packets -> {
outputHandler.run(packets);
statsLogger.recordInvocationEnd(packets.get(0).getTimestamp());
},
/* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges());
mediapipeGraph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully started.
mediapipeGraph.waitUntilGraphIdle();
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger);
}
/**
@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable {
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
*/
public synchronized TaskResult process(Map<String, Packet> inputs) {
addPackets(inputs, generateSyntheticTimestamp());
long syntheticInputTimestamp = generateSyntheticTimestamp();
// TODO: Support recording GPU input arrival.
statsLogger.recordCpuInputArrival(syntheticInputTimestamp);
addPackets(inputs, syntheticInputTimestamp);
graph.waitUntilGraphIdle();
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
return outputHandler.retrieveCachedTaskResult();
@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable {
*/
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
statsLogger.recordCpuInputArrival(inputTimestamp);
addPackets(inputs, inputTimestamp);
graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult();
@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable {
*/
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
statsLogger.recordCpuInputArrival(inputTimestamp);
addPackets(inputs, inputTimestamp);
}
@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
statsLogger.logSessionEnd();
} catch (MediaPipeException e) {
reportError(e);
}
@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable {
// Waits until all calculators are opened and the graph is fully restarted.
graph.waitUntilGraphIdle();
graphStarted.set(true);
statsLogger.logSessionStart();
} catch (MediaPipeException e) {
reportError(e);
}
@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
statsLogger.logSessionEnd();
if (modelResourcesCache != null) {
modelResourcesCache.release();
}
@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable {
private TaskRunner(
Graph graph,
ModelResourcesCache modelResourcesCache,
OutputHandler<? extends TaskResult, ?> outputHandler) {
OutputHandler<? extends TaskResult, ?> outputHandler,
TasksStatsLogger statsLogger) {
this.outputHandler = outputHandler;
this.graph = graph;
this.modelResourcesCache = modelResourcesCache;
this.packetCreator = new AndroidPacketCreator(graph);
this.statsLogger = statsLogger;
graphStarted.set(true);
this.statsLogger.logSessionStart();
}
/** Reports error. */

View File

@ -0,0 +1,78 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core.logging;
import android.content.Context;
/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */
public class TasksStatsDummyLogger implements TasksStatsLogger {
/**
* Creates the MediaPipe Tasks stats dummy logger.
*
* @param context a {@link Context}.
* @param taskNameStr the task api name.
* @param taskRunningModeStr the task running mode string representation.
*/
public static TasksStatsDummyLogger create(
Context context, String taskNameStr, String taskRunningModeStr) {
return new TasksStatsDummyLogger();
}
private TasksStatsDummyLogger() {}
/** Logs the start of a MediaPipe Tasks API session. */
@Override
public void logSessionStart() {}
/**
* Records MediaPipe Tasks API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordCpuInputArrival(long packetTimestamp) {}
/**
* Records MediaPipe Tasks API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordGpuInputArrival(long packetTimestamp) {}
/**
* Records the end of a Mediapipe Tasks API invocation.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordInvocationEnd(long packetTimestamp) {}
/** Logs the MediaPipe Tasks API periodic invocation report. */
@Override
public void logInvocationReport(StatsSnapshot stats) {}
/** Logs the Tasks API session end event. */
@Override
public void logSessionEnd() {}
/** Logs the MediaPipe Tasks API initialization error. */
@Override
public void logInitError() {}
}

View File

@ -0,0 +1,98 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core.logging;
import com.google.auto.value.AutoValue;
/** The stats logger interface that defines what MediaPipe Tasks events to log. */
public interface TasksStatsLogger {
/** Task stats snapshot. */
@AutoValue
abstract static class StatsSnapshot {
static StatsSnapshot create(
int cpuInputCount,
int gpuInputCount,
int finishedCount,
int droppedCount,
long totalLatencyMs,
long peakLatencyMs,
long elapsedTimeMs) {
return new AutoValue_TasksStatsLogger_StatsSnapshot(
cpuInputCount,
gpuInputCount,
finishedCount,
droppedCount,
totalLatencyMs,
peakLatencyMs,
elapsedTimeMs);
}
static StatsSnapshot createDefault() {
return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0);
}
abstract int cpuInputCount();
abstract int gpuInputCount();
abstract int finishedCount();
abstract int droppedCount();
abstract long totalLatencyMs();
abstract long peakLatencyMs();
abstract long elapsedTimeMs();
}
/** Logs the start of a MediaPipe Tasks API session. */
public void logSessionStart();
/**
* Records MediaPipe Tasks API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordCpuInputArrival(long packetTimestamp);
/**
* Records MediaPipe Tasks API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordGpuInputArrival(long packetTimestamp);
/**
* Records the end of a Mediapipe Tasks API invocation.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordInvocationEnd(long packetTimestamp);
/** Logs the MediaPipe Tasks API periodic invocation report. */
public void logInvocationReport(StatsSnapshot stats);
/** Logs the Tasks API session end event. */
public void logSessionEnd();
/** Logs the MediaPipe Tasks API initialization error. */
public void logInitError();
// TODO: Logs more error types.
}

View File

@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable {
TaskRunner.create(
context,
TaskInfo.<TextClassifierOptions>builder()
.setTaskName(TextClassifier.class.getSimpleName())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable {
TaskRunner.create(
context,
TaskInfo.<TextEmbedderOptions>builder()
.setTaskName(TextEmbedder.class.getSimpleName())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<GestureRecognizerOptions>builder()
.setTaskName(GestureRecognizer.class.getSimpleName())
.setTaskRunningModeName(recognizerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<HandLandmarkerOptions>builder()
.setTaskName(HandLandmarker.class.getSimpleName())
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<ImageClassifierOptions>builder()
.setTaskName(ImageClassifier.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<ImageEmbedderOptions>builder()
.setTaskName(ImageEmbedder.class.getSimpleName())
.setTaskRunningModeName(options.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)

View File

@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
TaskRunner.create(
context,
TaskInfo.<ObjectDetectorOptions>builder()
.setTaskName(ObjectDetector.class.getSimpleName())
.setTaskRunningModeName(detectorOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)