Internal change

PiperOrigin-RevId: 477585110
This commit is contained in:
MediaPipe Team 2022-09-28 16:46:09 -07:00 committed by Copybara-Service
parent 8c8a9cda5a
commit dcc5587483
4 changed files with 180 additions and 1 deletions

View File

@ -0,0 +1,74 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
namespace api2 {
// Trivially converts an input string into a Tensor that stores a copy of
// the string.
//
// Inputs:
// TEXT - std::string
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor storing a copy of the input string.
// Note that the underlying buffer of the Tensor is not necessarily
// null-terminated. It is the graph writer's responsibility to copy the
// correct number of characters when copying from this Tensor's buffer.
//
// Example:
// node {
// calculator: "TextToTensorCalculator"
// input_stream: "TEXT:text"
// output_stream: "TENSORS:tensors"
// }
class TextToTensorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut);
absl::Status Process(CalculatorContext* cc) override;
};
absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) {
absl::string_view text = kTextIn(cc).Get();
int input_len = static_cast<int>(text.length());
std::vector<Tensor> result;
result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})});
std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(),
input_len * sizeof(char));
kTensorsOut(cc).Send(std::move(result));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,88 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_map.h"
namespace mediapipe {
namespace {
using ::testing::StrEq;
absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "TextToTensorCalculator"
input_stream: "TEXT:text"
output_stream: "TENSORS:tensors"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(graph_config));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected 1", tensor_vec.size()));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
}
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
return std::string(buffer, text.length());
}
TEST(TextToTensorCalculatorTest, FooBarBaz) {
EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"),
IsOkAndHolds(StrEq("Foo. Bar? Baz!")));
}
TEST(TextToTensorCalculatorTest, Empty) {
EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq("")));
}
} // namespace
} // namespace mediapipe

View File

@ -90,7 +90,16 @@ class Tensor {
public:
// No resources are allocated here.
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 };
enum class ElementType {
kNone,
kFloat16,
kFloat32,
kUInt8,
kInt8,
kInt32,
// TODO: Update the inference runner to handle kTfLiteString.
kChar
};
struct Shape {
Shape() = default;
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
@ -319,6 +328,8 @@ class Tensor {
return 1;
case ElementType::kInt32:
return sizeof(int32_t);
case ElementType::kChar:
return sizeof(char);
}
}
int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -1,5 +1,8 @@
#include "mediapipe/framework/formats/tensor.h"
#include <cstring>
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#if !MEDIAPIPE_DISABLE_GPU
@ -23,6 +26,9 @@ TEST(General, TestDataTypes) {
Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3});
EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2);
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
}
TEST(Cpu, TestMemoryAllocation) {