Merge branch 'master' into ios-hand-landmarker-tests
This commit is contained in:
commit
086798e677
|
@ -87,6 +87,8 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
packet.Set<double>();
|
packet.Set<double>();
|
||||||
} else if (packet_options.has_time_series_header_value()) {
|
} else if (packet_options.has_time_series_header_value()) {
|
||||||
packet.Set<TimeSeriesHeader>();
|
packet.Set<TimeSeriesHeader>();
|
||||||
|
} else if (packet_options.has_int64_value()) {
|
||||||
|
packet.Set<int64_t>();
|
||||||
} else {
|
} else {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"None of supported values were specified in options.");
|
"None of supported values were specified in options.");
|
||||||
|
@ -124,6 +126,8 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
} else if (packet_options.has_time_series_header_value()) {
|
} else if (packet_options.has_time_series_header_value()) {
|
||||||
packet.Set(MakePacket<TimeSeriesHeader>(
|
packet.Set(MakePacket<TimeSeriesHeader>(
|
||||||
packet_options.time_series_header_value()));
|
packet_options.time_series_header_value()));
|
||||||
|
} else if (packet_options.has_int64_value()) {
|
||||||
|
packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
|
||||||
} else {
|
} else {
|
||||||
return absl::InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"None of supported values were specified in options.");
|
"None of supported values were specified in options.");
|
||||||
|
|
|
@ -29,13 +29,14 @@ message ConstantSidePacketCalculatorOptions {
|
||||||
message ConstantSidePacket {
|
message ConstantSidePacket {
|
||||||
oneof value {
|
oneof value {
|
||||||
int32 int_value = 1;
|
int32 int_value = 1;
|
||||||
|
uint64 uint64_value = 5;
|
||||||
|
int64 int64_value = 11;
|
||||||
float float_value = 2;
|
float float_value = 2;
|
||||||
|
double double_value = 9;
|
||||||
bool bool_value = 3;
|
bool bool_value = 3;
|
||||||
string string_value = 4;
|
string string_value = 4;
|
||||||
uint64 uint64_value = 5;
|
|
||||||
ClassificationList classification_list_value = 6;
|
ClassificationList classification_list_value = 6;
|
||||||
LandmarkList landmark_list_value = 7;
|
LandmarkList landmark_list_value = 7;
|
||||||
double double_value = 9;
|
|
||||||
TimeSeriesHeader time_series_header_value = 10;
|
TimeSeriesHeader time_series_header_value = 10;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
|
||||||
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
|
DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
|
||||||
DoTestSingleSidePacket("{ bool_value: true }", true);
|
DoTestSingleSidePacket("{ bool_value: true }", true);
|
||||||
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
|
DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
|
||||||
|
DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
|
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
|
||||||
|
|
|
@ -228,7 +228,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -280,7 +279,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
|
@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
||||||
input_tensors.reserve(kNumInputTensorsForBert);
|
input_tensors.reserve(kNumInputTensorsForBert);
|
||||||
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
||||||
input_tensors.push_back(
|
input_tensors.push_back(
|
||||||
{Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
|
{Tensor::ElementType::kInt32,
|
||||||
|
Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
|
||||||
}
|
}
|
||||||
std::memcpy(input_tensors[input_ids_tensor_index_]
|
std::memcpy(input_tensors[input_ids_tensor_index_]
|
||||||
.GetCpuWriteView()
|
.GetCpuWriteView()
|
||||||
|
|
|
@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
||||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
||||||
// Read CPU input into tensors.
|
// Read CPU input into tensors.
|
||||||
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
|
||||||
|
|
||||||
|
// If the input tensors have dynamic shape, then the tensors need to be
|
||||||
|
// resized and reallocated before we can copy the tensor values.
|
||||||
|
bool resized_tensor_shapes = false;
|
||||||
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
|
if (input_tensors[i].shape().is_dynamic) {
|
||||||
|
interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
|
||||||
|
resized_tensor_shapes = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Reallocation is needed for memory sanity.
|
||||||
|
if (resized_tensor_shapes) interpreter_->AllocateTensors();
|
||||||
|
|
||||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||||
const TfLiteType input_tensor_type =
|
const TfLiteType input_tensor_type =
|
||||||
interpreter_->tensor(interpreter_->inputs()[i])->type;
|
interpreter_->tensor(interpreter_->inputs()[i])->type;
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "absl/status/statusor.h"
|
|
||||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||||
// not found in the tokenizer vocab.
|
// not found in the tokenizer vocab.
|
||||||
std::vector<Tensor> result;
|
std::vector<Tensor> result;
|
||||||
result.push_back(
|
result.push_back(
|
||||||
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
|
{Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
|
||||||
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
||||||
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
||||||
kTensorsOut(cc).Send(std::move(result));
|
kTensorsOut(cc).Send(std::move(result));
|
||||||
|
|
|
@ -1077,6 +1077,7 @@ cc_test(
|
||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
":tensor_to_image_frame_calculator",
|
":tensor_to_image_frame_calculator",
|
||||||
|
":tensor_to_image_frame_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
|
|
@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float scale_factor_;
|
float scale_factor_;
|
||||||
|
bool scale_per_frame_min_max_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
|
||||||
|
@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
|
||||||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
|
||||||
scale_factor_ =
|
scale_factor_ =
|
||||||
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
|
||||||
|
scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
|
||||||
|
.scale_per_frame_min_max();
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
||||||
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
|
||||||
const int32_t total_size = height * width * depth;
|
const int32_t total_size = height * width * depth;
|
||||||
|
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
|
RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
|
||||||
|
<< "Setting scale_per_frame_min_max requires FLOAT input tensors.";
|
||||||
|
}
|
||||||
::std::unique_ptr<const ImageFrame> output;
|
::std::unique_ptr<const ImageFrame> output;
|
||||||
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
|
||||||
// Allocate buffer with alignments.
|
// Allocate buffer with alignments.
|
||||||
std::unique_ptr<uint8_t[]> buffer(
|
std::unique_ptr<uint8_t[]> buffer(
|
||||||
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
|
||||||
auto data = input_tensor.flat<float>().data();
|
auto data = input_tensor.flat<float>().data();
|
||||||
|
float min = 1e23;
|
||||||
|
float max = -1e23;
|
||||||
|
if (scale_per_frame_min_max_) {
|
||||||
|
for (int i = 0; i < total_size; ++i) {
|
||||||
|
float d = scale_factor_ * data[i];
|
||||||
|
if (d < min) {
|
||||||
|
min = d;
|
||||||
|
}
|
||||||
|
if (d > max) {
|
||||||
|
max = d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int i = 0; i < total_size; ++i) {
|
for (int i = 0; i < total_size; ++i) {
|
||||||
float d = scale_factor_ * data[i];
|
float d = data[i];
|
||||||
if (d < 0) d = 0;
|
if (scale_per_frame_min_max_) {
|
||||||
if (d > 255) d = 255;
|
d = 255 * (d - min) / (max - min + 1e-9);
|
||||||
|
} else {
|
||||||
|
d = scale_factor_ * d;
|
||||||
|
if (d < 0) d = 0;
|
||||||
|
if (d > 255) d = 255;
|
||||||
|
}
|
||||||
buffer[i] = d;
|
buffer[i] = d;
|
||||||
}
|
}
|
||||||
output = ::absl::make_unique<ImageFrame>(
|
output = ::absl::make_unique<ImageFrame>(
|
||||||
|
|
|
@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
|
||||||
// Multiples floating point tensor outputs by this value before converting to
|
// Multiples floating point tensor outputs by this value before converting to
|
||||||
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
// uint8. This is useful for converting from range [0, 1] to [0, 255]
|
||||||
optional float scale_factor = 1 [default = 1.0];
|
optional float scale_factor = 1 [default = 1.0];
|
||||||
|
|
||||||
|
// If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
|
||||||
|
// per frame. This overrides any explicit scale_factor.
|
||||||
|
optional bool scale_per_frame_min_max = 2 [default = false];
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,9 @@
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
|
||||||
template <class TypeParam>
|
template <class TypeParam>
|
||||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
class TensorToImageFrameCalculatorTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void SetUpRunner() {
|
void SetUpRunner(bool scale_per_frame_min_max = false) {
|
||||||
CalculatorGraphConfig::Node config;
|
CalculatorGraphConfig::Node config;
|
||||||
config.set_calculator("TensorToImageFrameCalculator");
|
config.set_calculator("TensorToImageFrameCalculator");
|
||||||
config.add_input_stream("TENSOR:input_tensor");
|
config.add_input_stream("TENSOR:input_tensor");
|
||||||
config.add_output_stream("IMAGE:output_image");
|
config.add_output_stream("IMAGE:output_image");
|
||||||
|
config.mutable_options()
|
||||||
|
->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
|
||||||
|
->set_scale_per_frame_min_max(scale_per_frame_min_max);
|
||||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TYPED_TEST(TensorToImageFrameCalculatorTest,
|
||||||
|
Converts3DTensorToImageFrame2DGrayWithScaling) {
|
||||||
|
this->SetUpRunner(true);
|
||||||
|
auto& runner = this->runner_;
|
||||||
|
constexpr int kWidth = 16;
|
||||||
|
constexpr int kHeight = 8;
|
||||||
|
const tf::TensorShape tensor_shape{kHeight, kWidth};
|
||||||
|
auto tensor = absl::make_unique<tf::Tensor>(
|
||||||
|
tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
|
||||||
|
auto tensor_vec = tensor->template flat<TypeParam>().data();
|
||||||
|
|
||||||
|
// Writing sequence of integers as floats which we want normalized.
|
||||||
|
tensor_vec[0] = 255;
|
||||||
|
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||||
|
tensor_vec[i] = 200;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t time = 1234;
|
||||||
|
runner->MutableInputs()->Tag(kTensor).packets.push_back(
|
||||||
|
Adopt(tensor.release()).At(Timestamp(time)));
|
||||||
|
|
||||||
|
if (!std::is_same<TypeParam, float>::value) {
|
||||||
|
EXPECT_FALSE(runner->Run().ok());
|
||||||
|
return; // Short circuit because does not apply to other types.
|
||||||
|
} else {
|
||||||
|
EXPECT_TRUE(runner->Run().ok());
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner->Outputs().Tag(kImage).packets;
|
||||||
|
EXPECT_EQ(1, output_packets.size());
|
||||||
|
EXPECT_EQ(time, output_packets[0].Timestamp().Value());
|
||||||
|
const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
|
||||||
|
EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
|
||||||
|
EXPECT_EQ(kWidth, output_image.Width());
|
||||||
|
EXPECT_EQ(kHeight, output_image.Height());
|
||||||
|
|
||||||
|
EXPECT_EQ(255, output_image.PixelData()[0]);
|
||||||
|
for (int i = 1; i < kWidth * kHeight; ++i) {
|
||||||
|
const uint8_t pixel_value = output_image.PixelData()[i];
|
||||||
|
ASSERT_EQ(0, pixel_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -1355,6 +1355,22 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "calculator_graph_summary_packet_test",
|
||||||
|
srcs = ["calculator_graph_summary_packet_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":calculator_framework",
|
||||||
|
":packet",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||||
|
"//mediapipe/framework/tool:sink",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "calculator_runner_test",
|
name = "calculator_runner_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
|
|
@ -109,9 +109,20 @@ class CalculatorContext {
|
||||||
// use OutputStream::SetOffset() directly.
|
// use OutputStream::SetOffset() directly.
|
||||||
void SetOffset(TimestampDiff offset);
|
void SetOffset(TimestampDiff offset);
|
||||||
|
|
||||||
// Returns the status of the graph run.
|
// DEPRECATED: This was intended to get graph run status during
|
||||||
|
// `CalculatorBase::Close` call. However, `Close` can run simultaneously with
|
||||||
|
// other calculators `CalculatorBase::Process`, hence the actual graph
|
||||||
|
// status may change any time and returned graph status here does not
|
||||||
|
// necessarily reflect the actual graph status.
|
||||||
//
|
//
|
||||||
// NOTE: This method should only be called during CalculatorBase::Close().
|
// As an alternative, instead of checking graph status in `Close` and doing
|
||||||
|
// work for "done" state, you can enable timestamp bound processing for your
|
||||||
|
// calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger
|
||||||
|
// `Process` on timestamp bound updates and handle "done" state there.
|
||||||
|
// Check examples in:
|
||||||
|
// mediapipe/framework/calculator_graph_summary_packet_test.cc.
|
||||||
|
//
|
||||||
|
ABSL_DEPRECATED("Does not reflect the actual graph status.")
|
||||||
absl::Status GraphStatus() const { return graph_status_; }
|
absl::Status GraphStatus() const { return graph_status_; }
|
||||||
|
|
||||||
ProfilingContext* GetProfilingContext() const {
|
ProfilingContext* GetProfilingContext() const {
|
||||||
|
|
327
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
327
mediapipe/framework/calculator_graph_summary_packet_test.cc
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Node;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::Eq;
|
||||||
|
using ::testing::IsEmpty;
|
||||||
|
using ::testing::Value;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
MATCHER_P2(IntPacket, value, timestamp, "") {
|
||||||
|
*result_listener << "where object is (value: " << arg.template Get<int>()
|
||||||
|
<< ", timestamp: " << arg.Timestamp() << ")";
|
||||||
|
return Value(arg.template Get<int>(), Eq(value)) &&
|
||||||
|
Value(arg.Timestamp(), Eq(timestamp));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculates and produces sum of all passed inputs when no more packets can be
|
||||||
|
// expected on the input stream.
|
||||||
|
class SummaryPacketCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"SUMMARY"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
|
// Makes sure there are no automatic timestamp bound updates when Process
|
||||||
|
// is called.
|
||||||
|
cc->SetTimestampOffset(TimestampDiff::Unset());
|
||||||
|
// Currently, only ImmediateInputStreamHandler supports "done" timestamp
|
||||||
|
// bound update. (ImmediateInputStreamhandler handles multiple input
|
||||||
|
// streams differently, so, in that case, calculator adjustments may be
|
||||||
|
// required.)
|
||||||
|
// TODO: update all input stream handlers to support "done"
|
||||||
|
// timestamp bound update.
|
||||||
|
cc->SetInputStreamHandler("ImmediateInputStreamHandler");
|
||||||
|
// Enables processing timestamp bound updates. For this use case we are
|
||||||
|
// specifically interested in "done" timestamp bound update. (E.g. when
|
||||||
|
// all input packet sources are closed.)
|
||||||
|
cc->SetProcessTimestampBounds(true);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
if (!kIn(cc).IsEmpty()) {
|
||||||
|
value_ += kIn(cc).Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kOut(cc).IsClosed()) {
|
||||||
|
// This can happen:
|
||||||
|
// 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
|
||||||
|
// source calculator finished generating packets sent to kIn) and
|
||||||
|
// HasNextAllowedInStream() == true (which is an often case).
|
||||||
|
// 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
|
||||||
|
// invoke Process() with Timestamp::Max to indicate "Done" timestamp
|
||||||
|
// bound update.
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: input stream holding a packet with timestamp that has
|
||||||
|
// no next timestamp allowed in stream should always result in
|
||||||
|
// InputStream::IsDone() == true.
|
||||||
|
if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
|
||||||
|
// kOut(cc).Send(value_) can be used here as well, however in the case of
|
||||||
|
// source calculator sending inputs into kIn the resulting timestamp is
|
||||||
|
// not well defined (e.g. it can be the last packet timestamp or
|
||||||
|
// Timestamp::Max())
|
||||||
|
// TODO: last packet from source should always result in
|
||||||
|
// InputStream::IsDone() == true.
|
||||||
|
kOut(cc).Send(value_, Timestamp::Max());
|
||||||
|
kOut(cc).Close();
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int value_ = 0;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnClosingAllPacketSources) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
send_packet(20, Timestamp(11));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp(10));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
send_packet(20, Timestamp::Max());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnPreStreamTimestamp) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PreStream());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnPostStreamTimestamp) {
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: 'input'
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: 'IN:input'
|
||||||
|
output_stream: 'SUMMARY:output'
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
auto send_packet = [&graph](int value, Timestamp timestamp) {
|
||||||
|
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||||
|
"input", MakePacket<int>(value).At(timestamp)));
|
||||||
|
};
|
||||||
|
|
||||||
|
send_packet(10, Timestamp::PostStream());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
class IntGeneratorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final {
|
||||||
|
kOut(cc).Send(20, Timestamp(0));
|
||||||
|
kOut(cc).Send(10, Timestamp(1000));
|
||||||
|
return tool::StatusStop();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnSourceCalculatorCompletion) {
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
CalculatorGraphConfig graph_config =
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
node {
|
||||||
|
calculator: "IntGeneratorCalculator"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
}
|
||||||
|
|
||||||
|
class EmitOnCloseCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<int> kIn{"IN"};
|
||||||
|
static constexpr Output<int> kOut{"INT"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||||
|
|
||||||
|
absl::Status Close(CalculatorContext* cc) final {
|
||||||
|
kOut(cc).Send(20, Timestamp(0));
|
||||||
|
kOut(cc).Send(10, Timestamp(1000));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
|
||||||
|
|
||||||
|
TEST(SummaryPacketCalculatorUseCaseTest,
|
||||||
|
ProducesSummaryPacketOnAnotherCalculatorClosure) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input"
|
||||||
|
node {
|
||||||
|
calculator: "EmitOnCloseCalculator"
|
||||||
|
input_stream: "IN:input"
|
||||||
|
output_stream: "INT:int_value"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
calculator: "SummaryPacketCalculator"
|
||||||
|
input_stream: "IN:int_value"
|
||||||
|
output_stream: "SUMMARY:output"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("output", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(graph.CloseInputStream("input"));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
|
||||||
|
|
||||||
|
output_packets.clear();
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
EXPECT_THAT(output_packets, IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -117,11 +117,18 @@ class Tensor {
|
||||||
Shape() = default;
|
Shape() = default;
|
||||||
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
||||||
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
|
Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
|
||||||
|
Shape(std::initializer_list<int> dimensions, bool is_dynamic)
|
||||||
|
: dims(dimensions), is_dynamic(is_dynamic) {}
|
||||||
|
Shape(const std::vector<int>& dimensions, bool is_dynamic)
|
||||||
|
: dims(dimensions), is_dynamic(is_dynamic) {}
|
||||||
int num_elements() const {
|
int num_elements() const {
|
||||||
return std::accumulate(dims.begin(), dims.end(), 1,
|
return std::accumulate(dims.begin(), dims.end(), 1,
|
||||||
std::multiplies<int>());
|
std::multiplies<int>());
|
||||||
}
|
}
|
||||||
std::vector<int> dims;
|
std::vector<int> dims;
|
||||||
|
// The Tensor has dynamic rather than static shape so the TFLite interpreter
|
||||||
|
// needs to be reallocated. Only relevant for CPU.
|
||||||
|
bool is_dynamic = false;
|
||||||
};
|
};
|
||||||
// Quantization parameters corresponding to the zero_point and scale value
|
// Quantization parameters corresponding to the zero_point and scale value
|
||||||
// made available by TfLite quantized (uint8/int8) tensors.
|
// made available by TfLite quantized (uint8/int8) tensors.
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
@ -34,6 +35,17 @@ TEST(General, TestDataTypes) {
|
||||||
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(General, TestDynamic) {
|
||||||
|
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true));
|
||||||
|
EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4);
|
||||||
|
EXPECT_TRUE(t1.shape().is_dynamic);
|
||||||
|
|
||||||
|
std::vector<int> t2_dims = {4, 3, 2, 3};
|
||||||
|
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true));
|
||||||
|
EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3);
|
||||||
|
EXPECT_TRUE(t2.shape().is_dynamic);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Cpu, TestMemoryAllocation) {
|
TEST(Cpu, TestMemoryAllocation) {
|
||||||
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
|
Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
|
||||||
auto v1 = t1.GetCpuWriteView();
|
auto v1 = t1.GetCpuWriteView();
|
||||||
|
|
|
@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
|
||||||
return *this + 1;
|
return *this + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Timestamp::HasNextAllowedInStream() const {
|
||||||
|
if (*this >= Max() || *this == PreStream()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
Timestamp Timestamp::PreviousAllowedInStream() const {
|
Timestamp Timestamp::PreviousAllowedInStream() const {
|
||||||
if (*this <= Min() || *this == PostStream()) {
|
if (*this <= Min() || *this == PostStream()) {
|
||||||
// Indicates that no previous timestamps may occur.
|
// Indicates that no previous timestamps may occur.
|
||||||
|
|
|
@ -186,6 +186,10 @@ class Timestamp {
|
||||||
// CHECKs that this->IsAllowedInStream().
|
// CHECKs that this->IsAllowedInStream().
|
||||||
Timestamp NextAllowedInStream() const;
|
Timestamp NextAllowedInStream() const;
|
||||||
|
|
||||||
|
// Returns true if there's a next timestamp in the range [Min .. Max] after
|
||||||
|
// this one.
|
||||||
|
bool HasNextAllowedInStream() const;
|
||||||
|
|
||||||
// Returns the previous timestamp in the range [Min .. Max], or
|
// Returns the previous timestamp in the range [Min .. Max], or
|
||||||
// Unstarted() if no Packets may preceed one with this timestamp.
|
// Unstarted() if no Packets may preceed one with this timestamp.
|
||||||
Timestamp PreviousAllowedInStream() const;
|
Timestamp PreviousAllowedInStream() const;
|
||||||
|
|
|
@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
|
||||||
Timestamp::PostStream().NextAllowedInStream());
|
Timestamp::PostStream().NextAllowedInStream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TimestampTest, HasNextAllowedInStream) {
|
||||||
|
EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
|
||||||
|
EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
|
||||||
|
|
||||||
|
EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
|
||||||
|
EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TimestampTest, SpecialValueDifferences) {
|
TEST(TimestampTest, SpecialValueDifferences) {
|
||||||
{ // Lower range
|
{ // Lower range
|
||||||
const std::vector<Timestamp> timestamps = {
|
const std::vector<Timestamp> timestamps = {
|
||||||
|
|
|
@ -34,7 +34,7 @@ objc_library(
|
||||||
"-x objective-c++",
|
"-x objective-c++",
|
||||||
],
|
],
|
||||||
data = [
|
data = [
|
||||||
"//mediapipe/tasks/testdata/vision:test_models",
|
"//mediapipe/tasks/testdata/vision:hand_landmarker.task",
|
||||||
"//mediapipe/tasks/testdata/vision:test_images",
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
"//mediapipe/tasks/testdata/vision:test_protos",
|
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||||
],
|
],
|
||||||
|
|
|
@ -37,7 +37,6 @@ static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
|
||||||
static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
|
static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
|
||||||
@{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
|
@{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
|
||||||
|
|
||||||
|
|
||||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
static const float kLandmarksErrorTolerance = 0.03f;
|
static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
|
|
||||||
|
@ -54,8 +53,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
|
@"hand index = %d landmark index j = %d", handIndex, landmarkIndex);
|
||||||
|
|
||||||
#define AssertHandLandmarkerResultIsEmpty(handLandmarkerResult) \
|
#define AssertHandLandmarkerResultIsEmpty(handLandmarkerResult) \
|
||||||
XCTAssertTrue(handLandmarkerResult.handedness.count == 0); \
|
XCTAssertTrue(handLandmarkerResult.handedness.count == 0); \
|
||||||
XCTAssertTrue(handLandmarkerResult.landmarks.count == 0); \
|
XCTAssertTrue(handLandmarkerResult.landmarks.count == 0); \
|
||||||
XCTAssertTrue(handLandmarkerResult.worldLandmarks.count == 0);
|
XCTAssertTrue(handLandmarkerResult.worldLandmarks.count == 0);
|
||||||
|
|
||||||
@interface MPPHandLandmarkerTests : XCTestCase {
|
@interface MPPHandLandmarkerTests : XCTestCase {
|
||||||
|
@ -70,28 +69,25 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
|
|
||||||
+ (MPPHandLandmarkerResult *)emptyHandLandmarkerResult {
|
+ (MPPHandLandmarkerResult *)emptyHandLandmarkerResult {
|
||||||
return [[MPPHandLandmarkerResult alloc] initWithLandmarks:@[]
|
return [[MPPHandLandmarkerResult alloc] initWithLandmarks:@[]
|
||||||
worldLandmarks:@[]
|
worldLandmarks:@[]
|
||||||
handedness:@[]
|
handedness:@[]
|
||||||
|
|
||||||
timestampInMilliseconds:0];
|
timestampInMilliseconds:0];
|
||||||
}
|
}
|
||||||
|
|
||||||
+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
|
+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
|
||||||
NSString *filePath =
|
NSString *filePath = [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
|
||||||
[MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
|
|
||||||
|
|
||||||
return [MPPHandLandmarkerResult
|
return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
|
||||||
handLandmarkerResultFromTextEncodedProtobufFileWithName:filePath
|
shouldRemoveZPosition:YES];
|
||||||
shouldRemoveZPosition:YES];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
+ (MPPHandLandmarkerResult *)pointingUpRotatedHandLandmarkerResult {
|
+ (MPPHandLandmarkerResult *)pointingUpRotatedHandLandmarkerResult {
|
||||||
NSString *filePath =
|
NSString *filePath =
|
||||||
[MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
|
[MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
|
||||||
|
|
||||||
return [MPPHandLandmarkerResult
|
return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
|
||||||
handLandmarkerResultFromTextEncodedProtobufFileWithName:filePath
|
shouldRemoveZPosition:YES];
|
||||||
shouldRemoveZPosition:YES];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
|
- (void)assertMultiHandLandmarks:(NSArray<NSArray<MPPNormalizedLandmark *> *> *)multiHandLandmarks
|
||||||
|
@ -133,8 +129,7 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
|
- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
|
||||||
isApproximatelyEqualToExpectedResult:
|
isApproximatelyEqualToExpectedResult:(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
|
||||||
(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
|
|
||||||
[self assertMultiHandLandmarks:handLandmarkerResult.landmarks
|
[self assertMultiHandLandmarks:handLandmarkerResult.landmarks
|
||||||
areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
|
areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
|
||||||
[self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
|
[self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
|
||||||
|
@ -146,7 +141,7 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
|
|
||||||
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
|
+ (NSString *)filePathWithFileInfo:(ResourceFileInfo *)fileInfo {
|
||||||
NSString *filePath = [MPPHandLandmarkerTests filePathWithName:fileInfo[@"name"]
|
NSString *filePath = [MPPHandLandmarkerTests filePathWithName:fileInfo[@"name"]
|
||||||
extension:fileInfo[@"type"]];
|
extension:fileInfo[@"type"]];
|
||||||
return filePath;
|
return filePath;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,8 +156,7 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
|
- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
|
||||||
(ResourceFileInfo *)modelFileInfo {
|
(ResourceFileInfo *)modelFileInfo {
|
||||||
NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
|
NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
|
||||||
MPPHandLandmarkerOptions *handLandmarkerOptions =
|
MPPHandLandmarkerOptions *handLandmarkerOptions = [[MPPHandLandmarkerOptions alloc] init];
|
||||||
[[MPPHandLandmarkerOptions alloc] init];
|
|
||||||
handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
|
handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
|
||||||
|
|
||||||
return handLandmarkerOptions;
|
return handLandmarkerOptions;
|
||||||
|
@ -170,21 +164,22 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
|
|
||||||
- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
|
- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
|
||||||
(MPPHandLandmarkerOptions *)handLandmarkerOptions {
|
(MPPHandLandmarkerOptions *)handLandmarkerOptions {
|
||||||
|
NSError* error;
|
||||||
MPPHandLandmarker *handLandmarker =
|
MPPHandLandmarker *handLandmarker =
|
||||||
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:nil];
|
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
|
||||||
XCTAssertNotNil(handLandmarker);
|
XCTAssertNotNil(handLandmarker);
|
||||||
|
XCTAssertNil(error);
|
||||||
|
|
||||||
return handLandmarker;
|
return handLandmarker;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertCreateHandLandmarkerWithOptions:
|
- (void)assertCreateHandLandmarkerWithOptions:(MPPHandLandmarkerOptions *)handLandmarkerOptions
|
||||||
(MPPHandLandmarkerOptions *)handLandmarkerOptions
|
failsWithExpectedError:(NSError *)expectedError {
|
||||||
failsWithExpectedError:(NSError *)expectedError {
|
|
||||||
NSError *error = nil;
|
NSError *error = nil;
|
||||||
MPPHandLandmarker *handLandmarker =
|
MPPHandLandmarker *handLandmarker =
|
||||||
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
|
[[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
|
||||||
|
|
||||||
XCTAssertNil(handLandmarkerOptions);
|
XCTAssertNil(handLandmarker);
|
||||||
AssertEqualErrors(error, expectedError);
|
AssertEqualErrors(error, expectedError);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,22 +206,20 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
|
- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
|
||||||
usingHandLandmarker:
|
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker {
|
||||||
(MPPHandLandmarker *)handLandmarker {
|
|
||||||
MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
|
MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
|
||||||
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage
|
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
|
||||||
error:nil];
|
|
||||||
XCTAssertNotNil(handLandmarkerResult);
|
XCTAssertNotNil(handLandmarkerResult);
|
||||||
|
|
||||||
return handLandmarkerResult;
|
return handLandmarkerResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)assertResultsOfDetectInImageWithFileInfo:(ResourceFileInfo *)fileInfo
|
- (void)assertResultsOfDetectInImageWithFileInfo:(ResourceFileInfo *)fileInfo
|
||||||
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
|
usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
|
||||||
approximatelyEqualsHandLandmarkerResult:
|
approximatelyEqualsHandLandmarkerResult:
|
||||||
(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
|
(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
|
||||||
MPPHandLandmarkerResult *handLandmarkerResult =
|
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:fileInfo
|
||||||
[self detectInImageWithFileInfo:fileInfo usingHandLandmarker:handLandmarker];
|
usingHandLandmarker:handLandmarker];
|
||||||
[self assertHandLandmarkerResult:handLandmarkerResult
|
[self assertHandLandmarkerResult:handLandmarkerResult
|
||||||
isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
|
isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
|
||||||
}
|
}
|
||||||
|
@ -236,14 +229,14 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
- (void)testDetectWithModelPathSucceeds {
|
- (void)testDetectWithModelPathSucceeds {
|
||||||
NSString *modelPath =
|
NSString *modelPath =
|
||||||
[MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
|
[MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
|
||||||
MPPHandLandmarker *handLandmarker =
|
MPPHandLandmarker *handLandmarker = [[MPPHandLandmarker alloc] initWithModelPath:modelPath
|
||||||
[[MPPHandLandmarker alloc] initWithModelPath:modelPath error:nil];
|
error:nil];
|
||||||
XCTAssertNotNil(handLandmarker);
|
XCTAssertNotNil(handLandmarker);
|
||||||
|
|
||||||
[self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
|
[self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
|
||||||
usingHandLandmarker:handLandmarker
|
usingHandLandmarker:handLandmarker
|
||||||
approximatelyEqualsHandLandmarkerResult:[MPPHandLandmarkerTests
|
approximatelyEqualsHandLandmarkerResult:[MPPHandLandmarkerTests
|
||||||
thumbUpHandLandmarkerResult]];
|
thumbUpHandLandmarkerResult]];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testDetectWithEmptyResultsSucceeds {
|
- (void)testDetectWithEmptyResultsSucceeds {
|
||||||
|
@ -253,8 +246,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
MPPHandLandmarker *handLandmarker =
|
MPPHandLandmarker *handLandmarker =
|
||||||
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
||||||
|
|
||||||
MPPHandLandmarkerResult *handLandmarkerResult =
|
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kNoHandsImage
|
||||||
[self detectInImageWithFileInfo:kNoHandsImage usingHandLandmarker:handLandmarker];
|
usingHandLandmarker:handLandmarker];
|
||||||
AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
|
AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -268,8 +261,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
MPPHandLandmarker *handLandmarker =
|
MPPHandLandmarker *handLandmarker =
|
||||||
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
[self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
|
||||||
|
|
||||||
MPPHandLandmarkerResult *handLandmarkerResult =
|
MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kTwoHandsImage
|
||||||
[self detectInImageWithFileInfo:kTwoHandsImage usingHandLandmarker:handLandmarker];
|
usingHandLandmarker:handLandmarker];
|
||||||
|
|
||||||
XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
|
XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
|
||||||
}
|
}
|
||||||
|
@ -284,12 +277,11 @@ static const float kLandmarksErrorTolerance = 0.03f;
|
||||||
MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
|
MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
|
||||||
orientation:UIImageOrientationRight];
|
orientation:UIImageOrientationRight];
|
||||||
|
|
||||||
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage
|
MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
|
||||||
error:nil];
|
|
||||||
|
|
||||||
[self assertHandLandmarkerResult:handLandmarkerResult
|
[self assertHandLandmarkerResult:handLandmarkerResult
|
||||||
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests pointingUpRotatedHandLandmarkerResult]];
|
isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests
|
||||||
|
pointingUpRotatedHandLandmarkerResult]];
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark Running Mode Tests
|
#pragma mark Running Mode Tests
|
||||||
|
|
|
@ -12,10 +12,11 @@ objc_library(
|
||||||
"-x objective-c++",
|
"-x objective-c++",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
|
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
|
||||||
"//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
|
|
||||||
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
"//mediapipe/tasks/ios/common/utils:NSStringHelpers",
|
||||||
|
"//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
|
||||||
|
"//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarkerResult",
|
||||||
|
"//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,9 +18,8 @@
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
@interface MPPHandLandmarkerResult (ProtobufHelpers)
|
@interface MPPHandLandmarkerResult (ProtobufHelpers)
|
||||||
|
|
||||||
+ (MPPHandLandmarkerResult *)
|
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
|
||||||
handLandmarkerResultFromTextEncodedProtobufFileWithName:(NSString *)fileName
|
shouldRemoveZPosition:(BOOL)removeZPosition;
|
||||||
shouldRemoveZPosition:(BOOL)removeZPosition;
|
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -31,9 +31,8 @@ using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
|
||||||
|
|
||||||
@implementation MPPHandLandmarkerResult (ProtobufHelpers)
|
@implementation MPPHandLandmarkerResult (ProtobufHelpers)
|
||||||
|
|
||||||
+ (MPPHandLandmarkerResult *)
|
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
|
||||||
handLandmarkerResultFromTextEncodedProtobufFileWithName:(NSString *)fileName
|
shouldRemoveZPosition:(BOOL)removeZPosition {
|
||||||
shouldRemoveZPosition:(BOOL)removeZPosition {
|
|
||||||
LandmarksDetectionResultProto landmarkDetectionResultProto;
|
LandmarksDetectionResultProto landmarkDetectionResultProto;
|
||||||
|
|
||||||
if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
|
if (!get_proto_from_pbtxt(fileName.cppString, landmarkDetectionResultProto).ok()) {
|
||||||
|
@ -51,9 +50,9 @@ using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
|
||||||
|
|
||||||
return [MPPHandLandmarkerResult
|
return [MPPHandLandmarkerResult
|
||||||
handLandmarkerResultWithLandmarksProto:{landmarkDetectionResultProto.landmarks()}
|
handLandmarkerResultWithLandmarksProto:{landmarkDetectionResultProto.landmarks()}
|
||||||
worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
|
worldLandmarksProto:{landmarkDetectionResultProto.world_landmarks()}
|
||||||
handednessProto:{landmarkDetectionResultProto.classifications()}
|
handednessProto:{landmarkDetectionResultProto.classifications()}
|
||||||
timestampInMilliSeconds:0];
|
timestampInMilliSeconds:0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -41,12 +41,12 @@ PYBIND11_MODULE(_pywrap_flatbuffers, m) {
|
||||||
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
|
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
|
||||||
contents.length());
|
contents.length());
|
||||||
});
|
});
|
||||||
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
|
m.def("generate_text_file", &flatbuffers::GenTextFile);
|
||||||
m.def("generate_text",
|
m.def("generate_text",
|
||||||
[](const flatbuffers::Parser& parser,
|
[](const flatbuffers::Parser& parser,
|
||||||
const std::string& buffer) -> std::string {
|
const std::string& buffer) -> std::string {
|
||||||
std::string text;
|
std::string text;
|
||||||
const char* result = flatbuffers::GenerateText(
|
const char* result = flatbuffers::GenText(
|
||||||
parser, reinterpret_cast<const void*>(buffer.c_str()), &text);
|
parser, reinterpret_cast<const void*>(buffer.c_str()), &text);
|
||||||
if (result) {
|
if (result) {
|
||||||
return "";
|
return "";
|
||||||
|
|
|
@ -38,7 +38,7 @@ mediapipe_files(srcs = [
|
||||||
])
|
])
|
||||||
|
|
||||||
rollup_bundle(
|
rollup_bundle(
|
||||||
name = "audio_bundle",
|
name = "audio_bundle_mjs",
|
||||||
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
||||||
entry_point = "index.ts",
|
entry_point = "index.ts",
|
||||||
format = "esm",
|
format = "esm",
|
||||||
|
@ -69,6 +69,29 @@ rollup_bundle(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
genrule(
|
||||||
|
name = "audio_sources",
|
||||||
|
srcs = [
|
||||||
|
":audio_bundle_cjs",
|
||||||
|
":audio_bundle_mjs",
|
||||||
|
],
|
||||||
|
outs = [
|
||||||
|
"audio_bundle.cjs",
|
||||||
|
"audio_bundle.cjs.map",
|
||||||
|
"audio_bundle.mjs",
|
||||||
|
"audio_bundle.mjs.map",
|
||||||
|
],
|
||||||
|
cmd = (
|
||||||
|
"for FILE in $(SRCS); do " +
|
||||||
|
" OUT_FILE=$(GENDIR)/mediapipe/tasks/web/audio/$$(" +
|
||||||
|
" basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
|
||||||
|
" ); " +
|
||||||
|
" echo $$FILE ; echo $$OUT_FILE ; " +
|
||||||
|
" cp $$FILE $$OUT_FILE ; " +
|
||||||
|
"done;"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
genrule(
|
genrule(
|
||||||
name = "package_json",
|
name = "package_json",
|
||||||
srcs = ["//mediapipe/tasks/web:package.json"],
|
srcs = ["//mediapipe/tasks/web:package.json"],
|
||||||
|
@ -91,8 +114,7 @@ pkg_npm(
|
||||||
"wasm/audio_wasm_internal.wasm",
|
"wasm/audio_wasm_internal.wasm",
|
||||||
"wasm/audio_wasm_nosimd_internal.js",
|
"wasm/audio_wasm_nosimd_internal.js",
|
||||||
"wasm/audio_wasm_nosimd_internal.wasm",
|
"wasm/audio_wasm_nosimd_internal.wasm",
|
||||||
":audio_bundle",
|
":audio_sources",
|
||||||
":audio_bundle_cjs",
|
|
||||||
":package_json",
|
":package_json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
"name": "@mediapipe/tasks-__NAME__",
|
"name": "@mediapipe/tasks-__NAME__",
|
||||||
"version": "__VERSION__",
|
"version": "__VERSION__",
|
||||||
"description": "__DESCRIPTION__",
|
"description": "__DESCRIPTION__",
|
||||||
"main": "__NAME___bundle_cjs.js",
|
"main": "__NAME___bundle.cjs",
|
||||||
"browser": "__NAME___bundle.js",
|
"browser": "__NAME___bundle.mjs",
|
||||||
"module": "__NAME___bundle.js",
|
"module": "__NAME___bundle.mjs",
|
||||||
"author": "mediapipe@google.com",
|
"author": "mediapipe@google.com",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
|
"type": "module",
|
||||||
"types": "__TYPES__",
|
"types": "__TYPES__",
|
||||||
"homepage": "http://mediapipe.dev",
|
"homepage": "http://mediapipe.dev",
|
||||||
"keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ]
|
"keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ]
|
||||||
|
|
|
@ -39,7 +39,7 @@ mediapipe_ts_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
rollup_bundle(
|
rollup_bundle(
|
||||||
name = "text_bundle",
|
name = "text_bundle_mjs",
|
||||||
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
||||||
entry_point = "index.ts",
|
entry_point = "index.ts",
|
||||||
format = "esm",
|
format = "esm",
|
||||||
|
@ -70,6 +70,29 @@ rollup_bundle(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
genrule(
|
||||||
|
name = "text_sources",
|
||||||
|
srcs = [
|
||||||
|
":text_bundle_cjs",
|
||||||
|
":text_bundle_mjs",
|
||||||
|
],
|
||||||
|
outs = [
|
||||||
|
"text_bundle.cjs",
|
||||||
|
"text_bundle.cjs.map",
|
||||||
|
"text_bundle.mjs",
|
||||||
|
"text_bundle.mjs.map",
|
||||||
|
],
|
||||||
|
cmd = (
|
||||||
|
"for FILE in $(SRCS); do " +
|
||||||
|
" OUT_FILE=$(GENDIR)/mediapipe/tasks/web/text/$$(" +
|
||||||
|
" basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
|
||||||
|
" ); " +
|
||||||
|
" echo $$FILE ; echo $$OUT_FILE ; " +
|
||||||
|
" cp $$FILE $$OUT_FILE ; " +
|
||||||
|
"done;"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
genrule(
|
genrule(
|
||||||
name = "package_json",
|
name = "package_json",
|
||||||
srcs = ["//mediapipe/tasks/web:package.json"],
|
srcs = ["//mediapipe/tasks/web:package.json"],
|
||||||
|
@ -93,7 +116,6 @@ pkg_npm(
|
||||||
"wasm/text_wasm_nosimd_internal.js",
|
"wasm/text_wasm_nosimd_internal.js",
|
||||||
"wasm/text_wasm_nosimd_internal.wasm",
|
"wasm/text_wasm_nosimd_internal.wasm",
|
||||||
":package_json",
|
":package_json",
|
||||||
":text_bundle",
|
":text_sources",
|
||||||
":text_bundle_cjs",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -50,7 +50,7 @@ mediapipe_ts_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
rollup_bundle(
|
rollup_bundle(
|
||||||
name = "vision_bundle",
|
name = "vision_bundle_mjs",
|
||||||
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
config_file = "//mediapipe/tasks/web:rollup.config.mjs",
|
||||||
entry_point = "index.ts",
|
entry_point = "index.ts",
|
||||||
format = "esm",
|
format = "esm",
|
||||||
|
@ -81,6 +81,29 @@ rollup_bundle(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
genrule(
|
||||||
|
name = "vision_sources",
|
||||||
|
srcs = [
|
||||||
|
":vision_bundle_cjs",
|
||||||
|
":vision_bundle_mjs",
|
||||||
|
],
|
||||||
|
outs = [
|
||||||
|
"vision_bundle.cjs",
|
||||||
|
"vision_bundle.cjs.map",
|
||||||
|
"vision_bundle.mjs",
|
||||||
|
"vision_bundle.mjs.map",
|
||||||
|
],
|
||||||
|
cmd = (
|
||||||
|
"for FILE in $(SRCS); do " +
|
||||||
|
" OUT_FILE=$(GENDIR)/mediapipe/tasks/web/vision/$$(" +
|
||||||
|
" basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
|
||||||
|
" ); " +
|
||||||
|
" echo $$FILE ; echo $$OUT_FILE ; " +
|
||||||
|
" cp $$FILE $$OUT_FILE ; " +
|
||||||
|
"done;"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
genrule(
|
genrule(
|
||||||
name = "package_json",
|
name = "package_json",
|
||||||
srcs = ["//mediapipe/tasks/web:package.json"],
|
srcs = ["//mediapipe/tasks/web:package.json"],
|
||||||
|
@ -104,7 +127,6 @@ pkg_npm(
|
||||||
"wasm/vision_wasm_nosimd_internal.js",
|
"wasm/vision_wasm_nosimd_internal.js",
|
||||||
"wasm/vision_wasm_nosimd_internal.wasm",
|
"wasm/vision_wasm_nosimd_internal.wasm",
|
||||||
":package_json",
|
":package_json",
|
||||||
":vision_bundle",
|
":vision_sources",
|
||||||
":vision_bundle_cjs",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/flags/flag.h"
|
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
@ -35,23 +34,14 @@
|
||||||
#include "mediapipe/framework/port/integral_types.h"
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
|
|
||||||
ABSL_FLAG(std::string, system_cpu_max_freq_file,
|
|
||||||
"/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq",
|
|
||||||
"The file pattern for CPU max frequencies, where $0 will be replaced "
|
|
||||||
"with the CPU id.");
|
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr uint32_t kBufferLength = 64;
|
constexpr uint32_t kBufferLength = 64;
|
||||||
|
|
||||||
absl::StatusOr<std::string> GetFilePath(int cpu) {
|
absl::StatusOr<std::string> GetFilePath(int cpu) {
|
||||||
if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
|
return absl::Substitute(
|
||||||
return absl::InvalidArgumentError(
|
"/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq", cpu);
|
||||||
absl::StrCat("Invalid frequency file: ",
|
|
||||||
absl::GetFlag(FLAGS_system_cpu_max_freq_file)));
|
|
||||||
}
|
|
||||||
return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
|
absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
|
||||||
|
|
|
@ -147,6 +147,22 @@ absl::Status ReconcileMetadataImages(const std::string& prefix,
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reconciles metadata for all images.
|
||||||
|
absl::Status ReconcileMetadataImages(tensorflow::SequenceExample* sequence) {
|
||||||
|
RET_CHECK_OK(ReconcileMetadataImages("", sequence));
|
||||||
|
for (const auto& key_value : sequence->feature_lists().feature_list()) {
|
||||||
|
const auto& key = key_value.first;
|
||||||
|
if (::absl::StrContains(key, kImageTimestampKey)) {
|
||||||
|
std::string prefix = "";
|
||||||
|
if (key != kImageTimestampKey) {
|
||||||
|
prefix = key.substr(0, key.size() - sizeof(kImageTimestampKey));
|
||||||
|
}
|
||||||
|
RET_CHECK_OK(ReconcileMetadataImages(prefix, sequence));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
// Sets the values of "feature/${TAG}/dimensions", and
|
// Sets the values of "feature/${TAG}/dimensions", and
|
||||||
// "feature/${TAG}/frame_rate" for each float list feature TAG. If the
|
// "feature/${TAG}/frame_rate" for each float list feature TAG. If the
|
||||||
// dimensions are already present as a context feature, this method verifies
|
// dimensions are already present as a context feature, this method verifies
|
||||||
|
@ -545,10 +561,7 @@ absl::Status ReconcileMetadata(bool reconcile_bbox_annotations,
|
||||||
bool reconcile_region_annotations,
|
bool reconcile_region_annotations,
|
||||||
tensorflow::SequenceExample* sequence) {
|
tensorflow::SequenceExample* sequence) {
|
||||||
RET_CHECK_OK(ReconcileAnnotationIndicesByImageTimestamps(sequence));
|
RET_CHECK_OK(ReconcileAnnotationIndicesByImageTimestamps(sequence));
|
||||||
RET_CHECK_OK(ReconcileMetadataImages("", sequence));
|
RET_CHECK_OK(ReconcileMetadataImages(sequence));
|
||||||
RET_CHECK_OK(ReconcileMetadataImages(kForwardFlowPrefix, sequence));
|
|
||||||
RET_CHECK_OK(ReconcileMetadataImages(kClassSegmentationPrefix, sequence));
|
|
||||||
RET_CHECK_OK(ReconcileMetadataImages(kInstanceSegmentationPrefix, sequence));
|
|
||||||
RET_CHECK_OK(ReconcileMetadataFeatureFloats(sequence));
|
RET_CHECK_OK(ReconcileMetadataFeatureFloats(sequence));
|
||||||
if (reconcile_bbox_annotations) {
|
if (reconcile_bbox_annotations) {
|
||||||
RET_CHECK_OK(ReconcileMetadataBoxAnnotations("", sequence));
|
RET_CHECK_OK(ReconcileMetadataBoxAnnotations("", sequence));
|
||||||
|
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
|
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
||||||
def repo():
|
def repo():
|
||||||
third_party_http_archive(
|
third_party_http_archive(
|
||||||
name = "flatbuffers",
|
name = "flatbuffers",
|
||||||
strip_prefix = "flatbuffers-23.5.8",
|
strip_prefix = "flatbuffers-23.5.26",
|
||||||
sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
|
sha256 = "1cce06b17cddd896b6d73cc047e36a254fb8df4d7ea18a46acf16c4c0cd3f3f3",
|
||||||
urls = [
|
urls = [
|
||||||
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
|
"https://github.com/google/flatbuffers/archive/v23.5.26.tar.gz",
|
||||||
"https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
|
"https://github.com/google/flatbuffers/archive/v23.5.26.tar.gz",
|
||||||
],
|
],
|
||||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||||
delete = ["build_defs.bzl", "BUILD.bazel"],
|
delete = ["build_defs.bzl", "BUILD.bazel"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user