Internal change
PiperOrigin-RevId: 477585110
This commit is contained in:
		
							parent
							
								
									8c8a9cda5a
								
							
						
					
					
						commit
						dcc5587483
					
				
							
								
								
									
										74
									
								
								mediapipe/calculators/tensor/text_to_tensor_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								mediapipe/calculators/tensor/text_to_tensor_calculator.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,74 @@
 | 
				
			||||||
 | 
					// Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstring>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/node.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/api2/port.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_context.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					namespace api2 {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Trivially converts an input string into a Tensor that stores a copy of
 | 
				
			||||||
 | 
					// the string.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Inputs:
 | 
				
			||||||
 | 
					//   TEXT - std::string
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Outputs:
 | 
				
			||||||
 | 
					//   TENSORS - std::vector<Tensor>
 | 
				
			||||||
 | 
					//     Vector containing a single Tensor storing a copy of the input string.
 | 
				
			||||||
 | 
					//     Note that the underlying buffer of the Tensor is not necessarily
 | 
				
			||||||
 | 
					//     null-terminated. It is the graph writer's responsibility to copy the
 | 
				
			||||||
 | 
					//     correct number of characters when copying from this Tensor's buffer.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Example:
 | 
				
			||||||
 | 
					//   node {
 | 
				
			||||||
 | 
					//     calculator: "TextToTensorCalculator"
 | 
				
			||||||
 | 
					//     input_stream: "TEXT:text"
 | 
				
			||||||
 | 
					//     output_stream: "TENSORS:tensors"
 | 
				
			||||||
 | 
					//   }
 | 
				
			||||||
 | 
					class TextToTensorCalculator : public Node {
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  static constexpr Input<std::string> kTextIn{"TEXT"};
 | 
				
			||||||
 | 
					  static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  absl::Status Process(CalculatorContext* cc) override;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
 | 
				
			||||||
 | 
					  absl::string_view text = kTextIn(cc).Get();
 | 
				
			||||||
 | 
					  int input_len = static_cast<int>(text.length());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<Tensor> result;
 | 
				
			||||||
 | 
					  result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})});
 | 
				
			||||||
 | 
					  std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
 | 
				
			||||||
 | 
					              input_len * sizeof(char));
 | 
				
			||||||
 | 
					  kTensorsOut(cc).Send(std::move(result));
 | 
				
			||||||
 | 
					  return absl::OkStatus();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace api2
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,88 @@
 | 
				
			||||||
 | 
					// Copyright 2022 The MediaPipe Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					// you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					// You may obtain a copy of the License at
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//      http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					// See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					// limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstring>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "absl/status/status.h"
 | 
				
			||||||
 | 
					#include "absl/status/statusor.h"
 | 
				
			||||||
 | 
					#include "absl/strings/string_view.h"
 | 
				
			||||||
 | 
					#include "absl/strings/substitute.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_framework.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/calculator_graph.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/packet.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/parse_text_proto.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/port/status_matchers.h"
 | 
				
			||||||
 | 
					#include "mediapipe/framework/tool/options_map.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mediapipe {
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using ::testing::StrEq;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
 | 
				
			||||||
 | 
					  auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
 | 
				
			||||||
 | 
					      R"pb(
 | 
				
			||||||
 | 
					        input_stream: "text"
 | 
				
			||||||
 | 
					        output_stream: "tensors"
 | 
				
			||||||
 | 
					        node {
 | 
				
			||||||
 | 
					          calculator: "TextToTensorCalculator"
 | 
				
			||||||
 | 
					          input_stream: "TEXT:text"
 | 
				
			||||||
 | 
					          output_stream: "TENSORS:tensors"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      )pb");
 | 
				
			||||||
 | 
					  std::vector<Packet> output_packets;
 | 
				
			||||||
 | 
					  tool::AddVectorSink("tensors", &graph_config, &output_packets);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Run the graph.
 | 
				
			||||||
 | 
					  CalculatorGraph graph;
 | 
				
			||||||
 | 
					  MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
 | 
				
			||||||
 | 
					  MP_RETURN_IF_ERROR(graph.StartRun({}));
 | 
				
			||||||
 | 
					  MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
 | 
				
			||||||
 | 
					      "text", MakePacket<std::string>(text).At(Timestamp(0))));
 | 
				
			||||||
 | 
					  MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (output_packets.size() != 1) {
 | 
				
			||||||
 | 
					    return absl::InvalidArgumentError(absl::Substitute(
 | 
				
			||||||
 | 
					        "output_packets has size $0, expected 1", output_packets.size()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const std::vector<Tensor>& tensor_vec =
 | 
				
			||||||
 | 
					      output_packets[0].Get<std::vector<Tensor>>();
 | 
				
			||||||
 | 
					  if (tensor_vec.size() != 1) {
 | 
				
			||||||
 | 
					    return absl::InvalidArgumentError(absl::Substitute(
 | 
				
			||||||
 | 
					        "tensor_vec has size $0, expected 1", tensor_vec.size()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
 | 
				
			||||||
 | 
					    return absl::InvalidArgumentError(absl::Substitute(
 | 
				
			||||||
 | 
					        "tensor has element type $0, expected $1", tensor_vec[0].element_type(),
 | 
				
			||||||
 | 
					        Tensor::ElementType::kChar));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
 | 
				
			||||||
 | 
					  return std::string(buffer, text.length());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TextToTensorCalculatorTest, FooBarBaz) {
 | 
				
			||||||
 | 
					  EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"),
 | 
				
			||||||
 | 
					              IsOkAndHolds(StrEq("Foo. Bar? Baz!")));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST(TextToTensorCalculatorTest, Empty) {
 | 
				
			||||||
 | 
					  EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq("")));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace
 | 
				
			||||||
 | 
					}  // namespace mediapipe
 | 
				
			||||||
| 
						 | 
					@ -90,7 +90,16 @@ class Tensor {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  // No resources are allocated here.
 | 
					  // No resources are allocated here.
 | 
				
			||||||
  enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 };
 | 
					  enum class ElementType {
 | 
				
			||||||
 | 
					    kNone,
 | 
				
			||||||
 | 
					    kFloat16,
 | 
				
			||||||
 | 
					    kFloat32,
 | 
				
			||||||
 | 
					    kUInt8,
 | 
				
			||||||
 | 
					    kInt8,
 | 
				
			||||||
 | 
					    kInt32,
 | 
				
			||||||
 | 
					    // TODO: Update the inference runner to handle kTfLiteString.
 | 
				
			||||||
 | 
					    kChar
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
  struct Shape {
 | 
					  struct Shape {
 | 
				
			||||||
    Shape() = default;
 | 
					    Shape() = default;
 | 
				
			||||||
    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
					    Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
 | 
				
			||||||
| 
						 | 
					@ -319,6 +328,8 @@ class Tensor {
 | 
				
			||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
      case ElementType::kInt32:
 | 
					      case ElementType::kInt32:
 | 
				
			||||||
        return sizeof(int32_t);
 | 
					        return sizeof(int32_t);
 | 
				
			||||||
 | 
					      case ElementType::kChar:
 | 
				
			||||||
 | 
					        return sizeof(char);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  int bytes() const { return shape_.num_elements() * element_size(); }
 | 
					  int bytes() const { return shape_.num_elements() * element_size(); }
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,8 @@
 | 
				
			||||||
#include "mediapipe/framework/formats/tensor.h"
 | 
					#include "mediapipe/framework/formats/tensor.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <cstring>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mediapipe/framework/port/gmock.h"
 | 
					#include "mediapipe/framework/port/gmock.h"
 | 
				
			||||||
#include "mediapipe/framework/port/gtest.h"
 | 
					#include "mediapipe/framework/port/gtest.h"
 | 
				
			||||||
#if !MEDIAPIPE_DISABLE_GPU
 | 
					#if !MEDIAPIPE_DISABLE_GPU
 | 
				
			||||||
| 
						 | 
					@ -23,6 +26,9 @@ TEST(General, TestDataTypes) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
 | 
					  Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
 | 
				
			||||||
  EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
 | 
					  EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
 | 
				
			||||||
 | 
					  EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TEST(Cpu, TestMemoryAllocation) {
 | 
					TEST(Cpu, TestMemoryAllocation) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user