Project import generated by Copybara.

GitOrigin-RevId: 7e1d382a1788ebd8412c5626581b4c4cf2fe75ea
This commit is contained in:
MediaPipe Team 2021-11-12 20:10:43 -08:00 committed by chuoling
parent f4e7f6cc48
commit cf101e62a9
39 changed files with 1102 additions and 233 deletions

View File

@ -350,6 +350,9 @@ maven_install(
"com.google.auto.value:auto-value:1.8.1",
"com.google.auto.value:auto-value-annotations:1.8.1",
"com.google.code.findbugs:jsr305:latest.release",
"com.google.android.datatransport:transport-api:3.0.0",
"com.google.android.datatransport:transport-backend-cct:3.1.0",
"com.google.android.datatransport:transport-runtime:3.1.0",
"com.google.flogger:flogger-system-backend:0.6",
"com.google.flogger:flogger:0.6",
"com.google.guava:guava:27.0.1-android",

View File

@ -36,16 +36,6 @@ dependencies {
implementation 'com.google.mediapipe:facemesh:latest.release'
// Optional: MediaPipe Hands Solution.
implementation 'com.google.mediapipe:hands:latest.release'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.6'
implementation 'com.google.flogger:flogger-system-backend:0.6'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library
def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version"
}
```
@ -84,3 +74,58 @@ To build these apps:
by default. If needed, for example to run the apps on Android Emulator, set
the `RUN_ON_GPU` boolean variable to `false` in the app's
`MainActivity.java` to run the pipeline and model inference on CPU.
## MediaPipe Solution APIs Terms of Service
Last modified: November 12, 2021
Use of MediaPipe Solution APIs is subject to the
[Google APIs Terms of Service](https://developers.google.com/terms),
[Google API Services User Data Policy](https://developers.google.com/terms/api-services-user-data-policy),
and the terms below. Please check back from time to time as these terms and
policies are occasionally updated.
**Privacy**
When you use MediaPipe Solution APIs, processing of the input data (e.g. images,
video, text) fully happens on-device, and **MediaPipe does not send that input
data to Google servers**. As a result, you can use our APIs for processing data
that should not leave the device.
MediaPipe Android Solution APIs will contact Google servers from time to time in
order to receive things like bug fixes, updated models, and hardware accelerator
compatibility information. MediaPipe Android Solution APIs also send metrics
about the performance and utilization of the APIs in your app to Google. Google
uses this metrics data to measure performance, API usage, debug, maintain and
improve the APIs, and detect misuse or abuse, as further described in our
[Privacy Policy](https://policies.google.com/privacy).
**You are responsible for obtaining informed consent from your app users about
Googles processing of MediaPipe metrics data as required by applicable law.**
Data we collect may include the following, across all MediaPipe Android Solution
APIs:
- Device information (such as manufacturer, model, OS version and build) and
available ML hardware accelerators (GPU and DSP). Used for diagnostics and
usage analytics.
- App identification information (package name / bundle id, app version). Used
for diagnostics and usage analytics.
- API configuration (such as image format, resolution, and MediaPipe version
used). Used for diagnostics and usage analytics.
- Event type (such as initialize, download model, update, run, and detection).
Used for diagnostics and usage analytics.
- Error codes. Used for diagnostics.
- Performance metrics. Used for diagnostics.
- Per-installation identifiers that do not uniquely identify a user or
physical device. Used for operation of remote configuration and usage
analytics.
- Network request sender IP addresses. Used for remote configuration
diagnostics. Collected IP addresses are retained temporarily.

View File

@ -218,14 +218,13 @@ camera.start();
### Android Solution API
Please first follow general
[instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api)
to add MediaPipe Gradle dependencies, then try the Face Detection Solution API
in the companion
[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facedetection)
following
[these instructions](../getting_started/android_solutions.md#build-solution-example-apps-in-android-studio)
[instructions](../getting_started/android_solutions.md) to add MediaPipe Gradle
dependencies and try the Android Solution API in the companion
[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facedetection),
and learn more in the usage example below.
Supported configuration options:
* [staticImageMode](#static_image_mode)
* [modelSelection](#model_selection)

View File

@ -487,12 +487,9 @@ camera.start();
### Android Solution API
Please first follow general
[instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api)
to add MediaPipe Gradle dependencies, then try the Face Mesh Solution API in the
companion
[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facemesh)
following
[these instructions](../getting_started/android_solutions.md#build-solution-example-apps-in-android-studio)
[instructions](../getting_started/android_solutions.md) to add MediaPipe Gradle
dependencies and try the Android Solution API in the companion
[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/facemesh),
and learn more in the usage example below.
Supported configuration options:

View File

@ -398,13 +398,10 @@ camera.start();
### Android Solution API
Please first follow general
[instructions](../getting_started/android_solutions.md#integrate-mediapipe-android-solutions-api)
to add MediaPipe Gradle dependencies, then try the Hands Solution API in the
companion
[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/hands)
following
[these instructions](../getting_started/android_solutions.md#build-solution-example-apps-in-android-studio)
and learn more in usage example below.
[instructions](../getting_started/android_solutions.md) to add MediaPipe Gradle
dependencies and try the Android Solution API in the companion
[example Android Studio project](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/solutions/hands),
and learn more in the usage example below.
Supported configuration options:

View File

@ -1169,6 +1169,7 @@ cc_library(
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
"//mediapipe/util:rectangle_util",
"@com_google_absl//absl/memory",
],
alwayslink = 1,

View File

@ -26,20 +26,10 @@
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/rectangle_util.h"
namespace mediapipe {
// Computes the overlap similarity based on Intersection over Union (IoU) of
// two rectangles.
inline float OverlapSimilarity(const Rectangle_f& rect1,
const Rectangle_f& rect2) {
if (!rect1.Intersects(rect2)) return 0.0f;
// Compute IoU similarity score.
const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
const float normalization = rect1.Area() + rect2.Area() - intersection_area;
return normalization > 0.0f ? intersection_area / normalization : 0.0f;
}
// AssocationCalculator<T> accepts multiple inputs of vectors of type T that can
// be converted to Rectangle_f. The output is a vector of type T that contains
// elements from the input vectors that don't overlap with each other. When
@ -187,7 +177,7 @@ class AssociationCalculator : public CalculatorBase {
for (auto uit = current->begin(); uit != current->end();) {
ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit));
if (OverlapSimilarity(cur_rect, prev_rect) >
if (CalculateIou(cur_rect, prev_rect) >
options_.min_similarity_threshold()) {
std::pair<bool, int> prev_id = GetId(*uit);
// If prev_id.first is false when some element doesn't have an ID,
@ -232,7 +222,7 @@ class AssociationCalculator : public CalculatorBase {
}
const Rectangle_f& prev_rect = get_prev_rectangle.value();
if (OverlapSimilarity(cur_rect, prev_rect) >
if (CalculateIou(cur_rect, prev_rect) >
options_.min_similarity_threshold()) {
std::pair<bool, int> prev_id = GetId(prev_input_vec[ui]);
// If prev_id.first is false when some element doesn't have an ID,

View File

@ -35,17 +35,7 @@ dependencies {
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
// MediaPipe Face Detection Solution components.
// MediaPipe Face Detection Solution.
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:facedetection:latest.release'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.6'
implementation 'com.google.flogger:flogger-system-backend:0.6'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library
def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version"
}

View File

@ -11,6 +11,9 @@
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<!-- For logging solution events -->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<application
android:allowBackup="true"

View File

@ -35,17 +35,7 @@ dependencies {
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
// MediaPipe Face Mesh Solution components.
// MediaPipe Face Mesh Solution.
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:facemesh:latest.release'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.6'
implementation 'com.google.flogger:flogger-system-backend:0.6'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library
def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version"
}

View File

@ -11,6 +11,9 @@
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<!-- For logging solution events -->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<application
android:allowBackup="true"

View File

@ -35,17 +35,7 @@ dependencies {
testImplementation 'junit:junit:4.+'
androidTestImplementation 'androidx.test.ext:junit:1.1.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'
// MediaPipe Hands Solution components.
// MediaPipe Hands Solution.
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:hands:latest.release'
// MediaPipe deps
implementation 'com.google.flogger:flogger:0.6'
implementation 'com.google.flogger:flogger-system-backend:0.6'
implementation 'com.google.guava:guava:27.0.1-android'
implementation 'com.google.protobuf:protobuf-java:3.11.4'
// CameraX core library
def camerax_version = "1.0.0-beta10"
implementation "androidx.camera:camera-core:$camerax_version"
implementation "androidx.camera:camera-camera2:$camerax_version"
implementation "androidx.camera:camera-lifecycle:$camerax_version"
}

View File

@ -11,6 +11,9 @@
<!-- For using the camera -->
<uses-permission android:name="android.permission.CAMERA" />
<uses-feature android:name="android.hardware.camera" />
<!-- For logging solution events -->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<application
android:allowBackup="true"

View File

@ -74,6 +74,7 @@ cc_library(
":content_zooming_calculator_state",
"//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:location_data_cc_proto",

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
@ -57,6 +58,8 @@ constexpr float kFieldOfView = 60;
constexpr char kStateCache[] = "STATE_CACHE";
// Tolerance for zooming out recentering.
constexpr float kPixelTolerance = 3;
// Returns 'true' when camera is moving (pan/tilt/zoom) & 'false' for no motion.
constexpr char kCameraActive[] = "CAMERA_ACTIVE";
namespace mediapipe {
namespace autoflip {
@ -181,6 +184,9 @@ absl::Status ContentZoomingCalculator::GetContract(
if (cc->InputSidePackets().HasTag(kStateCache)) {
cc->InputSidePackets().Tag(kStateCache).Set<StateCacheType*>();
}
if (cc->Outputs().HasTag(kCameraActive)) {
cc->Outputs().Tag(kCameraActive).Set<bool>();
}
return absl::OkStatus();
}
@ -649,6 +655,13 @@ absl::Status ContentZoomingCalculator::Process(
path_solver_tilt_->ClearHistory();
path_solver_zoom_->ClearHistory();
}
const bool camera_active =
is_animating || pan_state || tilt_state || zoom_state;
if (cc->Outputs().HasTag(kCameraActive)) {
cc->Outputs()
.Tag(kCameraActive)
.AddPacket(MakePacket<bool>(camera_active).At(cc->InputTimestamp()));
}
// Compute smoothed zoom camera path.
MP_RETURN_IF_ERROR(path_solver_zoom_->AddObservation(

View File

@ -75,6 +75,9 @@ class TagIndexMap {
std::map<std::string, std::vector<std::unique_ptr<T>>> map_;
};
class Graph;
class NodeBase;
// These structs are used internally to store information about the endpoints
// of a connection.
struct SourceBase;
@ -109,7 +112,7 @@ class MultiPort : public Single {
// These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API.
template <bool AllowMultiple, bool IsSide, typename T = internal::Generic>
template <bool IsSide, typename T = internal::Generic>
class DestinationImpl {
public:
using Base = DestinationBase;
@ -121,13 +124,12 @@ class DestinationImpl {
};
template <bool IsSide, typename T>
class DestinationImpl<true, IsSide, T>
: public MultiPort<DestinationImpl<false, IsSide, T>> {
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
public:
using MultiPort<DestinationImpl<false, IsSide, T>>::MultiPort;
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
};
template <bool AllowMultiple, bool IsSide, typename T = internal::Generic>
template <bool IsSide, typename T = internal::Generic>
class SourceImpl {
public:
using Base = SourceBase;
@ -135,9 +137,9 @@ class SourceImpl {
// Src is used as the return type of fluent methods below. Since these are
// single-port methods, it is desirable to always decay to a reference to the
// single-port superclass, even if they are called on a multiport.
using Src = SourceImpl<false, IsSide, T>;
using Src = SourceImpl<IsSide, T>;
template <typename U>
using Dst = DestinationImpl<false, IsSide, U>;
using Dst = DestinationImpl<IsSide, U>;
// clang-format off
template <typename U>
@ -173,10 +175,9 @@ class SourceImpl {
};
template <bool IsSide, typename T>
class SourceImpl<true, IsSide, T>
: public MultiPort<SourceImpl<false, IsSide, T>> {
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
public:
using MultiPort<SourceImpl<false, IsSide, T>>::MultiPort;
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
};
// A source and a destination correspond to an output/input stream on a node,
@ -185,14 +186,23 @@ class SourceImpl<true, IsSide, T>
// For graph inputs/outputs, however, the inputs are sources, and the outputs
// are destinations. This is because graph ports are connected "from inside"
// when building the graph.
template <bool AllowMultiple = false, typename T = internal::Generic>
using Source = SourceImpl<AllowMultiple, false, T>;
template <bool AllowMultiple = false, typename T = internal::Generic>
using SideSource = SourceImpl<AllowMultiple, true, T>;
template <bool AllowMultiple = false, typename T = internal::Generic>
using Destination = DestinationImpl<AllowMultiple, false, T>;
template <bool AllowMultiple = false, typename T = internal::Generic>
using SideDestination = DestinationImpl<AllowMultiple, true, T>;
template <typename T = internal::Generic>
using Source = SourceImpl<false, T>;
template <typename T = internal::Generic>
using MultiSource = MultiSourceImpl<false, T>;
template <typename T = internal::Generic>
using SideSource = SourceImpl<true, T>;
template <typename T = internal::Generic>
using MultiSideSource = MultiSourceImpl<true, T>;
template <typename T = internal::Generic>
using Destination = DestinationImpl<false, T>;
template <typename T = internal::Generic>
using SideDestination = DestinationImpl<true, T>;
template <typename T = internal::Generic>
using MultiDestination = MultiDestinationImpl<false, T>;
template <typename T = internal::Generic>
using MultiSideDestination = MultiDestinationImpl<true, T>;
class NodeBase {
public:
@ -202,45 +212,67 @@ class NodeBase {
// of its entries by index. However, for nodes without visible contracts we
// can't know whether a tag is indexable or not, so we would need the
// multi-port to also be usable as a port directly (representing index 0).
Source<true> Out(const std::string& tag) {
return Source<true>(&out_streams_[tag]);
MultiSource<> Out(const std::string& tag) {
return MultiSource<>(&out_streams_[tag]);
}
Destination<true> In(const std::string& tag) {
return Destination<true>(&in_streams_[tag]);
MultiDestination<> In(const std::string& tag) {
return MultiDestination<>(&in_streams_[tag]);
}
SideSource<true> SideOut(const std::string& tag) {
return SideSource<true>(&out_sides_[tag]);
MultiSideSource<> SideOut(const std::string& tag) {
return MultiSideSource<>(&out_sides_[tag]);
}
SideDestination<true> SideIn(const std::string& tag) {
return SideDestination<true>(&in_sides_[tag]);
MultiSideDestination<> SideIn(const std::string& tag) {
return MultiSideDestination<>(&in_sides_[tag]);
}
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
auto operator[](const PortCommon<B, T, kIsOptional, kIsMultiple>& port) {
using PayloadT =
typename PortCommon<B, T, kIsOptional, kIsMultiple>::PayloadT;
if constexpr (std::is_same_v<B, OutputBase>) {
return Source<kIsMultiple, T>(&out_streams_[port.Tag()]);
auto* base = &out_streams_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiSource<PayloadT>(base);
} else {
return Source<PayloadT>(base);
}
} else if constexpr (std::is_same_v<B, InputBase>) {
return Destination<kIsMultiple, T>(&in_streams_[port.Tag()]);
auto* base = &in_streams_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiDestination<PayloadT>(base);
} else {
return Destination<PayloadT>(base);
}
} else if constexpr (std::is_same_v<B, SideOutputBase>) {
return SideSource<kIsMultiple, T>(&out_sides_[port.Tag()]);
auto* base = &out_sides_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiSideSource<PayloadT>(base);
} else {
return SideSource<PayloadT>(base);
}
} else if constexpr (std::is_same_v<B, SideInputBase>) {
return SideDestination<kIsMultiple, T>(&in_sides_[port.Tag()]);
auto* base = &in_sides_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiSideDestination<PayloadT>(base);
} else {
return SideDestination<PayloadT>(base);
}
} else {
static_assert(dependent_false<B>::value, "Type not supported.");
}
}
// Convenience methods for accessing purely index-based ports.
Source<false> Out(int index) { return Out("")[index]; }
Source<> Out(int index) { return Out("")[index]; }
Destination<false> In(int index) { return In("")[index]; }
Destination<> In(int index) { return In("")[index]; }
SideSource<false> SideOut(int index) { return SideOut("")[index]; }
SideSource<> SideOut(int index) { return SideOut("")[index]; }
SideDestination<false> SideIn(int index) { return SideIn("")[index]; }
SideDestination<> SideIn(int index) { return SideIn("")[index]; }
template <typename T>
T& GetOptions() {
@ -277,11 +309,6 @@ class Node<internal::Generic> : public NodeBase {
using GenericNode = Node<internal::Generic>;
template <template <bool, class> class BP, class Port, class TagIndexMapT>
auto MakeBuilderPort(const Port& port, TagIndexMapT& streams) {
return BP<Port::kMultiple, typename Port::PayloadT>(&streams[port.Tag()]);
}
template <class Calc>
class Node : public NodeBase {
public:
@ -298,25 +325,25 @@ class Node : public NodeBase {
template <class Tag>
auto Out(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedOutputs::get(tag);
return MakeBuilderPort<Source>(port, out_streams_);
return NodeBase::operator[](port);
}
template <class Tag>
auto In(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedInputs::get(tag);
return MakeBuilderPort<Destination>(port, in_streams_);
return NodeBase::operator[](port);
}
template <class Tag>
auto SideOut(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedSideOutputs::get(tag);
return MakeBuilderPort<SideSource>(port, out_sides_);
return NodeBase::operator[](port);
}
template <class Tag>
auto SideIn(Tag tag) {
constexpr auto& port = Calc::Contract::TaggedSideInputs::get(tag);
return MakeBuilderPort<SideDestination>(port, in_sides_);
return NodeBase::operator[](port);
}
// We could allow using the non-checked versions with typed nodes too, but
@ -332,17 +359,17 @@ class PacketGenerator {
public:
PacketGenerator(std::string type) : type_(std::move(type)) {}
SideSource<true> SideOut(const std::string& tag) {
return SideSource<true>(&out_sides_[tag]);
MultiSideSource<> SideOut(const std::string& tag) {
return MultiSideSource<>(&out_sides_[tag]);
}
SideDestination<true> SideIn(const std::string& tag) {
return SideDestination<true>(&in_sides_[tag]);
MultiSideDestination<> SideIn(const std::string& tag) {
return MultiSideDestination<>(&in_sides_[tag]);
}
// Convenience methods for accessing purely index-based ports.
SideSource<false> SideOut(int index) { return SideOut("")[index]; }
SideDestination<false> SideIn(int index) { return SideIn("")[index]; }
SideSource<> SideOut(int index) { return SideOut("")[index]; }
SideDestination<> SideIn(int index) { return SideIn("")[index]; }
template <typename T>
T& GetOptions() {
@ -402,70 +429,85 @@ class Graph {
}
// Graph ports, non-typed.
Source<true> In(const std::string& graph_input) {
MultiSource<> In(const std::string& graph_input) {
return graph_boundary_.Out(graph_input);
}
Destination<true> Out(const std::string& graph_output) {
MultiDestination<> Out(const std::string& graph_output) {
return graph_boundary_.In(graph_output);
}
SideSource<true> SideIn(const std::string& graph_input) {
MultiSideSource<> SideIn(const std::string& graph_input) {
return graph_boundary_.SideOut(graph_input);
}
SideDestination<true> SideOut(const std::string& graph_output) {
MultiSideDestination<> SideOut(const std::string& graph_output) {
return graph_boundary_.SideIn(graph_output);
}
// Convenience methods for accessing purely index-based ports.
Source<false> In(int index) { return In("")[0]; }
Source<> In(int index) { return In("")[index]; }
Destination<false> Out(int index) { return Out("")[0]; }
Destination<> Out(int index) { return Out("")[index]; }
SideSource<false> SideIn(int index) { return SideIn("")[0]; }
SideSource<> SideIn(int index) { return SideIn("")[index]; }
SideDestination<false> SideOut(int index) { return SideOut("")[0]; }
SideDestination<> SideOut(int index) { return SideOut("")[index]; }
// Graph ports, typed.
// TODO: make graph_boundary_ a typed node!
template <class PortT, class Payload = typename PortT::PayloadT,
class Src = Source<PortT::kMultiple, Payload>>
Src In(const PortT& graph_input) {
return Src(&graph_boundary_.out_streams_[graph_input.Tag()]);
template <class PortT, class Payload = typename PortT::PayloadT>
auto In(const PortT& graph_input) {
return (*this)[graph_input];
}
template <class PortT, class Payload = typename PortT::PayloadT,
class Dst = Destination<PortT::kMultiple, Payload>>
Dst Out(const PortT& graph_output) {
return Dst(&graph_boundary_.in_streams_[graph_output.Tag()]);
template <class PortT, class Payload = typename PortT::PayloadT>
auto Out(const PortT& graph_output) {
return (*this)[graph_output];
}
template <class PortT, class Payload = typename PortT::PayloadT,
class Src = SideSource<PortT::kMultiple, Payload>>
Src SideIn(const PortT& graph_input) {
return Src(&graph_boundary_.out_sides_[graph_input.Tag()]);
template <class PortT, class Payload = typename PortT::PayloadT>
auto SideIn(const PortT& graph_input) {
return (*this)[graph_input];
}
template <class PortT, class Payload = typename PortT::PayloadT,
class Dst = SideDestination<PortT::kMultiple, Payload>>
Dst SideOut(const PortT& graph_output) {
return Dst(&graph_boundary_.in_sides_[graph_output.Tag()]);
template <class PortT, class Payload = typename PortT::PayloadT>
auto SideOut(const PortT& graph_output) {
return (*this)[graph_output];
}
template <typename B, typename T, bool kIsOptional, bool kIsMultiple>
auto operator[](const PortCommon<B, T, kIsOptional, kIsMultiple>& port) {
using PayloadT =
typename PortCommon<B, T, kIsOptional, kIsMultiple>::PayloadT;
if constexpr (std::is_same_v<B, OutputBase>) {
return Destination<kIsMultiple, T>(
&graph_boundary_.in_streams_[port.Tag()]);
auto* base = &graph_boundary_.in_streams_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiDestination<PayloadT>(base);
} else {
return Destination<PayloadT>(base);
}
} else if constexpr (std::is_same_v<B, InputBase>) {
return Source<kIsMultiple, T>(&graph_boundary_.out_streams_[port.Tag()]);
auto* base = &graph_boundary_.out_streams_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiSource<PayloadT>(base);
} else {
return Source<PayloadT>(base);
}
} else if constexpr (std::is_same_v<B, SideOutputBase>) {
return SideDestination<kIsMultiple, T>(
&graph_boundary_.in_sides_[port.Tag()]);
auto* base = &graph_boundary_.in_sides_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiSideDestination<PayloadT>(base);
} else {
return SideDestination<PayloadT>(base);
}
} else if constexpr (std::is_same_v<B, SideInputBase>) {
return SideSource<kIsMultiple, T>(
&graph_boundary_.out_sides_[port.Tag()]);
auto* base = &graph_boundary_.out_sides_[port.Tag()];
if constexpr (kIsMultiple) {
return MultiSideSource<PayloadT>(base);
} else {
return SideSource<PayloadT>(base);
}
} else {
static_assert(dependent_false<B>::value, "Type not supported.");
}

View File

@ -50,21 +50,21 @@ TEST(BuilderTest, BuildGraph) {
TEST(BuilderTest, CopyableSource) {
builder::Graph graph;
builder::Source<false, int> a = graph[Input<int>("A")];
builder::Source<int> a = graph[Input<int>("A")];
a.SetName("a");
builder::Source<false, int> b = graph[Input<int>("B")];
builder::Source<int> b = graph[Input<int>("B")];
b.SetName("b");
builder::SideSource<false, float> side_a = graph[SideInput<float>("SIDE_A")];
builder::SideSource<float> side_a = graph[SideInput<float>("SIDE_A")];
side_a.SetName("side_a");
builder::SideSource<false, float> side_b = graph[SideInput<float>("SIDE_B")];
builder::SideSource<float> side_b = graph[SideInput<float>("SIDE_B")];
side_b.SetName("side_b");
builder::Destination<false, int> out = graph[Output<int>("OUT")];
builder::SideDestination<false, float> side_out =
builder::Destination<int> out = graph[Output<int>("OUT")];
builder::SideDestination<float> side_out =
graph[SideOutput<float>("SIDE_OUT")];
builder::Source<false, int> input = a;
builder::Source<int> input = a;
input = b;
builder::SideSource<false, float> side_input = side_b;
builder::SideSource<float> side_input = side_b;
side_input = side_a;
input >> out;
@ -85,27 +85,26 @@ TEST(BuilderTest, CopyableSource) {
TEST(BuilderTest, BuildGraphWithFunctions) {
builder::Graph graph;
builder::Source<false, int> base = graph[Input<int>("IN")];
builder::Source<int> base = graph[Input<int>("IN")];
base.SetName("base");
builder::SideSource<false, float> side = graph[SideInput<float>("SIDE")];
builder::SideSource<float> side = graph[SideInput<float>("SIDE")];
side.SetName("side");
auto foo_fn = [](builder::Source<false, int> base,
builder::SideSource<false, float> side,
auto foo_fn = [](builder::Source<int> base, builder::SideSource<float> side,
builder::Graph& graph) {
auto& foo = graph.AddNode("Foo");
base >> foo[Input<int>("BASE")];
side >> foo[SideInput<float>("SIDE")];
return foo[Output<double>("OUT")];
};
builder::Source<false, double> foo_out = foo_fn(base, side, graph);
builder::Source<double> foo_out = foo_fn(base, side, graph);
auto bar_fn = [](builder::Source<false, double> in, builder::Graph& graph) {
auto bar_fn = [](builder::Source<double> in, builder::Graph& graph) {
auto& bar = graph.AddNode("Bar");
in >> bar[Input<double>("IN")];
return bar[Output<double>("OUT")];
};
builder::Source<false, double> bar_out = bar_fn(foo_out, graph);
builder::Source<double> bar_out = bar_fn(foo_out, graph);
bar_out.SetName("out");
bar_out >> graph[Output<double>("OUT")];
@ -298,6 +297,34 @@ TEST(BuilderTest, EmptyTag) {
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
TEST(BuilderTest, GraphIndexes) {
builder::Graph graph;
auto& foo = graph.AddNode("Foo");
graph.In(0).SetName("a") >> foo.In("")[0];
graph.In(1).SetName("c") >> foo.In("")[2];
graph.In(2).SetName("b") >> foo.In("")[1];
foo.Out("")[0].SetName("x") >> graph.Out(1);
foo.Out("")[1].SetName("y") >> graph.Out(0);
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "a"
input_stream: "c"
input_stream: "b"
output_stream: "y"
output_stream: "x"
node {
calculator: "Foo"
input_stream: "a"
input_stream: "b"
input_stream: "c"
output_stream: "x"
output_stream: "y"
}
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}
} // namespace test
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,28 @@
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_
#include <cstdint>
namespace mediapipe {
// Generates unique view id at compile-time using FILE and LINE.
#define TENSOR_UNIQUE_VIEW_ID() \
static constexpr uint64_t kId = tensor_internal::FnvHash64( \
__FILE__, tensor_internal::FnvHash64(TENSOR_INT_TO_STRING(__LINE__)))
namespace tensor_internal {
#define TENSOR_INT_TO_STRING2(x) #x
#define TENSOR_INT_TO_STRING(x) TENSOR_INT_TO_STRING2(x)
// Compile-time hash function
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
constexpr uint64_t kFnvPrime = 0x00000100000001B3;
constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325;
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) {
return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime);
}
} // namespace tensor_internal
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_INTERNAL_H_

View File

@ -21,6 +21,7 @@
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe {
namespace packet_internal {
@ -105,6 +106,22 @@ std::string Packet::DebugString() const {
return result;
}
absl::Status Packet::ValidateAsType(const tool::TypeInfo& type_info) const {
if (ABSL_PREDICT_FALSE(IsEmpty())) {
return absl::InternalError(
absl::StrCat("Expected a Packet of type: ",
MediaPipeTypeStringOrDemangled(type_info),
", but received an empty Packet."));
}
bool holder_is_right_type = holder_->GetTypeId() == type_info.hash_code();
if (ABSL_PREDICT_FALSE(!holder_is_right_type)) {
return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", holder_->DebugTypeName(), "\", but \"",
MediaPipeTypeStringOrDemangled(type_info), "\" was requested."));
}
return absl::OkStatus();
}
absl::Status Packet::ValidateAsProtoMessageLite() const {
if (ABSL_PREDICT_FALSE(IsEmpty())) {
return absl::InternalError("Packet is empty.");

View File

@ -179,7 +179,9 @@ class Packet {
// Returns an error if the packet does not contain data of type T.
template <typename T>
absl::Status ValidateAsType() const;
absl::Status ValidateAsType() const {
return ValidateAsType(tool::TypeId<T>());
}
// Returns an error if the packet is not an instance of
// a protocol buffer message.
@ -218,6 +220,8 @@ class Packet {
friend std::shared_ptr<packet_internal::HolderBase>
packet_internal::GetHolderShared(Packet&& packet);
absl::Status ValidateAsType(const tool::TypeInfo& type_info) const;
std::shared_ptr<packet_internal::HolderBase> holder_;
class Timestamp timestamp_;
};
@ -770,21 +774,6 @@ inline const T& Packet::Get() const {
return holder->data();
}
template <typename T>
absl::Status Packet::ValidateAsType() const {
if (ABSL_PREDICT_FALSE(IsEmpty())) {
return absl::InternalError(absl::StrCat(
"Expected a Packet of type: ", MediaPipeTypeStringOrDemangled<T>(),
", but received an empty Packet."));
}
if (ABSL_PREDICT_FALSE(holder_->As<T>() == nullptr)) {
return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", holder_->DebugTypeName(), "\", but \"",
MediaPipeTypeStringOrDemangled<T>(), "\" was requested."));
}
return absl::OkStatus();
}
inline Timestamp Packet::Timestamp() const { return timestamp_; }
template <typename T>

View File

@ -84,6 +84,11 @@ class PacketType {
// Returns true iff this and other are consistent, meaning they do
// not expect different types. IsAny() is consistent with anything.
// IsNone() is only consistent with IsNone() and IsAny().
// Note: this is definied as a symmetric relationship, but within the
// framework, it is consistently invoked as:
// input_port_type.IsConsistentWith(connected_output_port_type)
// TODO: consider making this explicitly directional, and
// sharing some logic with the packet validation check.
bool IsConsistentWith(const PacketType& other) const;
// Returns OK if the packet contains an object of the appropriate type.

View File

@ -373,16 +373,22 @@ inline const std::string* MediaPipeTypeString() {
return MediaPipeTypeStringFromTypeId(tool::GetTypeHash<T>());
}
template <typename T>
const std::string MediaPipeTypeStringOrDemangled() {
const std::string* type_string = MediaPipeTypeString<T>();
inline std::string MediaPipeTypeStringOrDemangled(
const tool::TypeInfo& type_info) {
const std::string* type_string =
MediaPipeTypeStringFromTypeId(type_info.hash_code());
if (type_string) {
return *type_string;
} else {
return mediapipe::Demangle(tool::TypeId<T>().name());
return mediapipe::Demangle(type_info.name());
}
}
template <typename T>
std::string MediaPipeTypeStringOrDemangled() {
return MediaPipeTypeStringOrDemangled(tool::TypeId<T>());
}
// Returns type hash id of type identified by type_string or NULL if not
// registered.
inline const size_t* MediaPipeTypeId(const std::string& type_string) {

View File

@ -72,6 +72,17 @@ def mediapipe_aar(
visibility = ["//visibility:public"],
)
# When "--define ENABLE_STATS_LOGGING=1" is set in the build command,
# the solution stats logging component will be added into the AAR.
# This flag is for internal use only.
native.config_setting(
name = "enable_stats_logging",
define_values = {
"ENABLE_STATS_LOGGING": "1",
},
visibility = ["//visibility:public"],
)
_mediapipe_jni(
name = name + "_jni",
gen_libmediapipe = gen_libmediapipe,
@ -101,16 +112,23 @@ EOF
android_library(
name = name + "_android_lib",
srcs = srcs + [
"//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
],
"//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
] +
select({
"//conditions:default": [],
"enable_stats_logging": [
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
],
}),
manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
deps = [
@ -146,6 +164,13 @@ EOF
"//conditions:default": [":" + name + "_jni_opencv_cc_lib"],
"//mediapipe/framework/port:disable_opencv": [],
"exclude_opencv_so_lib": [],
}) + select({
"//conditions:default": [],
"enable_stats_logging": [
"@maven//:com_google_android_datatransport_transport_api",
"@maven//:com_google_android_datatransport_transport_backend_cct",
"@maven//:com_google_android_datatransport_transport_runtime",
],
}),
assets = assets,
assets_dir = assets_dir,
@ -159,6 +184,20 @@ def _mediapipe_proto(name):
Args:
name: the name of the target.
"""
_proto_java_src_generator(
name = "mediapipe_log_extension_proto",
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "mediapipe_logging_enums_proto",
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "calculator_proto",
proto_src = "mediapipe/framework/calculator.proto",

View File

@ -16,6 +16,15 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"])
android_library(
name = "solution_info",
srcs = ["SolutionInfo.java"],
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "solution_base",
srcs = glob(
@ -30,6 +39,7 @@ android_library(
),
visibility = ["//visibility:public"],
deps = [
":logging",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/glutil",
"//third_party:autovalue",
@ -80,6 +90,19 @@ android_library(
],
)
android_library(
name = "logging",
srcs = glob(
["logging/*.java"],
),
visibility = ["//visibility:public"],
deps = [
":solution_info",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
# Native dependencies of all MediaPipe solutions.
cc_binary(
name = "libmediapipe_jni.so",
@ -109,6 +132,6 @@ load("//mediapipe/java/com/google/mediapipe:mediapipe_aar.bzl", "mediapipe_aar")
mediapipe_aar(
name = "solution_core",
srcs = glob(["*.java"]),
srcs = glob(["**/*.java"]),
gen_libmediapipe = False,
)

View File

@ -137,8 +137,10 @@ public class ImageSolutionBase extends SolutionBase {
if (imageObj instanceof TextureFrame) {
imagePacket = packetCreator.createImage((TextureFrame) imageObj);
imageObj = null;
statsLogger.recordGpuInputArrival(timestamp);
} else if (imageObj instanceof Bitmap) {
imagePacket = packetCreator.createRgbaImage((Bitmap) imageObj);
statsLogger.recordCpuInputArrival(timestamp);
} else {
reportError(
"The input image type is not supported.",
@ -146,7 +148,6 @@ public class ImageSolutionBase extends SolutionBase {
MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(),
"The input image type is not supported."));
}
try {
// addConsumablePacketToInputStream allows the graph to take exclusive ownership of the
// packet, which may allow for more memory optimizations.

View File

@ -17,6 +17,7 @@ package com.google.mediapipe.solutioncore;
import android.util.Log;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.solutioncore.logging.SolutionStatsLogger;
import java.util.List;
/** Interface for handling MediaPipe solution graph outputs. */
@ -33,6 +34,9 @@ public class OutputHandler<T extends SolutionResult> {
private ResultListener<T> customResultListener;
// The user-defined error listener.
private ErrorListener customErrorListener;
// A logger that records the time when the output packets leave the graph or logs any error
// occurs.
private SolutionStatsLogger statsLogger;
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
private boolean handleTimestampBoundChanges = false;
@ -54,6 +58,15 @@ public class OutputHandler<T extends SolutionResult> {
this.customResultListener = listener;
}
/**
* Sets a {@link SolutionStatsLogger} to report invocation end events.
*
* @param statsLogger a {@link SolutionStatsLogger}.
*/
public void setStatsLogger(SolutionStatsLogger statsLogger) {
this.statsLogger = statsLogger;
}
/**
* Sets a callback to be invoked when exceptions are thrown in the solution.
*
@ -82,6 +95,7 @@ public class OutputHandler<T extends SolutionResult> {
T solutionResult = null;
try {
solutionResult = outputConverter.convert(packets);
statsLogger.recordInvocationEnd(packets.get(0).getTimestamp());
customResultListener.run(solutionResult);
} catch (MediaPipeException e) {
if (customErrorListener != null) {

View File

@ -27,6 +27,8 @@ import com.google.mediapipe.framework.Graph;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.solutioncore.logging.SolutionStatsLogger;
import com.google.mediapipe.solutioncore.logging.SolutionStatsDummyLogger;
import com.google.protobuf.Parser;
import java.io.File;
import java.util.List;
@ -35,7 +37,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nullable;
/** The base class of the MediaPipe solutions. */
public class SolutionBase {
public class SolutionBase implements AutoCloseable {
private static final String TAG = "SolutionBase";
protected Graph solutionGraph;
protected AndroidPacketCreator packetCreator;
@ -43,6 +45,7 @@ public class SolutionBase {
protected String imageInputStreamName;
protected long lastTimestamp = Long.MIN_VALUE;
protected final AtomicBoolean solutionGraphStarted = new AtomicBoolean(false);
protected SolutionStatsLogger statsLogger;
static {
System.loadLibrary("mediapipe_jni");
@ -63,6 +66,8 @@ public class SolutionBase {
SolutionInfo solutionInfo,
OutputHandler<? extends SolutionResult> outputHandler) {
this.imageInputStreamName = solutionInfo.imageInputStreamName();
this.statsLogger =
new SolutionStatsDummyLogger(context, this.getClass().getSimpleName(), solutionInfo);
try {
AndroidAssetUtil.initializeNativeAssetManager(context);
solutionGraph = new Graph();
@ -72,12 +77,14 @@ public class SolutionBase {
solutionGraph.loadBinaryGraph(
AndroidAssetUtil.getAssetBytes(context.getAssets(), solutionInfo.binaryGraphPath()));
}
outputHandler.setStatsLogger(statsLogger);
solutionGraph.addMultiStreamCallback(
solutionInfo.outputStreamNames(),
outputHandler::run,
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
packetCreator = new AndroidPacketCreator(solutionGraph);
} catch (MediaPipeException e) {
statsLogger.logInitError();
reportError("Error occurs while creating the MediaPipe solution graph.", e);
}
}
@ -113,10 +120,15 @@ public class SolutionBase {
if (inputSidePackets != null) {
solutionGraph.setInputSidePackets(inputSidePackets);
}
if (!solutionGraphStarted.getAndSet(true)) {
if (!solutionGraphStarted.get()) {
solutionGraph.startRunningGraph();
// Wait until all calculators are opened and the graph is truly started.
solutionGraph.waitUntilGraphIdle();
solutionGraphStarted.set(true);
statsLogger.logSessionStart();
}
} catch (MediaPipeException e) {
statsLogger.logInitError();
reportError("Error occurs while starting the MediaPipe solution graph.", e);
}
}
@ -131,11 +143,13 @@ public class SolutionBase {
}
/** Closes and cleans up the solution graph. */
@Override
public void close() {
if (solutionGraphStarted.get()) {
try {
solutionGraph.closeAllPacketSources();
solutionGraph.waitUntilGraphDone();
statsLogger.logSessionEnd();
} catch (MediaPipeException e) {
// Note: errors during Process are reported at the earliest opportunity,
// which may be addPacket or waitUntilDone, depending on timing. For consistency,

View File

@ -70,10 +70,26 @@ public class VideoInput {
}
}
/**
* The state of the MediaPlayer. See
* https://developer.android.com/reference/android/media/MediaPlayer#StateDiagram
*/
private enum MediaPlayerState {
IDLE,
PREPARING,
PREPARED,
STARTED,
PAUSED,
STOPPED,
PLAYBACK_COMPLETE,
END,
}
private static final String TAG = "VideoInput";
private final SingleThreadHandlerExecutor executor;
private TextureFrameConsumer newFrameListener;
private MediaPlayer mediaPlayer;
private MediaPlayerState state = MediaPlayerState.IDLE;
private boolean looping = false;
private float audioVolume = 1.0f;
// {@link SurfaceTexture} where the video frames can be accessed.
@ -150,14 +166,19 @@ public class VideoInput {
converter.setConsumer(newFrameListener);
executor.execute(
() -> {
if (state != MediaPlayerState.IDLE && state != MediaPlayerState.END) {
return;
}
mediaPlayer = new MediaPlayer();
mediaPlayer.setLooping(looping);
mediaPlayer.setVolume(audioVolume, audioVolume);
mediaPlayer.setOnPreparedListener(
mp -> {
surfaceTexture.setDefaultBufferSize(mp.getVideoWidth(), mp.getVideoHeight());
unused -> {
surfaceTexture.setDefaultBufferSize(
mediaPlayer.getVideoWidth(), mediaPlayer.getVideoHeight());
// Calculates the optimal texture size by preserving the video aspect ratio.
float videoAspectRatio = (float) mp.getVideoWidth() / mp.getVideoHeight();
float videoAspectRatio =
(float) mediaPlayer.getVideoWidth() / mediaPlayer.getVideoHeight();
float displayAspectRatio = (float) displayWidth / displayHeight;
int textureWidth =
displayAspectRatio > videoAspectRatio
@ -168,22 +189,34 @@ public class VideoInput {
? displayHeight
: (int) (displayWidth / videoAspectRatio);
converter.setSurfaceTexture(surfaceTexture, textureWidth, textureHeight);
executor.execute(mp::start);
state = MediaPlayerState.PREPARED;
executor.execute(
() -> {
if (mediaPlayer != null && state == MediaPlayerState.PREPARED) {
mediaPlayer.start();
state = MediaPlayerState.STARTED;
}
});
});
mediaPlayer.setOnErrorListener(
(mp, what, extra) -> {
(unused, what, extra) -> {
Log.e(
TAG,
String.format(
"Error during mediaPlayer initialization. what: %s extra: %s",
what, extra));
reset(mp);
executor.execute(this::close);
return true;
});
mediaPlayer.setOnCompletionListener(this::reset);
mediaPlayer.setOnCompletionListener(
unused -> {
state = MediaPlayerState.PLAYBACK_COMPLETE;
executor.execute(this::close);
});
try {
mediaPlayer.setDataSource(activity, videoUri);
mediaPlayer.setSurface(new Surface(surfaceTexture));
state = MediaPlayerState.PREPARING;
mediaPlayer.prepareAsync();
} catch (IOException e) {
Log.e(TAG, "Failed to start MediaPlayer:", e);
@ -196,8 +229,10 @@ public class VideoInput {
public void pause() {
executor.execute(
() -> {
if (mediaPlayer != null) {
if (mediaPlayer != null
&& (state == MediaPlayerState.STARTED || state == MediaPlayerState.PAUSED)) {
mediaPlayer.pause();
state = MediaPlayerState.PAUSED;
}
});
}
@ -206,8 +241,9 @@ public class VideoInput {
public void resume() {
executor.execute(
() -> {
if (mediaPlayer != null) {
if (mediaPlayer != null && state == MediaPlayerState.PAUSED) {
mediaPlayer.start();
state = MediaPlayerState.STARTED;
}
});
}
@ -216,20 +252,38 @@ public class VideoInput {
public void stop() {
executor.execute(
() -> {
if (mediaPlayer != null) {
if (mediaPlayer != null
&& (state == MediaPlayerState.PREPARED
|| state == MediaPlayerState.STARTED
|| state == MediaPlayerState.PAUSED
|| state == MediaPlayerState.PLAYBACK_COMPLETE
|| state == MediaPlayerState.STOPPED)) {
mediaPlayer.stop();
state = MediaPlayerState.STOPPED;
}
});
}
/** Closes VideoInput and releases the {@link MediaPlayer} resources. */
/** Closes VideoInput and releases the resources. */
public void close() {
if (converter != null) {
converter.close();
converter = null;
}
executor.execute(
() -> {
if (mediaPlayer != null) {
mediaPlayer.release();
state = MediaPlayerState.END;
}
if (eglManager != null) {
destorySurfaceTexture();
eglManager.release();
eglManager = null;
}
});
looping = false;
audioVolume = 1.0f;
}
private void createSurfaceTexture() {
@ -260,20 +314,4 @@ public class VideoInput {
eglManager.releaseSurface(tempEglSurface);
surfaceTexture = null;
}
private void reset(MediaPlayer mp) {
setNewFrameListener(null);
if (converter != null) {
converter.close();
converter = null;
}
if (eglManager != null) {
destorySurfaceTexture();
eglManager.release();
eglManager = null;
}
executor.execute(mp::release);
looping = false;
audioVolume = 1.0f;
}
}

View File

@ -0,0 +1,83 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.solutioncore.logging;
import android.content.Context;
import com.google.mediapipe.solutioncore.SolutionInfo;
/** A dummy solution stats logger that logs nothing. */
public class SolutionStatsDummyLogger implements SolutionStatsLogger {
/**
* Initializes the solution stats dummy logger.
*
* @param context a {@link Context}.
* @param solutionNameStr the solution name.
* @param solutionInfo a {@link SolutionInfo}.
*/
public SolutionStatsDummyLogger(
Context context, String solutionNameStr, SolutionInfo solutionInfo) {}
/** Logs the solution session start event. */
@Override
public void logSessionStart() {}
/**
* Records solution API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordGpuInputArrival(long packetTimestamp) {}
/**
* Records solution API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordCpuInputArrival(long packetTimestamp) {}
/**
* Records a solution api invocation end event.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
@Override
public void recordInvocationEnd(long packetTimestamp) {}
/** Logs the solution invocation report event. */
@Override
public void logInvocationReport(StatsSnapshot stats) {}
/** Logs the solution session end event. */
@Override
public void logSessionEnd() {}
/** Logs the solution init error. */
@Override
public void logInitError() {}
/** Logs the solution unsupported input error. */
@Override
public void logUnsupportedInputError() {}
/** Logs the solution unsupported output error. */
@Override
public void logUnsupportedOutputError() {}
}

View File

@ -0,0 +1,103 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.solutioncore.logging;
import com.google.auto.value.AutoValue;
/** The stats logger interface that defines what MediaPipe solution stats events to log. */
public interface SolutionStatsLogger {
/** Solution stats snapshot. */
@AutoValue
abstract static class StatsSnapshot {
static StatsSnapshot create(
int cpuInputCount,
int gpuInputCount,
int finishedCount,
int droppedCount,
long totalLatencyMs,
long peakLatencyMs,
long elapsedTimeMs) {
return new AutoValue_SolutionStatsLogger_StatsSnapshot(
cpuInputCount,
gpuInputCount,
finishedCount,
droppedCount,
totalLatencyMs,
peakLatencyMs,
elapsedTimeMs);
}
static StatsSnapshot createDefault() {
return new AutoValue_SolutionStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0);
}
abstract int cpuInputCount();
abstract int gpuInputCount();
abstract int finishedCount();
abstract int droppedCount();
abstract long totalLatencyMs();
abstract long peakLatencyMs();
abstract long elapsedTimeMs();
}
/** Logs the solution session start event. */
public void logSessionStart();
/**
* Records solution API receiving GPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordGpuInputArrival(long packetTimestamp);
/**
* Records solution API receiving CPU input data.
*
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordCpuInputArrival(long packetTimestamp);
/**
* Records a solution api invocation end event.
*
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
* invocation.
*/
public void recordInvocationEnd(long packetTimestamp);
/** Logs the solution invocation report event. */
public void logInvocationReport(StatsSnapshot stats);
/** Logs the solution session end event. */
public void logSessionEnd();
/** Logs the solution init error. */
public void logInitError();
/** Logs the solution unsupported input error. */
public void logUnsupportedInputError();
/** Logs the solution unsupported output error. */
public void logUnsupportedOutputError();
}

View File

@ -225,6 +225,7 @@ objc_library(
testonly = 1,
srcs = [
"CFHolderTests.mm",
"MPPDisplayLinkWeakTargetTests.mm",
"MPPGraphTests.mm",
],
copts = [
@ -250,6 +251,7 @@ objc_library(
":MPPGraphTestBase",
":Weakify",
":mediapipe_framework_ios",
":mediapipe_input_sources_ios",
"//mediapipe/calculators/core:pass_through_calculator",
],
)

View File

@ -0,0 +1,68 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <XCTest/XCTest.h>
#import "mediapipe/objc/MPPDisplayLinkWeakTarget.h"
@interface DummyTarget : NSObject
@property(nonatomic) BOOL updateCalled;
- (void)update:(id)sender;
@end
@implementation DummyTarget
@synthesize updateCalled = _updateCalled;
- (void)update:(id)sender {
_updateCalled = YES;
}
@end
@interface MPPDisplayLinkWeakTargetTests : XCTestCase
@end
@implementation MPPDisplayLinkWeakTargetTests {
DummyTarget *_dummyTarget;
}
- (void)setUp {
_dummyTarget = [[DummyTarget alloc] init];
}
- (void)testCallingLiveTarget {
XCTAssertFalse(_dummyTarget.updateCalled);
MPPDisplayLinkWeakTarget *target =
[[MPPDisplayLinkWeakTarget alloc] initWithTarget:_dummyTarget
selector:@selector(update:)];
[target displayLinkCallback:nil];
XCTAssertTrue(_dummyTarget.updateCalled);
}
- (void)testDoesNotCrashWhenTargetIsDeallocated {
MPPDisplayLinkWeakTarget *target =
[[MPPDisplayLinkWeakTarget alloc] initWithTarget:_dummyTarget
selector:@selector(update:)];
_dummyTarget = nil;
[target displayLinkCallback:nil];
XCTAssertNil(_dummyTarget);
}
@end

View File

@ -34,6 +34,9 @@
- (void)displayLinkCallback:(CADisplayLink *)sender {
__strong id target = _target;
if (target == nil) {
return;
}
void (*display)(id, SEL, CADisplayLink *) = (void *)[target methodForSelector:_selector];
display(target, _selector, sender);
}

View File

@ -140,7 +140,11 @@ static CVReturn renderCallback(CVDisplayLinkRef displayLink, const CVTimeStamp*
}
CFRelease(pixelBuffer);
});
} else if (!_videoDisplayLink.paused && _videoPlayer.rate == 0) {
} else if (
#if !TARGET_OS_OSX
!_videoDisplayLink.paused &&
#endif
_videoPlayer.rate == 0) {
// The video might be paused by the operating system fo other reasons not catched by the context
// of an interruption. If this condition happens the @c _videoDisplayLink will not have a
// paused state, while the _videoPlayer will have rate 0 AKA paused. In this scenario we restart

View File

@ -189,7 +189,28 @@ void PublicPacketGetters(pybind11::module* m) {
)doc");
m->def(
"get_int_list", &GetContent<std::vector<int>>,
"get_int_list",
[](const Packet& packet) {
if (packet.ValidateAsType<std::vector<int>>().ok()) {
auto int_list = packet.Get<std::vector<int>>();
return std::vector<int64>(int_list.begin(), int_list.end());
} else if (packet.ValidateAsType<std::vector<int8>>().ok()) {
auto int_list = packet.Get<std::vector<int8>>();
return std::vector<int64>(int_list.begin(), int_list.end());
} else if (packet.ValidateAsType<std::vector<int16>>().ok()) {
auto int_list = packet.Get<std::vector<int16>>();
return std::vector<int64>(int_list.begin(), int_list.end());
} else if (packet.ValidateAsType<std::vector<int32>>().ok()) {
auto int_list = packet.Get<std::vector<int32>>();
return std::vector<int64>(int_list.begin(), int_list.end());
} else if (packet.ValidateAsType<std::vector<int64>>().ok()) {
auto int_list = packet.Get<std::vector<int64>>();
return std::vector<int64>(int_list.begin(), int_list.end());
}
throw RaisePyError(PyExc_ValueError,
"Packet doesn't contain int, int8, int16, int32, or "
"int64 containers.");
},
R"doc(Get the content of a MediaPipe int vector Packet as an integer list.
Args:

View File

@ -321,3 +321,27 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
],
)
cc_library(
name = "rectangle_util",
srcs = ["rectangle_util.cc"],
hdrs = ["rectangle_util.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)
cc_test(
name = "rectangle_util_test",
srcs = ["rectangle_util_test.cc"],
deps = [
":rectangle_util",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -0,0 +1,73 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/rectangle_util.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h"
namespace mediapipe {
// Converts a NormalizedRect into a Rectangle_f.
absl::StatusOr<Rectangle_f> ToRectangle(
const mediapipe::NormalizedRect& input) {
if (!input.has_x_center() || !input.has_y_center() || !input.has_width() ||
!input.has_height()) {
return absl::InvalidArgumentError("Missing dimensions in NormalizedRect.");
}
if (input.width() < 0.0f || input.height() < 0.0f) {
return absl::InvalidArgumentError("Negative rectangle width or height.");
}
const float xmin = input.x_center() - input.width() / 2.0;
const float ymin = input.y_center() - input.height() / 2.0;
// TODO: Support rotation for rectangle.
return Rectangle_f(xmin, ymin, input.width(), input.height());
}
// If the new_rect overlaps with any of the rectangles in
// existing_rects, then return true. Otherwise, return false.
absl::StatusOr<bool> DoesRectOverlap(
const mediapipe::NormalizedRect& new_rect,
absl::Span<const mediapipe::NormalizedRect> existing_rects,
float min_similarity_threshold) {
ASSIGN_OR_RETURN(Rectangle_f new_rectangle, ToRectangle(new_rect));
for (const mediapipe::NormalizedRect& existing_rect : existing_rects) {
ASSIGN_OR_RETURN(Rectangle_f existing_rectangle,
ToRectangle(existing_rect));
if (CalculateIou(existing_rectangle, new_rectangle) >
min_similarity_threshold) {
return true;
}
}
return false;
}
// Computes the overlap similarity based on Intersection over Union (IoU) of
// two rectangles. Result is bounded between [0.0, 1.0], where 0.0 means no
// intersection at all, and 1.0 means the two rectangles are identical.
float CalculateIou(const Rectangle_f& rect1, const Rectangle_f& rect2) {
if (!rect1.Intersects(rect2)) return 0.0f;
// Compute IoU similarity score.
const float intersection_area = Rectangle_f(rect1).Intersect(rect2).Area();
const float normalization = rect1.Area() + rect2.Area() - intersection_area;
return normalization > 0.0f ? intersection_area / normalization : 0.0f;
}
} // namespace mediapipe

View File

@ -0,0 +1,40 @@
// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_RECTANGLE_UTIL_H_
#define MEDIAPIPE_RECTANGLE_UTIL_H_
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h"
namespace mediapipe {
// Converts a NormalizedRect into a Rectangle_f.
absl::StatusOr<Rectangle_f> ToRectangle(const mediapipe::NormalizedRect& input);
// If the new_rect overlaps with any of the rectangles in
// existing_rects, then return true. Otherwise, return false.
absl::StatusOr<bool> DoesRectOverlap(
const mediapipe::NormalizedRect& new_rect,
absl::Span<const mediapipe::NormalizedRect> existing_rects,
float min_similarity_threshold);
// Computes the Intersection over Union (IoU) between two rectangles.
float CalculateIou(const Rectangle_f& rect1, const Rectangle_f& rect2);
} // namespace mediapipe
#endif // MEDIAPIPE_RECTANGLE_UTIL_H_

View File

@ -0,0 +1,180 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/rectangle_util.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
using ::testing::FloatNear;
class RectangleUtilTest : public testing::Test {
protected:
RectangleUtilTest() {
// 0.4 ================
// | | | |
// 0.3 ===================== | NR2 | |
// | | | NR1 | | | NR4 |
// 0.2 | NR0 | =========== ================
// | | | | | |
// 0.1 =====|=============== |
// | NR3 | | |
// 0.0 ================ |
// | NR5 |
// -0.1 ===========
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
// NormalizedRect nr_0.
nr_0.set_x_center(0.2);
nr_0.set_y_center(0.2);
nr_0.set_width(0.2);
nr_0.set_height(0.2);
// NormalizedRect nr_1.
nr_1.set_x_center(0.4);
nr_1.set_y_center(0.2);
nr_1.set_width(0.2);
nr_1.set_height(0.2);
// NormalizedRect nr_2.
nr_2.set_x_center(1.0);
nr_2.set_y_center(0.3);
nr_2.set_width(0.2);
nr_2.set_height(0.2);
// NormalizedRect nr_3.
nr_3.set_x_center(0.35);
nr_3.set_y_center(0.15);
nr_3.set_width(0.3);
nr_3.set_height(0.3);
// NormalizedRect nr_4.
nr_4.set_x_center(1.1);
nr_4.set_y_center(0.3);
nr_4.set_width(0.2);
nr_4.set_height(0.2);
// NormalizedRect nr_5.
nr_5.set_x_center(0.5);
nr_5.set_y_center(0.05);
nr_5.set_width(0.2);
nr_5.set_height(0.3);
}
mediapipe::NormalizedRect nr_0, nr_1, nr_2, nr_3, nr_4, nr_5;
};
TEST_F(RectangleUtilTest, OverlappingWithListLargeThreshold) {
constexpr float kMinSimilarityThreshold = 0.15;
std::vector<NormalizedRect> existing_rects;
existing_rects.push_back(nr_0);
existing_rects.push_back(nr_5);
existing_rects.push_back(nr_2);
EXPECT_THAT(DoesRectOverlap(nr_3, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(true));
EXPECT_THAT(DoesRectOverlap(nr_4, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(true));
EXPECT_THAT(DoesRectOverlap(nr_1, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(false));
}
TEST_F(RectangleUtilTest, OverlappingWithListSmallThreshold) {
constexpr float kMinSimilarityThreshold = 0.1;
std::vector<NormalizedRect> existing_rects;
existing_rects.push_back(nr_0);
existing_rects.push_back(nr_5);
existing_rects.push_back(nr_2);
EXPECT_THAT(DoesRectOverlap(nr_3, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(true));
EXPECT_THAT(DoesRectOverlap(nr_4, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(true));
EXPECT_THAT(DoesRectOverlap(nr_1, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(true));
}
TEST_F(RectangleUtilTest, NonOverlappingWithList) {
constexpr float kMinSimilarityThreshold = 0.1;
std::vector<NormalizedRect> existing_rects;
existing_rects.push_back(nr_0);
existing_rects.push_back(nr_3);
existing_rects.push_back(nr_5);
EXPECT_THAT(DoesRectOverlap(nr_2, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(false));
EXPECT_THAT(DoesRectOverlap(nr_4, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(false));
}
TEST_F(RectangleUtilTest, OverlappingWithEmptyList) {
constexpr float kMinSimilarityThreshold = 0.1;
std::vector<NormalizedRect> existing_rects;
EXPECT_THAT(DoesRectOverlap(nr_2, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(false));
EXPECT_THAT(DoesRectOverlap(nr_4, existing_rects, kMinSimilarityThreshold),
IsOkAndHolds(false));
}
TEST_F(RectangleUtilTest, OverlapSimilarityOverlapping) {
constexpr float kMaxAbsoluteError = 1e-4;
constexpr float kExpectedIou = 4.0 / 9.0;
auto rect_1 = ToRectangle(nr_1);
auto rect_3 = ToRectangle(nr_3);
MP_ASSERT_OK(rect_1);
MP_ASSERT_OK(rect_3);
EXPECT_THAT(CalculateIou(*rect_1, *rect_3),
FloatNear(kExpectedIou, kMaxAbsoluteError));
}
TEST_F(RectangleUtilTest, OverlapSimilarityNotOverlapping) {
constexpr float kMaxAbsoluteError = 1e-4;
constexpr float kExpectedIou = 0.0;
auto rect_1 = ToRectangle(nr_1);
auto rect_2 = ToRectangle(nr_2);
MP_ASSERT_OK(rect_1);
MP_ASSERT_OK(rect_2);
EXPECT_THAT(CalculateIou(*rect_1, *rect_2),
FloatNear(kExpectedIou, kMaxAbsoluteError));
}
TEST_F(RectangleUtilTest, NormRectToRectangleSuccess) {
const Rectangle_f kExpectedRect(/*xmin=*/0.1, /*ymin=*/0.1,
/*width=*/0.2, /*height=*/0.2);
EXPECT_THAT(ToRectangle(nr_0), IsOkAndHolds(kExpectedRect));
}
TEST_F(RectangleUtilTest, NormRectToRectangleFail) {
mediapipe::NormalizedRect invalid_nr;
invalid_nr.set_x_center(0.2);
EXPECT_THAT(ToRectangle(invalid_nr), testing::Not(IsOk()));
invalid_nr.set_y_center(0.2);
invalid_nr.set_width(-0.2);
invalid_nr.set_height(0.2);
EXPECT_THAT(ToRectangle(invalid_nr), testing::Not(IsOk()));
invalid_nr.set_width(0.2);
invalid_nr.set_height(-0.2);
EXPECT_THAT(ToRectangle(invalid_nr), testing::Not(IsOk()));
}
} // namespace
} // namespace mediapipe