Internal change
PiperOrigin-RevId: 477585110
This commit is contained in:
parent
8c8a9cda5a
commit
dcc5587483
74
mediapipe/calculators/tensor/text_to_tensor_calculator.cc
Normal file
74
mediapipe/calculators/tensor/text_to_tensor_calculator.cc
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_context.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
// Trivially converts an input string into a Tensor that stores a copy of
|
||||||
|
// the string.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// TEXT - std::string
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// TENSORS - std::vector<Tensor>
|
||||||
|
// Vector containing a single Tensor storing a copy of the input string.
|
||||||
|
// Note that the underlying buffer of the Tensor is not necessarily
|
||||||
|
// null-terminated. It is the graph writer's responsibility to copy the
|
||||||
|
// correct number of characters when copying from this Tensor's buffer.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// node {
|
||||||
|
// calculator: "TextToTensorCalculator"
|
||||||
|
// input_stream: "TEXT:text"
|
||||||
|
// output_stream: "TENSORS:tensors"
|
||||||
|
// }
|
||||||
|
class TextToTensorCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||||
|
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);
|
||||||
|
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
|
||||||
|
absl::string_view text = kTextIn(cc).Get();
|
||||||
|
int input_len = static_cast<int>(text.length());
|
||||||
|
|
||||||
|
std::vector<Tensor> result;
|
||||||
|
result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})});
|
||||||
|
std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
|
||||||
|
input_len * sizeof(char));
|
||||||
|
kTensorsOut(cc).Send(std::move(result));
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_graph.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/tool/options_map.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::StrEq;
|
||||||
|
|
||||||
|
absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
|
||||||
|
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
|
R"pb(
|
||||||
|
input_stream: "text"
|
||||||
|
output_stream: "tensors"
|
||||||
|
node {
|
||||||
|
calculator: "TextToTensorCalculator"
|
||||||
|
input_stream: "TEXT:text"
|
||||||
|
output_stream: "TENSORS:tensors"
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
std::vector<Packet> output_packets;
|
||||||
|
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
// Run the graph.
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
|
||||||
|
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||||
|
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||||
|
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||||
|
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||||
|
|
||||||
|
if (output_packets.size() != 1) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"output_packets has size $0, expected 1", output_packets.size()));
|
||||||
|
}
|
||||||
|
const std::vector<Tensor>& tensor_vec =
|
||||||
|
output_packets[0].Get<std::vector<Tensor>>();
|
||||||
|
if (tensor_vec.size() != 1) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"tensor_vec has size $0, expected 1", tensor_vec.size()));
|
||||||
|
}
|
||||||
|
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||||
|
return absl::InvalidArgumentError(absl::Substitute(
|
||||||
|
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||||
|
Tensor::ElementType::kChar));
|
||||||
|
}
|
||||||
|
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
|
||||||
|
return std::string(buffer, text.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TextToTensorCalculatorTest, FooBarBaz) {
|
||||||
|
EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"),
|
||||||
|
IsOkAndHolds(StrEq("Foo. Bar? Baz!")));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TextToTensorCalculatorTest, Empty) {
|
||||||
|
EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq("")));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -90,7 +90,16 @@ class Tensor {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// No resources are allocated here.
|
// No resources are allocated here.
|
||||||
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 };
|
enum class ElementType {
|
||||||
|
kNone,
|
||||||
|
kFloat16,
|
||||||
|
kFloat32,
|
||||||
|
kUInt8,
|
||||||
|
kInt8,
|
||||||
|
kInt32,
|
||||||
|
// TODO: Update the inference runner to handle kTfLiteString.
|
||||||
|
kChar
|
||||||
|
};
|
||||||
struct Shape {
|
struct Shape {
|
||||||
Shape() = default;
|
Shape() = default;
|
||||||
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
||||||
|
@ -319,6 +328,8 @@ class Tensor {
|
||||||
return 1;
|
return 1;
|
||||||
case ElementType::kInt32:
|
case ElementType::kInt32:
|
||||||
return sizeof(int32_t);
|
return sizeof(int32_t);
|
||||||
|
case ElementType::kChar:
|
||||||
|
return sizeof(char);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
|
@ -23,6 +26,9 @@ TEST(General, TestDataTypes) {
|
||||||
|
|
||||||
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
|
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
|
||||||
EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
|
EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
|
||||||
|
|
||||||
|
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||||
|
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Cpu, TestMemoryAllocation) {
|
TEST(Cpu, TestMemoryAllocation) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user