Merge branch 'master' into ios-hand-landmarker-tests
This commit is contained in:
		
						commit
						086798e677
					
				| 
						 | 
				
			
			@ -87,6 +87,8 @@ class ConstantSidePacketCalculator : public CalculatorBase {
 | 
			
		|||
        packet.Set<double>();
 | 
			
		||||
      } else if (packet_options.has_time_series_header_value()) {
 | 
			
		||||
        packet.Set<TimeSeriesHeader>();
 | 
			
		||||
      } else if (packet_options.has_int64_value()) {
 | 
			
		||||
        packet.Set<int64_t>();
 | 
			
		||||
      } else {
 | 
			
		||||
        return absl::InvalidArgumentError(
 | 
			
		||||
            "None of supported values were specified in options.");
 | 
			
		||||
| 
						 | 
				
			
			@ -124,6 +126,8 @@ class ConstantSidePacketCalculator : public CalculatorBase {
 | 
			
		|||
      } else if (packet_options.has_time_series_header_value()) {
 | 
			
		||||
        packet.Set(MakePacket<TimeSeriesHeader>(
 | 
			
		||||
            packet_options.time_series_header_value()));
 | 
			
		||||
      } else if (packet_options.has_int64_value()) {
 | 
			
		||||
        packet.Set(MakePacket<int64_t>(packet_options.int64_value()));
 | 
			
		||||
      } else {
 | 
			
		||||
        return absl::InvalidArgumentError(
 | 
			
		||||
            "None of supported values were specified in options.");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,13 +29,14 @@ message ConstantSidePacketCalculatorOptions {
 | 
			
		|||
  message ConstantSidePacket {
 | 
			
		||||
    oneof value {
 | 
			
		||||
      int32 int_value = 1;
 | 
			
		||||
      uint64 uint64_value = 5;
 | 
			
		||||
      int64 int64_value = 11;
 | 
			
		||||
      float float_value = 2;
 | 
			
		||||
      double double_value = 9;
 | 
			
		||||
      bool bool_value = 3;
 | 
			
		||||
      string string_value = 4;
 | 
			
		||||
      uint64 uint64_value = 5;
 | 
			
		||||
      ClassificationList classification_list_value = 6;
 | 
			
		||||
      LandmarkList landmark_list_value = 7;
 | 
			
		||||
      double double_value = 9;
 | 
			
		||||
      TimeSeriesHeader time_series_header_value = 10;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,7 @@
 | 
			
		|||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +59,7 @@ TEST(ConstantSidePacketCalculatorTest, EveryPossibleType) {
 | 
			
		|||
  DoTestSingleSidePacket("{ float_value: 6.5f }", 6.5f);
 | 
			
		||||
  DoTestSingleSidePacket("{ bool_value: true }", true);
 | 
			
		||||
  DoTestSingleSidePacket<std::string>(R"({ string_value: "str" })", "str");
 | 
			
		||||
  DoTestSingleSidePacket<int64_t>("{ int64_value: 63 }", 63);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConstantSidePacketCalculatorTest, MultiplePackets) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -228,7 +228,6 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
| 
						 | 
				
			
			@ -280,7 +279,6 @@ cc_library(
 | 
			
		|||
        "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
 | 
			
		||||
        "//mediapipe/tasks/metadata:metadata_schema_cc",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
    ],
 | 
			
		||||
    alwayslink = 1,
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,6 @@
 | 
			
		|||
 | 
			
		||||
#include "absl/container/flat_hash_set.h"
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/ascii.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "absl/strings/substitute.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -244,7 +243,8 @@ std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
 | 
			
		|||
  input_tensors.reserve(kNumInputTensorsForBert);
 | 
			
		||||
  for (int i = 0; i < kNumInputTensorsForBert; ++i) {
 | 
			
		||||
    input_tensors.push_back(
 | 
			
		||||
        {Tensor::ElementType::kInt32, Tensor::Shape({tensor_size})});
 | 
			
		||||
        {Tensor::ElementType::kInt32,
 | 
			
		||||
         Tensor::Shape({1, tensor_size}, has_dynamic_input_tensors_)});
 | 
			
		||||
  }
 | 
			
		||||
  std::memcpy(input_tensors[input_ids_tensor_index_]
 | 
			
		||||
                  .GetCpuWriteView()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -96,6 +96,19 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
 | 
			
		|||
    CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
 | 
			
		||||
  // Read CPU input into tensors.
 | 
			
		||||
  RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size());
 | 
			
		||||
 | 
			
		||||
  // If the input tensors have dynamic shape, then the tensors need to be
 | 
			
		||||
  // resized and reallocated before we can copy the tensor values.
 | 
			
		||||
  bool resized_tensor_shapes = false;
 | 
			
		||||
  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
    if (input_tensors[i].shape().is_dynamic) {
 | 
			
		||||
      interpreter_->ResizeInputTensorStrict(i, input_tensors[i].shape().dims);
 | 
			
		||||
      resized_tensor_shapes = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // Reallocation is needed for memory sanity.
 | 
			
		||||
  if (resized_tensor_shapes) interpreter_->AllocateTensors();
 | 
			
		||||
 | 
			
		||||
  for (int i = 0; i < input_tensors.size(); ++i) {
 | 
			
		||||
    const TfLiteType input_tensor_type =
 | 
			
		||||
        interpreter_->tensor(interpreter_->inputs()[i])->type;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,7 +20,6 @@
 | 
			
		|||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +160,7 @@ absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
  // not found in the tokenizer vocab.
 | 
			
		||||
  std::vector<Tensor> result;
 | 
			
		||||
  result.push_back(
 | 
			
		||||
      {Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
 | 
			
		||||
      {Tensor::ElementType::kInt32, Tensor::Shape({1, max_seq_len_})});
 | 
			
		||||
  std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
 | 
			
		||||
              input_tokens.data(), input_tokens.size() * sizeof(int32_t));
 | 
			
		||||
  kTensorsOut(cc).Send(std::move(result));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1077,6 +1077,7 @@ cc_test(
 | 
			
		|||
    linkstatic = 1,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":tensor_to_image_frame_calculator",
 | 
			
		||||
        ":tensor_to_image_frame_calculator_cc_proto",
 | 
			
		||||
        "//mediapipe/framework:calculator_framework",
 | 
			
		||||
        "//mediapipe/framework:calculator_runner",
 | 
			
		||||
        "//mediapipe/framework/formats:image_frame",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -65,6 +65,7 @@ class TensorToImageFrameCalculator : public CalculatorBase {
 | 
			
		|||
 | 
			
		||||
 private:
 | 
			
		||||
  float scale_factor_;
 | 
			
		||||
  bool scale_per_frame_min_max_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
REGISTER_CALCULATOR(TensorToImageFrameCalculator);
 | 
			
		||||
| 
						 | 
				
			
			@ -88,6 +89,8 @@ absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) {
 | 
			
		|||
absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) {
 | 
			
		||||
  scale_factor_ =
 | 
			
		||||
      cc->Options<TensorToImageFrameCalculatorOptions>().scale_factor();
 | 
			
		||||
  scale_per_frame_min_max_ = cc->Options<TensorToImageFrameCalculatorOptions>()
 | 
			
		||||
                                 .scale_per_frame_min_max();
 | 
			
		||||
  cc->SetOffset(TimestampDiff(0));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -109,16 +112,38 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
 | 
			
		|||
  auto format = (depth == 3 ? ImageFormat::SRGB : ImageFormat::GRAY8);
 | 
			
		||||
  const int32_t total_size = height * width * depth;
 | 
			
		||||
 | 
			
		||||
  if (scale_per_frame_min_max_) {
 | 
			
		||||
    RET_CHECK_EQ(input_tensor.dtype(), tensorflow::DT_FLOAT)
 | 
			
		||||
        << "Setting scale_per_frame_min_max requires FLOAT input tensors.";
 | 
			
		||||
  }
 | 
			
		||||
  ::std::unique_ptr<const ImageFrame> output;
 | 
			
		||||
  if (input_tensor.dtype() == tensorflow::DT_FLOAT) {
 | 
			
		||||
    // Allocate buffer with alignments.
 | 
			
		||||
    std::unique_ptr<uint8_t[]> buffer(
 | 
			
		||||
        new (std::align_val_t(EIGEN_MAX_ALIGN_BYTES)) uint8_t[total_size]);
 | 
			
		||||
    auto data = input_tensor.flat<float>().data();
 | 
			
		||||
    float min = 1e23;
 | 
			
		||||
    float max = -1e23;
 | 
			
		||||
    if (scale_per_frame_min_max_) {
 | 
			
		||||
      for (int i = 0; i < total_size; ++i) {
 | 
			
		||||
        float d = scale_factor_ * data[i];
 | 
			
		||||
        if (d < min) {
 | 
			
		||||
          min = d;
 | 
			
		||||
        }
 | 
			
		||||
        if (d > max) {
 | 
			
		||||
          max = d;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    for (int i = 0; i < total_size; ++i) {
 | 
			
		||||
      float d = data[i];
 | 
			
		||||
      if (scale_per_frame_min_max_) {
 | 
			
		||||
        d = 255 * (d - min) / (max - min + 1e-9);
 | 
			
		||||
      } else {
 | 
			
		||||
        d = scale_factor_ * d;
 | 
			
		||||
        if (d < 0) d = 0;
 | 
			
		||||
        if (d > 255) d = 255;
 | 
			
		||||
      }
 | 
			
		||||
      buffer[i] = d;
 | 
			
		||||
    }
 | 
			
		||||
    output = ::absl::make_unique<ImageFrame>(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,4 +26,8 @@ message TensorToImageFrameCalculatorOptions {
 | 
			
		|||
  // Multiples floating point tensor outputs by this value before converting to
 | 
			
		||||
  // uint8. This is useful for converting from range [0, 1] to [0, 255]
 | 
			
		||||
  optional float scale_factor = 1 [default = 1.0];
 | 
			
		||||
 | 
			
		||||
  // If true, scales any FLOAT tensor input of [min, max] to be between [0, 255]
 | 
			
		||||
  // per frame. This overrides any explicit scale_factor.
 | 
			
		||||
  optional bool scale_per_frame_min_max = 2 [default = false];
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,7 +11,9 @@
 | 
			
		|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.pb.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_runner.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image_frame.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -32,11 +34,14 @@ constexpr char kImage[] = "IMAGE";
 | 
			
		|||
template <class TypeParam>
 | 
			
		||||
class TensorToImageFrameCalculatorTest : public ::testing::Test {
 | 
			
		||||
 protected:
 | 
			
		||||
  void SetUpRunner() {
 | 
			
		||||
  void SetUpRunner(bool scale_per_frame_min_max = false) {
 | 
			
		||||
    CalculatorGraphConfig::Node config;
 | 
			
		||||
    config.set_calculator("TensorToImageFrameCalculator");
 | 
			
		||||
    config.add_input_stream("TENSOR:input_tensor");
 | 
			
		||||
    config.add_output_stream("IMAGE:output_image");
 | 
			
		||||
    config.mutable_options()
 | 
			
		||||
        ->MutableExtension(mediapipe::TensorToImageFrameCalculatorOptions::ext)
 | 
			
		||||
        ->set_scale_per_frame_min_max(scale_per_frame_min_max);
 | 
			
		||||
    runner_ = absl::make_unique<CalculatorRunner>(config);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -157,4 +162,47 @@ TYPED_TEST(TensorToImageFrameCalculatorTest,
 | 
			
		|||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TYPED_TEST(TensorToImageFrameCalculatorTest,
 | 
			
		||||
           Converts3DTensorToImageFrame2DGrayWithScaling) {
 | 
			
		||||
  this->SetUpRunner(true);
 | 
			
		||||
  auto& runner = this->runner_;
 | 
			
		||||
  constexpr int kWidth = 16;
 | 
			
		||||
  constexpr int kHeight = 8;
 | 
			
		||||
  const tf::TensorShape tensor_shape{kHeight, kWidth};
 | 
			
		||||
  auto tensor = absl::make_unique<tf::Tensor>(
 | 
			
		||||
      tf::DataTypeToEnum<TypeParam>::v(), tensor_shape);
 | 
			
		||||
  auto tensor_vec = tensor->template flat<TypeParam>().data();
 | 
			
		||||
 | 
			
		||||
  // Writing sequence of integers as floats which we want normalized.
 | 
			
		||||
  tensor_vec[0] = 255;
 | 
			
		||||
  for (int i = 1; i < kWidth * kHeight; ++i) {
 | 
			
		||||
    tensor_vec[i] = 200;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64_t time = 1234;
 | 
			
		||||
  runner->MutableInputs()->Tag(kTensor).packets.push_back(
 | 
			
		||||
      Adopt(tensor.release()).At(Timestamp(time)));
 | 
			
		||||
 | 
			
		||||
  if (!std::is_same<TypeParam, float>::value) {
 | 
			
		||||
    EXPECT_FALSE(runner->Run().ok());
 | 
			
		||||
    return;  // Short circuit because does not apply to other types.
 | 
			
		||||
  } else {
 | 
			
		||||
    EXPECT_TRUE(runner->Run().ok());
 | 
			
		||||
    const std::vector<Packet>& output_packets =
 | 
			
		||||
        runner->Outputs().Tag(kImage).packets;
 | 
			
		||||
    EXPECT_EQ(1, output_packets.size());
 | 
			
		||||
    EXPECT_EQ(time, output_packets[0].Timestamp().Value());
 | 
			
		||||
    const ImageFrame& output_image = output_packets[0].Get<ImageFrame>();
 | 
			
		||||
    EXPECT_EQ(ImageFormat::GRAY8, output_image.Format());
 | 
			
		||||
    EXPECT_EQ(kWidth, output_image.Width());
 | 
			
		||||
    EXPECT_EQ(kHeight, output_image.Height());
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(255, output_image.PixelData()[0]);
 | 
			
		||||
    for (int i = 1; i < kWidth * kHeight; ++i) {
 | 
			
		||||
      const uint8_t pixel_value = output_image.PixelData()[i];
 | 
			
		||||
      ASSERT_EQ(0, pixel_value);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1355,6 +1355,22 @@ cc_test(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "calculator_graph_summary_packet_test",
 | 
			
		||||
    srcs = ["calculator_graph_summary_packet_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":calculator_framework",
 | 
			
		||||
        ":packet",
 | 
			
		||||
        "//mediapipe/framework/api2:node",
 | 
			
		||||
        "//mediapipe/framework/api2:packet",
 | 
			
		||||
        "//mediapipe/framework/api2:port",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/framework/port:parse_text_proto",
 | 
			
		||||
        "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
 | 
			
		||||
        "//mediapipe/framework/tool:sink",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "calculator_runner_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -109,9 +109,20 @@ class CalculatorContext {
 | 
			
		|||
  // use OutputStream::SetOffset() directly.
 | 
			
		||||
  void SetOffset(TimestampDiff offset);
 | 
			
		||||
 | 
			
		||||
  // Returns the status of the graph run.
 | 
			
		||||
  // DEPRECATED: This was intended to get graph run status during
 | 
			
		||||
  // `CalculatorBase::Close` call. However, `Close` can run simultaneously with
 | 
			
		||||
  // other calculators `CalculatorBase::Process`, hence the actual graph
 | 
			
		||||
  // status may change any time and returned graph status here does not
 | 
			
		||||
  // necessarily reflect the actual graph status.
 | 
			
		||||
  //
 | 
			
		||||
  // NOTE: This method should only be called during CalculatorBase::Close().
 | 
			
		||||
  // As an alternative, instead of checking graph status in `Close` and doing
 | 
			
		||||
  // work for "done" state, you can enable timestamp bound processing for your
 | 
			
		||||
  // calculator (`CalculatorContract::SetProcessTimestampBounds`) to trigger
 | 
			
		||||
  // `Process` on timestamp bound updates and handle "done" state there.
 | 
			
		||||
  // Check examples in:
 | 
			
		||||
  // mediapipe/framework/calculator_graph_summary_packet_test.cc.
 | 
			
		||||
  //
 | 
			
		||||
  ABSL_DEPRECATED("Does not reflect the actual graph status.")
 | 
			
		||||
  absl::Status GraphStatus() const { return graph_status_; }
 | 
			
		||||
 | 
			
		||||
  ProfilingContext* GetProfilingContext() const {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										327
									
								
								mediapipe/framework/calculator_graph_summary_packet_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										327
									
								
								mediapipe/framework/calculator_graph_summary_packet_test.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,327 @@
 | 
			
		|||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/api2/packet.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::api2::Input;
 | 
			
		||||
using ::mediapipe::api2::Node;
 | 
			
		||||
using ::mediapipe::api2::Output;
 | 
			
		||||
using ::testing::ElementsAre;
 | 
			
		||||
using ::testing::Eq;
 | 
			
		||||
using ::testing::IsEmpty;
 | 
			
		||||
using ::testing::Value;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
MATCHER_P2(IntPacket, value, timestamp, "") {
 | 
			
		||||
  *result_listener << "where object is (value: " << arg.template Get<int>()
 | 
			
		||||
                   << ", timestamp: " << arg.Timestamp() << ")";
 | 
			
		||||
  return Value(arg.template Get<int>(), Eq(value)) &&
 | 
			
		||||
         Value(arg.Timestamp(), Eq(timestamp));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculates and produces sum of all passed inputs when no more packets can be
 | 
			
		||||
// expected on the input stream.
 | 
			
		||||
class SummaryPacketCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<int> kIn{"IN"};
 | 
			
		||||
  static constexpr Output<int> kOut{"SUMMARY"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
			
		||||
 | 
			
		||||
  static absl::Status UpdateContract(CalculatorContract* cc) {
 | 
			
		||||
    // Makes sure there are no automatic timestamp bound updates when Process
 | 
			
		||||
    // is called.
 | 
			
		||||
    cc->SetTimestampOffset(TimestampDiff::Unset());
 | 
			
		||||
    // Currently, only ImmediateInputStreamHandler supports "done" timestamp
 | 
			
		||||
    // bound update. (ImmediateInputStreamhandler handles multiple input
 | 
			
		||||
    // streams differently, so, in that case, calculator adjustments may be
 | 
			
		||||
    // required.)
 | 
			
		||||
    // TODO: update all input stream handlers to support "done"
 | 
			
		||||
    // timestamp bound update.
 | 
			
		||||
    cc->SetInputStreamHandler("ImmediateInputStreamHandler");
 | 
			
		||||
    // Enables processing timestamp bound updates. For this use case we are
 | 
			
		||||
    // specifically interested in "done" timestamp bound update. (E.g. when
 | 
			
		||||
    // all input packet sources are closed.)
 | 
			
		||||
    cc->SetProcessTimestampBounds(true);
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
			
		||||
    if (!kIn(cc).IsEmpty()) {
 | 
			
		||||
      value_ += kIn(cc).Get();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (kOut(cc).IsClosed()) {
 | 
			
		||||
      // This can happen:
 | 
			
		||||
      // 1. If, during previous invocation, kIn(cc).IsDone() == true (e.g.
 | 
			
		||||
      //    source calculator finished generating packets sent to kIn) and
 | 
			
		||||
      //    HasNextAllowedInStream() == true (which is an often case).
 | 
			
		||||
      // 2. For Timestamp::PreStream, ImmediateInputStreamHandler will still
 | 
			
		||||
      //    invoke Process() with Timestamp::Max to indicate "Done" timestamp
 | 
			
		||||
      //    bound update.
 | 
			
		||||
      return absl::OkStatus();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO: input stream holding a packet with timestamp that has
 | 
			
		||||
    // no next timestamp allowed in stream should always result in
 | 
			
		||||
    // InputStream::IsDone() == true.
 | 
			
		||||
    if (kIn(cc).IsDone() || !cc->InputTimestamp().HasNextAllowedInStream()) {
 | 
			
		||||
      // kOut(cc).Send(value_) can be used here as well, however in the case of
 | 
			
		||||
      // source calculator sending inputs into kIn the resulting timestamp is
 | 
			
		||||
      // not well defined (e.g. it can be the last packet timestamp or
 | 
			
		||||
      // Timestamp::Max())
 | 
			
		||||
      // TODO: last packet from source should always result in
 | 
			
		||||
      // InputStream::IsDone() == true.
 | 
			
		||||
      kOut(cc).Send(value_, Timestamp::Max());
 | 
			
		||||
      kOut(cc).Close();
 | 
			
		||||
    }
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  int value_ = 0;
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(SummaryPacketCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnClosingAllPacketSources) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: 'input'
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: 'IN:input'
 | 
			
		||||
      output_stream: 'SUMMARY:output'
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp(10));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  send_packet(20, Timestamp(11));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest, ProducesSummaryPacketOnMaxTimestamp) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: 'input'
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: 'IN:input'
 | 
			
		||||
      output_stream: 'SUMMARY:output'
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp(10));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  send_packet(20, Timestamp::Max());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnPreStreamTimestamp) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: 'input'
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: 'IN:input'
 | 
			
		||||
      output_stream: 'SUMMARY:output'
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp::PreStream());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnPostStreamTimestamp) {
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  CalculatorGraphConfig graph_config =
 | 
			
		||||
      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
        input_stream: 'input'
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "SummaryPacketCalculator"
 | 
			
		||||
          input_stream: 'IN:input'
 | 
			
		||||
          output_stream: 'SUMMARY:output'
 | 
			
		||||
        }
 | 
			
		||||
      )pb");
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  auto send_packet = [&graph](int value, Timestamp timestamp) {
 | 
			
		||||
    MP_ASSERT_OK(graph.AddPacketToInputStream(
 | 
			
		||||
        "input", MakePacket<int>(value).At(timestamp)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  send_packet(10, Timestamp::PostStream());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(10, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class IntGeneratorCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Output<int> kOut{"INT"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final {
 | 
			
		||||
    kOut(cc).Send(20, Timestamp(0));
 | 
			
		||||
    kOut(cc).Send(10, Timestamp(1000));
 | 
			
		||||
    return tool::StatusStop();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(IntGeneratorCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnSourceCalculatorCompletion) {
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  CalculatorGraphConfig graph_config =
 | 
			
		||||
      ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "IntGeneratorCalculator"
 | 
			
		||||
          output_stream: "INT:int_value"
 | 
			
		||||
        }
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "SummaryPacketCalculator"
 | 
			
		||||
          input_stream: "IN:int_value"
 | 
			
		||||
          output_stream: "SUMMARY:output"
 | 
			
		||||
        }
 | 
			
		||||
      )pb");
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_EXPECT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class EmitOnCloseCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<int> kIn{"IN"};
 | 
			
		||||
  static constexpr Output<int> kOut{"INT"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
 | 
			
		||||
 | 
			
		||||
  absl::Status Close(CalculatorContext* cc) final {
 | 
			
		||||
    kOut(cc).Send(20, Timestamp(0));
 | 
			
		||||
    kOut(cc).Send(10, Timestamp(1000));
 | 
			
		||||
    return absl::OkStatus();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(EmitOnCloseCalculator);
 | 
			
		||||
 | 
			
		||||
TEST(SummaryPacketCalculatorUseCaseTest,
 | 
			
		||||
     ProducesSummaryPacketOnAnotherCalculatorClosure) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
 | 
			
		||||
    input_stream: "input"
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "EmitOnCloseCalculator"
 | 
			
		||||
      input_stream: "IN:input"
 | 
			
		||||
      output_stream: "INT:int_value"
 | 
			
		||||
    }
 | 
			
		||||
    node {
 | 
			
		||||
      calculator: "SummaryPacketCalculator"
 | 
			
		||||
      input_stream: "IN:int_value"
 | 
			
		||||
      output_stream: "SUMMARY:output"
 | 
			
		||||
    }
 | 
			
		||||
  )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("output", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_ASSERT_OK(graph.Initialize(graph_config, {}));
 | 
			
		||||
  MP_ASSERT_OK(graph.StartRun({}));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseInputStream("input"));
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilIdle());
 | 
			
		||||
  EXPECT_THAT(output_packets, ElementsAre(IntPacket(30, Timestamp::Max())));
 | 
			
		||||
 | 
			
		||||
  output_packets.clear();
 | 
			
		||||
  MP_ASSERT_OK(graph.CloseAllPacketSources());
 | 
			
		||||
  MP_ASSERT_OK(graph.WaitUntilDone());
 | 
			
		||||
  EXPECT_THAT(output_packets, IsEmpty());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -117,11 +117,18 @@ class Tensor {
 | 
			
		|||
    Shape() = default;
 | 
			
		||||
    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
			
		||||
    Shape(const std::vector<int>& dimensions) : dims(dimensions) {}
 | 
			
		||||
    Shape(std::initializer_list<int> dimensions, bool is_dynamic)
 | 
			
		||||
        : dims(dimensions), is_dynamic(is_dynamic) {}
 | 
			
		||||
    Shape(const std::vector<int>& dimensions, bool is_dynamic)
 | 
			
		||||
        : dims(dimensions), is_dynamic(is_dynamic) {}
 | 
			
		||||
    int num_elements() const {
 | 
			
		||||
      return std::accumulate(dims.begin(), dims.end(), 1,
 | 
			
		||||
                             std::multiplies<int>());
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<int> dims;
 | 
			
		||||
    // The Tensor has dynamic rather than static shape so the TFLite interpreter
 | 
			
		||||
    // needs to be reallocated. Only relevant for CPU.
 | 
			
		||||
    bool is_dynamic = false;
 | 
			
		||||
  };
 | 
			
		||||
  // Quantization parameters corresponding to the zero_point and scale value
 | 
			
		||||
  // made available by TfLite quantized (uint8/int8) tensors.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@
 | 
			
		|||
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,17 @@ TEST(General, TestDataTypes) {
 | 
			
		|||
  EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(General, TestDynamic) {
 | 
			
		||||
  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape({1, 2, 3, 4}, true));
 | 
			
		||||
  EXPECT_EQ(t1.shape().num_elements(), 1 * 2 * 3 * 4);
 | 
			
		||||
  EXPECT_TRUE(t1.shape().is_dynamic);
 | 
			
		||||
 | 
			
		||||
  std::vector<int> t2_dims = {4, 3, 2, 3};
 | 
			
		||||
  Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape(t2_dims, true));
 | 
			
		||||
  EXPECT_EQ(t2.shape().num_elements(), 4 * 3 * 2 * 3);
 | 
			
		||||
  EXPECT_TRUE(t2.shape().is_dynamic);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(Cpu, TestMemoryAllocation) {
 | 
			
		||||
  Tensor t1(Tensor::ElementType::kFloat32, Tensor::Shape{4, 3, 2, 3});
 | 
			
		||||
  auto v1 = t1.GetCpuWriteView();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -131,6 +131,13 @@ Timestamp Timestamp::NextAllowedInStream() const {
 | 
			
		|||
  return *this + 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Timestamp::HasNextAllowedInStream() const {
 | 
			
		||||
  if (*this >= Max() || *this == PreStream()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Timestamp Timestamp::PreviousAllowedInStream() const {
 | 
			
		||||
  if (*this <= Min() || *this == PostStream()) {
 | 
			
		||||
    // Indicates that no previous timestamps may occur.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,6 +186,10 @@ class Timestamp {
 | 
			
		|||
  // CHECKs that this->IsAllowedInStream().
 | 
			
		||||
  Timestamp NextAllowedInStream() const;
 | 
			
		||||
 | 
			
		||||
  // Returns true if there's a next timestamp in the range [Min .. Max] after
 | 
			
		||||
  // this one.
 | 
			
		||||
  bool HasNextAllowedInStream() const;
 | 
			
		||||
 | 
			
		||||
  // Returns the previous timestamp in the range [Min .. Max], or
 | 
			
		||||
  // Unstarted() if no Packets may preceed one with this timestamp.
 | 
			
		||||
  Timestamp PreviousAllowedInStream() const;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -125,6 +125,22 @@ TEST(TimestampTest, NextAllowedInStream) {
 | 
			
		|||
            Timestamp::PostStream().NextAllowedInStream());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TimestampTest, HasNextAllowedInStream) {
 | 
			
		||||
  EXPECT_TRUE(Timestamp::Min().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE((Timestamp::Min() + 1).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE(Timestamp(-1000).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE(Timestamp(0).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE(Timestamp(1000).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE((Timestamp::Max() - 2).HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_TRUE((Timestamp::Max() - 1).HasNextAllowedInStream());
 | 
			
		||||
 | 
			
		||||
  EXPECT_FALSE(Timestamp::PreStream().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::Max().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::PostStream().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::OneOverPostStream().HasNextAllowedInStream());
 | 
			
		||||
  EXPECT_FALSE(Timestamp::Done().HasNextAllowedInStream());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TimestampTest, SpecialValueDifferences) {
 | 
			
		||||
  {  // Lower range
 | 
			
		||||
    const std::vector<Timestamp> timestamps = {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -34,7 +34,7 @@ objc_library(
 | 
			
		|||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    data = [
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_models",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:hand_landmarker.task",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_images",
 | 
			
		||||
        "//mediapipe/tasks/testdata/vision:test_protos",
 | 
			
		||||
    ],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,7 +37,6 @@ static ResourceFileInfo *const kExpectedThumbUpLandmarksFile =
 | 
			
		|||
static ResourceFileInfo *const kExpectedPointingUpRotatedLandmarksFile =
 | 
			
		||||
    @{@"name" : @"pointing_up_rotated_landmarks", @"type" : kPbFileExtension};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
 | 
			
		||||
static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -77,11 +76,9 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)thumbUpHandLandmarkerResult {
 | 
			
		||||
  NSString *filePath =
 | 
			
		||||
      [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
 | 
			
		||||
  NSString *filePath = [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedThumbUpLandmarksFile];
 | 
			
		||||
 | 
			
		||||
  return [MPPHandLandmarkerResult
 | 
			
		||||
      handLandmarkerResultFromTextEncodedProtobufFileWithName:filePath
 | 
			
		||||
  return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
 | 
			
		||||
                                                         shouldRemoveZPosition:YES];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -89,8 +86,7 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
  NSString *filePath =
 | 
			
		||||
      [MPPHandLandmarkerTests filePathWithFileInfo:kExpectedPointingUpRotatedLandmarksFile];
 | 
			
		||||
 | 
			
		||||
  return [MPPHandLandmarkerResult
 | 
			
		||||
      handLandmarkerResultFromTextEncodedProtobufFileWithName:filePath
 | 
			
		||||
  return [MPPHandLandmarkerResult handLandmarkerResultFromProtobufFileWithName:filePath
 | 
			
		||||
                                                         shouldRemoveZPosition:YES];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -133,8 +129,7 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertHandLandmarkerResult:(MPPHandLandmarkerResult *)handLandmarkerResult
 | 
			
		||||
    isApproximatelyEqualToExpectedResult:
 | 
			
		||||
        (MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
			
		||||
    isApproximatelyEqualToExpectedResult:(MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
			
		||||
  [self assertMultiHandLandmarks:handLandmarkerResult.landmarks
 | 
			
		||||
      areApproximatelyEqualToExpectedMultiHandLandmarks:expectedHandLandmarkerResult.landmarks];
 | 
			
		||||
  [self assertMultiHandWorldLandmarks:handLandmarkerResult.worldLandmarks
 | 
			
		||||
| 
						 | 
				
			
			@ -161,8 +156,7 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
- (MPPHandLandmarkerOptions *)handLandmarkerOptionsWithModelFileInfo:
 | 
			
		||||
    (ResourceFileInfo *)modelFileInfo {
 | 
			
		||||
  NSString *modelPath = [MPPHandLandmarkerTests filePathWithFileInfo:modelFileInfo];
 | 
			
		||||
  MPPHandLandmarkerOptions *handLandmarkerOptions =
 | 
			
		||||
      [[MPPHandLandmarkerOptions alloc] init];
 | 
			
		||||
  MPPHandLandmarkerOptions *handLandmarkerOptions = [[MPPHandLandmarkerOptions alloc] init];
 | 
			
		||||
  handLandmarkerOptions.baseOptions.modelAssetPath = modelPath;
 | 
			
		||||
 | 
			
		||||
  return handLandmarkerOptions;
 | 
			
		||||
| 
						 | 
				
			
			@ -170,21 +164,22 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
 | 
			
		||||
- (MPPHandLandmarker *)createHandLandmarkerWithOptionsSucceeds:
 | 
			
		||||
    (MPPHandLandmarkerOptions *)handLandmarkerOptions {
 | 
			
		||||
  NSError* error;
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:nil];
 | 
			
		||||
      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
 | 
			
		||||
  XCTAssertNotNil(handLandmarker);
 | 
			
		||||
  XCTAssertNil(error);
 | 
			
		||||
 | 
			
		||||
  return handLandmarker;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
- (void)assertCreateHandLandmarkerWithOptions:
 | 
			
		||||
            (MPPHandLandmarkerOptions *)handLandmarkerOptions
 | 
			
		||||
- (void)assertCreateHandLandmarkerWithOptions:(MPPHandLandmarkerOptions *)handLandmarkerOptions
 | 
			
		||||
                       failsWithExpectedError:(NSError *)expectedError {
 | 
			
		||||
  NSError *error = nil;
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [[MPPHandLandmarker alloc] initWithOptions:handLandmarkerOptions error:&error];
 | 
			
		||||
 | 
			
		||||
  XCTAssertNil(handLandmarkerOptions);
 | 
			
		||||
  XCTAssertNil(handLandmarker);
 | 
			
		||||
  AssertEqualErrors(error, expectedError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -211,11 +206,9 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
- (MPPHandLandmarkerResult *)detectInImageWithFileInfo:(ResourceFileInfo *)imageFileInfo
 | 
			
		||||
                                    usingHandLandmarker:
 | 
			
		||||
                                        (MPPHandLandmarker *)handLandmarker {
 | 
			
		||||
                                   usingHandLandmarker:(MPPHandLandmarker *)handLandmarker {
 | 
			
		||||
  MPPImage *mppImage = [self imageWithFileInfo:imageFileInfo];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage
 | 
			
		||||
                                                                                    error:nil];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
 | 
			
		||||
  XCTAssertNotNil(handLandmarkerResult);
 | 
			
		||||
 | 
			
		||||
  return handLandmarkerResult;
 | 
			
		||||
| 
						 | 
				
			
			@ -225,8 +218,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
                             usingHandLandmarker:(MPPHandLandmarker *)handLandmarker
 | 
			
		||||
         approximatelyEqualsHandLandmarkerResult:
 | 
			
		||||
             (MPPHandLandmarkerResult *)expectedHandLandmarkerResult {
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult =
 | 
			
		||||
      [self detectInImageWithFileInfo:fileInfo usingHandLandmarker:handLandmarker];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:fileInfo
 | 
			
		||||
                                                              usingHandLandmarker:handLandmarker];
 | 
			
		||||
  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:expectedHandLandmarkerResult];
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -236,8 +229,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
- (void)testDetectWithModelPathSucceeds {
 | 
			
		||||
  NSString *modelPath =
 | 
			
		||||
      [MPPHandLandmarkerTests filePathWithFileInfo:kHandLandmarkerBundleAssetFile];
 | 
			
		||||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [[MPPHandLandmarker alloc] initWithModelPath:modelPath error:nil];
 | 
			
		||||
  MPPHandLandmarker *handLandmarker = [[MPPHandLandmarker alloc] initWithModelPath:modelPath
 | 
			
		||||
                                                                             error:nil];
 | 
			
		||||
  XCTAssertNotNil(handLandmarker);
 | 
			
		||||
 | 
			
		||||
  [self assertResultsOfDetectInImageWithFileInfo:kThumbUpImage
 | 
			
		||||
| 
						 | 
				
			
			@ -253,8 +246,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult =
 | 
			
		||||
      [self detectInImageWithFileInfo:kNoHandsImage usingHandLandmarker:handLandmarker];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kNoHandsImage
 | 
			
		||||
                                                              usingHandLandmarker:handLandmarker];
 | 
			
		||||
  AssertHandLandmarkerResultIsEmpty(handLandmarkerResult);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -268,8 +261,8 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
  MPPHandLandmarker *handLandmarker =
 | 
			
		||||
      [self createHandLandmarkerWithOptionsSucceeds:handLandmarkerOptions];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult =
 | 
			
		||||
      [self detectInImageWithFileInfo:kTwoHandsImage usingHandLandmarker:handLandmarker];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [self detectInImageWithFileInfo:kTwoHandsImage
 | 
			
		||||
                                                              usingHandLandmarker:handLandmarker];
 | 
			
		||||
 | 
			
		||||
  XCTAssertTrue(handLandmarkerResult.handedness.count == numHands);
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -284,12 +277,11 @@ static const float kLandmarksErrorTolerance = 0.03f;
 | 
			
		|||
  MPPImage *mppImage = [self imageWithFileInfo:kPointingUpRotatedImage
 | 
			
		||||
                                   orientation:UIImageOrientationRight];
 | 
			
		||||
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage
 | 
			
		||||
                                                                          error:nil];
 | 
			
		||||
  MPPHandLandmarkerResult *handLandmarkerResult = [handLandmarker detectInImage:mppImage error:nil];
 | 
			
		||||
 | 
			
		||||
  [self assertHandLandmarkerResult:handLandmarkerResult
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests pointingUpRotatedHandLandmarkerResult]];                                                                        
 | 
			
		||||
 | 
			
		||||
      isApproximatelyEqualToExpectedResult:[MPPHandLandmarkerTests
 | 
			
		||||
                                               pointingUpRotatedHandLandmarkerResult]];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#pragma mark Running Mode Tests
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,10 +12,11 @@ objc_library(
 | 
			
		|||
        "-x objective-c++",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/common/utils:NSStringHelpers",
 | 
			
		||||
        "//mediapipe/tasks/ios/test/vision/utils:parse_proto_utils",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/hand_landmarker:MPPHandLandmarkerResult",
 | 
			
		||||
        "//mediapipe/tasks/ios/vision/hand_landmarker/utils:MPPHandLandmarkerResultHelpers",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,8 +18,7 @@
 | 
			
		|||
NS_ASSUME_NONNULL_BEGIN
 | 
			
		||||
@interface MPPHandLandmarkerResult (ProtobufHelpers)
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)
 | 
			
		||||
    handLandmarkerResultFromTextEncodedProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
                                                    shouldRemoveZPosition:(BOOL)removeZPosition;
 | 
			
		||||
 | 
			
		||||
@end
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,8 +31,7 @@ using ::mediapipe::tasks::ios::test::vision::utils::get_proto_from_pbtxt;
 | 
			
		|||
 | 
			
		||||
@implementation MPPHandLandmarkerResult (ProtobufHelpers)
 | 
			
		||||
 | 
			
		||||
+ (MPPHandLandmarkerResult *)
 | 
			
		||||
    handLandmarkerResultFromTextEncodedProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
+ (MPPHandLandmarkerResult *)handLandmarkerResultFromProtobufFileWithName:(NSString *)fileName
 | 
			
		||||
                                                    shouldRemoveZPosition:(BOOL)removeZPosition {
 | 
			
		||||
  LandmarksDetectionResultProto landmarkDetectionResultProto;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,12 +41,12 @@ PYBIND11_MODULE(_pywrap_flatbuffers, m) {
 | 
			
		|||
        self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
 | 
			
		||||
                             contents.length());
 | 
			
		||||
      });
 | 
			
		||||
  m.def("generate_text_file", &flatbuffers::GenerateTextFile);
 | 
			
		||||
  m.def("generate_text_file", &flatbuffers::GenTextFile);
 | 
			
		||||
  m.def("generate_text",
 | 
			
		||||
        [](const flatbuffers::Parser& parser,
 | 
			
		||||
           const std::string& buffer) -> std::string {
 | 
			
		||||
          std::string text;
 | 
			
		||||
          const char* result = flatbuffers::GenerateText(
 | 
			
		||||
          const char* result = flatbuffers::GenText(
 | 
			
		||||
              parser, reinterpret_cast<const void*>(buffer.c_str()), &text);
 | 
			
		||||
          if (result) {
 | 
			
		||||
            return "";
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,7 +38,7 @@ mediapipe_files(srcs = [
 | 
			
		|||
])
 | 
			
		||||
 | 
			
		||||
rollup_bundle(
 | 
			
		||||
    name = "audio_bundle",
 | 
			
		||||
    name = "audio_bundle_mjs",
 | 
			
		||||
    config_file = "//mediapipe/tasks/web:rollup.config.mjs",
 | 
			
		||||
    entry_point = "index.ts",
 | 
			
		||||
    format = "esm",
 | 
			
		||||
| 
						 | 
				
			
			@ -69,6 +69,29 @@ rollup_bundle(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
genrule(
 | 
			
		||||
    name = "audio_sources",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        ":audio_bundle_cjs",
 | 
			
		||||
        ":audio_bundle_mjs",
 | 
			
		||||
    ],
 | 
			
		||||
    outs = [
 | 
			
		||||
        "audio_bundle.cjs",
 | 
			
		||||
        "audio_bundle.cjs.map",
 | 
			
		||||
        "audio_bundle.mjs",
 | 
			
		||||
        "audio_bundle.mjs.map",
 | 
			
		||||
    ],
 | 
			
		||||
    cmd = (
 | 
			
		||||
        "for FILE in $(SRCS); do " +
 | 
			
		||||
        "  OUT_FILE=$(GENDIR)/mediapipe/tasks/web/audio/$$(" +
 | 
			
		||||
        "      basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
 | 
			
		||||
        "  ); " +
 | 
			
		||||
        "  echo $$FILE ; echo $$OUT_FILE ; " +
 | 
			
		||||
        "  cp $$FILE $$OUT_FILE ; " +
 | 
			
		||||
        "done;"
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
genrule(
 | 
			
		||||
    name = "package_json",
 | 
			
		||||
    srcs = ["//mediapipe/tasks/web:package.json"],
 | 
			
		||||
| 
						 | 
				
			
			@ -91,8 +114,7 @@ pkg_npm(
 | 
			
		|||
        "wasm/audio_wasm_internal.wasm",
 | 
			
		||||
        "wasm/audio_wasm_nosimd_internal.js",
 | 
			
		||||
        "wasm/audio_wasm_nosimd_internal.wasm",
 | 
			
		||||
        ":audio_bundle",
 | 
			
		||||
        ":audio_bundle_cjs",
 | 
			
		||||
        ":audio_sources",
 | 
			
		||||
        ":package_json",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,11 +2,12 @@
 | 
			
		|||
  "name": "@mediapipe/tasks-__NAME__",
 | 
			
		||||
  "version": "__VERSION__",
 | 
			
		||||
  "description": "__DESCRIPTION__",
 | 
			
		||||
  "main": "__NAME___bundle_cjs.js",
 | 
			
		||||
  "browser": "__NAME___bundle.js",
 | 
			
		||||
  "module": "__NAME___bundle.js",
 | 
			
		||||
  "main": "__NAME___bundle.cjs",
 | 
			
		||||
  "browser": "__NAME___bundle.mjs",
 | 
			
		||||
  "module": "__NAME___bundle.mjs",
 | 
			
		||||
  "author": "mediapipe@google.com",
 | 
			
		||||
  "license": "Apache-2.0",
 | 
			
		||||
  "type": "module",
 | 
			
		||||
  "types": "__TYPES__",
 | 
			
		||||
  "homepage": "http://mediapipe.dev",
 | 
			
		||||
  "keywords": [ "AR", "ML", "Augmented", "MediaPipe", "MediaPipe Tasks" ]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,7 +39,7 @@ mediapipe_ts_library(
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
rollup_bundle(
 | 
			
		||||
    name = "text_bundle",
 | 
			
		||||
    name = "text_bundle_mjs",
 | 
			
		||||
    config_file = "//mediapipe/tasks/web:rollup.config.mjs",
 | 
			
		||||
    entry_point = "index.ts",
 | 
			
		||||
    format = "esm",
 | 
			
		||||
| 
						 | 
				
			
			@ -70,6 +70,29 @@ rollup_bundle(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
genrule(
 | 
			
		||||
    name = "text_sources",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        ":text_bundle_cjs",
 | 
			
		||||
        ":text_bundle_mjs",
 | 
			
		||||
    ],
 | 
			
		||||
    outs = [
 | 
			
		||||
        "text_bundle.cjs",
 | 
			
		||||
        "text_bundle.cjs.map",
 | 
			
		||||
        "text_bundle.mjs",
 | 
			
		||||
        "text_bundle.mjs.map",
 | 
			
		||||
    ],
 | 
			
		||||
    cmd = (
 | 
			
		||||
        "for FILE in $(SRCS); do " +
 | 
			
		||||
        "  OUT_FILE=$(GENDIR)/mediapipe/tasks/web/text/$$(" +
 | 
			
		||||
        "      basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
 | 
			
		||||
        "  ); " +
 | 
			
		||||
        "  echo $$FILE ; echo $$OUT_FILE ; " +
 | 
			
		||||
        "  cp $$FILE $$OUT_FILE ; " +
 | 
			
		||||
        "done;"
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
genrule(
 | 
			
		||||
    name = "package_json",
 | 
			
		||||
    srcs = ["//mediapipe/tasks/web:package.json"],
 | 
			
		||||
| 
						 | 
				
			
			@ -93,7 +116,6 @@ pkg_npm(
 | 
			
		|||
        "wasm/text_wasm_nosimd_internal.js",
 | 
			
		||||
        "wasm/text_wasm_nosimd_internal.wasm",
 | 
			
		||||
        ":package_json",
 | 
			
		||||
        ":text_bundle",
 | 
			
		||||
        ":text_bundle_cjs",
 | 
			
		||||
        ":text_sources",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,7 @@ mediapipe_ts_library(
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
rollup_bundle(
 | 
			
		||||
    name = "vision_bundle",
 | 
			
		||||
    name = "vision_bundle_mjs",
 | 
			
		||||
    config_file = "//mediapipe/tasks/web:rollup.config.mjs",
 | 
			
		||||
    entry_point = "index.ts",
 | 
			
		||||
    format = "esm",
 | 
			
		||||
| 
						 | 
				
			
			@ -81,6 +81,29 @@ rollup_bundle(
 | 
			
		|||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
genrule(
 | 
			
		||||
    name = "vision_sources",
 | 
			
		||||
    srcs = [
 | 
			
		||||
        ":vision_bundle_cjs",
 | 
			
		||||
        ":vision_bundle_mjs",
 | 
			
		||||
    ],
 | 
			
		||||
    outs = [
 | 
			
		||||
        "vision_bundle.cjs",
 | 
			
		||||
        "vision_bundle.cjs.map",
 | 
			
		||||
        "vision_bundle.mjs",
 | 
			
		||||
        "vision_bundle.mjs.map",
 | 
			
		||||
    ],
 | 
			
		||||
    cmd = (
 | 
			
		||||
        "for FILE in $(SRCS); do " +
 | 
			
		||||
        "  OUT_FILE=$(GENDIR)/mediapipe/tasks/web/vision/$$(" +
 | 
			
		||||
        "      basename $$FILE | sed -E 's/_([cm])js\\.js/.\\1js/'" +
 | 
			
		||||
        "  ); " +
 | 
			
		||||
        "  echo $$FILE ; echo $$OUT_FILE ; " +
 | 
			
		||||
        "  cp $$FILE $$OUT_FILE ; " +
 | 
			
		||||
        "done;"
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
genrule(
 | 
			
		||||
    name = "package_json",
 | 
			
		||||
    srcs = ["//mediapipe/tasks/web:package.json"],
 | 
			
		||||
| 
						 | 
				
			
			@ -104,7 +127,6 @@ pkg_npm(
 | 
			
		|||
        "wasm/vision_wasm_nosimd_internal.js",
 | 
			
		||||
        "wasm/vision_wasm_nosimd_internal.wasm",
 | 
			
		||||
        ":package_json",
 | 
			
		||||
        ":vision_bundle",
 | 
			
		||||
        ":vision_bundle_cjs",
 | 
			
		||||
        ":vision_sources",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,7 +26,6 @@
 | 
			
		|||
#include <fstream>
 | 
			
		||||
 | 
			
		||||
#include "absl/algorithm/container.h"
 | 
			
		||||
#include "absl/flags/flag.h"
 | 
			
		||||
#include "absl/strings/match.h"
 | 
			
		||||
#include "absl/strings/numbers.h"
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -35,23 +34,14 @@
 | 
			
		|||
#include "mediapipe/framework/port/integral_types.h"
 | 
			
		||||
#include "mediapipe/framework/port/statusor.h"
 | 
			
		||||
 | 
			
		||||
ABSL_FLAG(std::string, system_cpu_max_freq_file,
 | 
			
		||||
          "/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq",
 | 
			
		||||
          "The file pattern for CPU max frequencies, where $0 will be replaced "
 | 
			
		||||
          "with the CPU id.");
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
constexpr uint32_t kBufferLength = 64;
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::string> GetFilePath(int cpu) {
 | 
			
		||||
  if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
 | 
			
		||||
    return absl::InvalidArgumentError(
 | 
			
		||||
        absl::StrCat("Invalid frequency file: ",
 | 
			
		||||
                     absl::GetFlag(FLAGS_system_cpu_max_freq_file)));
 | 
			
		||||
  }
 | 
			
		||||
  return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
 | 
			
		||||
  return absl::Substitute(
 | 
			
		||||
      "/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq", cpu);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -147,6 +147,22 @@ absl::Status ReconcileMetadataImages(const std::string& prefix,
 | 
			
		|||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reconciles metadata for all images.
 | 
			
		||||
absl::Status ReconcileMetadataImages(tensorflow::SequenceExample* sequence) {
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataImages("", sequence));
 | 
			
		||||
  for (const auto& key_value : sequence->feature_lists().feature_list()) {
 | 
			
		||||
    const auto& key = key_value.first;
 | 
			
		||||
    if (::absl::StrContains(key, kImageTimestampKey)) {
 | 
			
		||||
      std::string prefix = "";
 | 
			
		||||
      if (key != kImageTimestampKey) {
 | 
			
		||||
        prefix = key.substr(0, key.size() - sizeof(kImageTimestampKey));
 | 
			
		||||
      }
 | 
			
		||||
      RET_CHECK_OK(ReconcileMetadataImages(prefix, sequence));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sets the values of "feature/${TAG}/dimensions", and
 | 
			
		||||
// "feature/${TAG}/frame_rate" for each float list feature TAG. If the
 | 
			
		||||
// dimensions are already present as a context feature, this method verifies
 | 
			
		||||
| 
						 | 
				
			
			@ -545,10 +561,7 @@ absl::Status ReconcileMetadata(bool reconcile_bbox_annotations,
 | 
			
		|||
                               bool reconcile_region_annotations,
 | 
			
		||||
                               tensorflow::SequenceExample* sequence) {
 | 
			
		||||
  RET_CHECK_OK(ReconcileAnnotationIndicesByImageTimestamps(sequence));
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataImages("", sequence));
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataImages(kForwardFlowPrefix, sequence));
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataImages(kClassSegmentationPrefix, sequence));
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataImages(kInstanceSegmentationPrefix, sequence));
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataImages(sequence));
 | 
			
		||||
  RET_CHECK_OK(ReconcileMetadataFeatureFloats(sequence));
 | 
			
		||||
  if (reconcile_bbox_annotations) {
 | 
			
		||||
    RET_CHECK_OK(ReconcileMetadataBoxAnnotations("", sequence));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										8
									
								
								third_party/flatbuffers/workspace.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								third_party/flatbuffers/workspace.bzl
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
 | 
			
		|||
def repo():
 | 
			
		||||
    third_party_http_archive(
 | 
			
		||||
        name = "flatbuffers",
 | 
			
		||||
        strip_prefix = "flatbuffers-23.5.8",
 | 
			
		||||
        sha256 = "55b75dfa5b6f6173e4abf9c35284a10482ba65db886b39db511eba6c244f1e88",
 | 
			
		||||
        strip_prefix = "flatbuffers-23.5.26",
 | 
			
		||||
        sha256 = "1cce06b17cddd896b6d73cc047e36a254fb8df4d7ea18a46acf16c4c0cd3f3f3",
 | 
			
		||||
        urls = [
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.5.8.tar.gz",
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.5.26.tar.gz",
 | 
			
		||||
            "https://github.com/google/flatbuffers/archive/v23.5.26.tar.gz",
 | 
			
		||||
        ],
 | 
			
		||||
        build_file = "//third_party/flatbuffers:BUILD.bazel",
 | 
			
		||||
        delete = ["build_defs.bzl", "BUILD.bazel"],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user