Internal change
PiperOrigin-RevId: 477585110
This commit is contained in:
		
							parent
							
								
									8c8a9cda5a
								
							
						
					
					
						commit
						dcc5587483
					
				
							
								
								
									
										74
									
								
								mediapipe/calculators/tensor/text_to_tensor_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								mediapipe/calculators/tensor/text_to_tensor_calculator.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,74 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/framework/api2/node.h"
 | 
			
		||||
#include "mediapipe/framework/api2/port.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_context.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace api2 {
 | 
			
		||||
 | 
			
		||||
// Trivially converts an input string into a Tensor that stores a copy of
 | 
			
		||||
// the string.
 | 
			
		||||
//
 | 
			
		||||
// Inputs:
 | 
			
		||||
//   TEXT - std::string
 | 
			
		||||
//
 | 
			
		||||
// Outputs:
 | 
			
		||||
//   TENSORS - std::vector<Tensor>
 | 
			
		||||
//     Vector containing a single Tensor storing a copy of the input string.
 | 
			
		||||
//     Note that the underlying buffer of the Tensor is not necessarily
 | 
			
		||||
//     null-terminated. It is the graph writer's responsibility to copy the
 | 
			
		||||
//     correct number of characters when copying from this Tensor's buffer.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//   node {
 | 
			
		||||
//     calculator: "TextToTensorCalculator"
 | 
			
		||||
//     input_stream: "TEXT:text"
 | 
			
		||||
//     output_stream: "TENSORS:tensors"
 | 
			
		||||
//   }
 | 
			
		||||
class TextToTensorCalculator : public Node {
 | 
			
		||||
 public:
 | 
			
		||||
  static constexpr Input<std::string> kTextIn{"TEXT"};
 | 
			
		||||
  static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
 | 
			
		||||
 | 
			
		||||
  MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);
 | 
			
		||||
 | 
			
		||||
  absl::Status Process(CalculatorContext* cc) override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
 | 
			
		||||
  absl::string_view text = kTextIn(cc).Get();
 | 
			
		||||
  int input_len = static_cast<int>(text.length());
 | 
			
		||||
 | 
			
		||||
  std::vector<Tensor> result;
 | 
			
		||||
  result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})});
 | 
			
		||||
  std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
 | 
			
		||||
              input_len * sizeof(char));
 | 
			
		||||
  kTensorsOut(cc).Send(std::move(result));
 | 
			
		||||
  return absl::OkStatus();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);
 | 
			
		||||
 | 
			
		||||
}  // namespace api2
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,88 @@
 | 
			
		|||
// Copyright 2022 The MediaPipe Authors.
 | 
			
		||||
//
 | 
			
		||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
// you may not use this file except in compliance with the License.
 | 
			
		||||
// You may obtain a copy of the License at
 | 
			
		||||
//
 | 
			
		||||
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
//
 | 
			
		||||
// Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
// See the License for the specific language governing permissions and
 | 
			
		||||
// limitations under the License.
 | 
			
		||||
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "absl/strings/substitute.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_framework.h"
 | 
			
		||||
#include "mediapipe/framework/calculator_graph.h"
 | 
			
		||||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
#include "mediapipe/framework/tool/options_map.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::testing::StrEq;
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
 | 
			
		||||
  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
			
		||||
      R"pb(
 | 
			
		||||
        input_stream: "text"
 | 
			
		||||
        output_stream: "tensors"
 | 
			
		||||
        node {
 | 
			
		||||
          calculator: "TextToTensorCalculator"
 | 
			
		||||
          input_stream: "TEXT:text"
 | 
			
		||||
          output_stream: "TENSORS:tensors"
 | 
			
		||||
        }
 | 
			
		||||
      )pb");
 | 
			
		||||
  std::vector<Packet> output_packets;
 | 
			
		||||
  tool::AddVectorSink("tensors", &graph_config, &output_packets);
 | 
			
		||||
 | 
			
		||||
  // Run the graph.
 | 
			
		||||
  CalculatorGraph graph;
 | 
			
		||||
  MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
 | 
			
		||||
  MP_RETURN_IF_ERROR(graph.StartRun({}));
 | 
			
		||||
  MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
 | 
			
		||||
      "text", MakePacket<std::string>(text).At(Timestamp(0))));
 | 
			
		||||
  MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
 | 
			
		||||
 | 
			
		||||
  if (output_packets.size() != 1) {
 | 
			
		||||
    return absl::InvalidArgumentError(absl::Substitute(
 | 
			
		||||
        "output_packets has size $0, expected 1", output_packets.size()));
 | 
			
		||||
  }
 | 
			
		||||
  const std::vector<Tensor>& tensor_vec =
 | 
			
		||||
      output_packets[0].Get<std::vector<Tensor>>();
 | 
			
		||||
  if (tensor_vec.size() != 1) {
 | 
			
		||||
    return absl::InvalidArgumentError(absl::Substitute(
 | 
			
		||||
        "tensor_vec has size $0, expected 1", tensor_vec.size()));
 | 
			
		||||
  }
 | 
			
		||||
  if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
 | 
			
		||||
    return absl::InvalidArgumentError(absl::Substitute(
 | 
			
		||||
        "tensor has element type $0, expected $1", tensor_vec[0].element_type(),
 | 
			
		||||
        Tensor::ElementType::kChar));
 | 
			
		||||
  }
 | 
			
		||||
  const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
 | 
			
		||||
  return std::string(buffer, text.length());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TextToTensorCalculatorTest, FooBarBaz) {
 | 
			
		||||
  EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"),
 | 
			
		||||
              IsOkAndHolds(StrEq("Foo. Bar? Baz!")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(TextToTensorCalculatorTest, Empty) {
 | 
			
		||||
  EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq("")));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace mediapipe
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +90,16 @@ class Tensor {
 | 
			
		|||
 | 
			
		||||
 public:
 | 
			
		||||
  // No resources are allocated here.
 | 
			
		||||
  enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 };
 | 
			
		||||
  enum class ElementType {
 | 
			
		||||
    kNone,
 | 
			
		||||
    kFloat16,
 | 
			
		||||
    kFloat32,
 | 
			
		||||
    kUInt8,
 | 
			
		||||
    kInt8,
 | 
			
		||||
    kInt32,
 | 
			
		||||
    // TODO: Update the inference runner to handle kTfLiteString.
 | 
			
		||||
    kChar
 | 
			
		||||
  };
 | 
			
		||||
  struct Shape {
 | 
			
		||||
    Shape() = default;
 | 
			
		||||
    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
			
		||||
| 
						 | 
				
			
			@ -319,6 +328,8 @@ class Tensor {
 | 
			
		|||
        return 1;
 | 
			
		||||
      case ElementType::kInt32:
 | 
			
		||||
        return sizeof(int32_t);
 | 
			
		||||
      case ElementType::kChar:
 | 
			
		||||
        return sizeof(char);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  int bytes() const { return shape_.num_elements() * element_size(); }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,8 @@
 | 
			
		|||
#include "mediapipe/framework/formats/tensor.h"
 | 
			
		||||
 | 
			
		||||
#include <cstring>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
			
		||||
| 
						 | 
				
			
			@ -23,6 +26,9 @@ TEST(General, TestDataTypes) {
 | 
			
		|||
 | 
			
		||||
  Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
 | 
			
		||||
  EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
 | 
			
		||||
 | 
			
		||||
  Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
 | 
			
		||||
  EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(Cpu, TestMemoryAllocation) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user