Internal change
PiperOrigin-RevId: 477585110
This commit is contained in:
parent
8c8a9cda5a
commit
dcc5587483
74
mediapipe/calculators/tensor/text_to_tensor_calculator.cc
Normal file
74
mediapipe/calculators/tensor/text_to_tensor_calculator.cc
Normal file
|
@ -0,0 +1,74 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Trivially converts an input string into a Tensor that stores a copy of
|
||||
// the string.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor storing a copy of the input string.
|
||||
// Note that the underlying buffer of the Tensor is not necessarily
|
||||
// null-terminated. It is the graph writer's responsibility to copy the
|
||||
// correct number of characters when copying from this Tensor's buffer.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "TextToTensorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// }
|
||||
class TextToTensorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
};
|
||||
|
||||
absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
|
||||
absl::string_view text = kTextIn(cc).Get();
|
||||
int input_len = static_cast<int>(text.length());
|
||||
|
||||
std::vector<Tensor> result;
|
||||
result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})});
|
||||
std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
|
||||
input_len * sizeof(char));
|
||||
kTensorsOut(cc).Send(std::move(result));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_graph.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/options_map.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::testing::StrEq;
|
||||
|
||||
absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
R"pb(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "TextToTensorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
output_stream: "TENSORS:tensors"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor_vec has size $0, expected 1", tensor_vec.size()));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||
Tensor::ElementType::kChar));
|
||||
}
|
||||
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
|
||||
return std::string(buffer, text.length());
|
||||
}
|
||||
|
||||
TEST(TextToTensorCalculatorTest, FooBarBaz) {
|
||||
EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"),
|
||||
IsOkAndHolds(StrEq("Foo. Bar? Baz!")));
|
||||
}
|
||||
|
||||
TEST(TextToTensorCalculatorTest, Empty) {
|
||||
EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq("")));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -90,7 +90,16 @@ class Tensor {
|
|||
|
||||
public:
|
||||
// No resources are allocated here.
|
||||
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 };
|
||||
enum class ElementType {
|
||||
kNone,
|
||||
kFloat16,
|
||||
kFloat32,
|
||||
kUInt8,
|
||||
kInt8,
|
||||
kInt32,
|
||||
// TODO: Update the inference runner to handle kTfLiteString.
|
||||
kChar
|
||||
};
|
||||
struct Shape {
|
||||
Shape() = default;
|
||||
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
||||
|
@ -319,6 +328,8 @@ class Tensor {
|
|||
return 1;
|
||||
case ElementType::kInt32:
|
||||
return sizeof(int32_t);
|
||||
case ElementType::kChar:
|
||||
return sizeof(char);
|
||||
}
|
||||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -23,6 +26,9 @@ TEST(General, TestDataTypes) {
|
|||
|
||||
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
|
||||
EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
|
||||
|
||||
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||
}
|
||||
|
||||
TEST(Cpu, TestMemoryAllocation) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user