Internal MediaPipe Tasks change.

PiperOrigin-RevId: 513897822
This commit is contained in:
MediaPipe Team 2023-03-03 12:44:55 -08:00 committed by Copybara-Service
parent fe92d2e781
commit 2963739086
4 changed files with 401 additions and 0 deletions

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "kmeans_embedding_lookup",
srcs = ["kmeans_embedding_lookup.cc"],
hdrs = ["kmeans_embedding_lookup.h"],
copts = tflite_copts(),
deps = [
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
],
)
cc_test(
name = "kmeans_embedding_lookup_test",
size = "small",
srcs = ["kmeans_embedding_lookup_test.cc"],
deps = [
":kmeans_embedding_lookup",
"//mediapipe/framework/port:gtest_main",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
],
)

View File

@ -0,0 +1,145 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <vector>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite::ops::custom {
namespace kmeans_embedding_lookup_op {
namespace {
constexpr int kInputMessage = 0;
constexpr int kEncodingTable = 1;
constexpr int kCodebook = 2;
constexpr int kOutputLabel = 0;
} // namespace
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = 1;
const TfLiteTensor* input = GetInput(context, node, kInputMessage);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable);
TF_LITE_ENSURE(context, encoding_table != nullptr);
const TfLiteTensor* codebook = GetInput(context, node, kCodebook);
TF_LITE_ENSURE(context, codebook != nullptr);
const int encoding_size = encoding_table->dims->data[1];
const int block_size = codebook->dims->data[1];
output_size->data[1] = encoding_size * block_size;
// Check if the inputs and output are typed correctly.
if (input->type != kTfLiteInt32) {
context->ReportError(context, "Input type must be Int32.");
return kTfLiteError;
}
if (encoding_table->type != kTfLiteUInt8) {
context->ReportError(context, "Encoding Table type must be UInt8.");
return kTfLiteError;
}
if (codebook->type != kTfLiteFloat32) {
context->ReportError(context, "Codebook type must be Float32.");
return kTfLiteError;
}
if (output->type != kTfLiteFloat32) {
context->ReportError(context, "Output type must be Float32.");
return kTfLiteError;
}
return context->ResizeTensor(context, output, output_size);
}
// This is the core method that generates the aggregated embedding from the
// given input, encoding table and codebook tensors.
void GetEmbedding(const TfLiteTensor* input, const TfLiteTensor* encoding_table,
const TfLiteTensor* codebook, float* data) {
const int input_encoding_size = encoding_table->dims->data[1];
const int block_size = codebook->dims->data[1];
const int num_tokens = input->dims->data[1];
const int output_embedding_size = input_encoding_size * block_size;
int num_embeddings = 0;
std::vector<float> final_embedding(output_embedding_size, 0.0);
for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
const int32_t token = GetTensorData<int32_t>(input)[token_idx];
if (token == 0) {
break;
}
++num_embeddings;
for (int encoding_dim_idx = 0; encoding_dim_idx < input_encoding_size;
encoding_dim_idx++) {
int codebook_idx = GetTensorData<uint8_t>(
encoding_table)[token * input_encoding_size + encoding_dim_idx];
for (int block_offset = 0; block_offset < block_size; block_offset++) {
final_embedding[encoding_dim_idx * block_size + block_offset] +=
GetTensorData<float>(
codebook)[codebook_idx * block_size + block_offset];
}
}
}
// Compute the mean of the embeddings.
for (int embed_dim_idx = 0; embed_dim_idx < output_embedding_size;
embed_dim_idx++) {
data[embed_dim_idx] =
final_embedding[embed_dim_idx] / (std::max(num_embeddings, 1));
}
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputMessage);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable);
TF_LITE_ENSURE(context, encoding_table != nullptr);
const TfLiteTensor* codebook = GetInput(context, node, kCodebook);
TF_LITE_ENSURE(context, codebook != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
// Sanity checks on the input.
const int batch_size = input->dims->data[0];
if (batch_size != 1) {
context->ReportError(context, "`batch_size` must be == 1.");
return kTfLiteError;
}
// Compute the output embedding.
GetEmbedding(input, encoding_table, codebook, GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace kmeans_embedding_lookup_op
TfLiteRegistration* Register_KmeansEmbeddingLookup() {
static TfLiteRegistration r = {nullptr, nullptr,
kmeans_embedding_lookup_op::Prepare,
kmeans_embedding_lookup_op::Eval};
return &r;
}
} // namespace tflite::ops::custom

View File

@ -0,0 +1,36 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This op was originally written by the Learn2Compress team.
// It takes in a list of indices, an encoding table which consists of
// integer indices into a codebook with floating point vectors.
// For each index, it looks up the corresponding row in the encoding table and
// for each entry in the row of the encoding table, it looks up the
// corresponding row in the codebook and populates it in an output embedding.
// The average of the output embeddings for each of the input indices is the
// output of this op.
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_KmeansEmbeddingLookup();
} // namespace tflite::ops::custom
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_

View File

@ -0,0 +1,176 @@
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
#include <cstdint>
#include <functional>
#include <initializer_list>
#include <string>
#include <vector>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/test_util.h"
namespace tflite::ops::custom {
namespace {
using ::testing::ElementsAreArray;
using ::tflite::ArrayFloatNear;
// Helper class for testing the op.
class KmeansEmbeddingLookupModel : public SingleOpModel {
public:
explicit KmeansEmbeddingLookupModel(
std::initializer_list<int> input_shape,
std::initializer_list<int> encoding_table_shape,
std::initializer_list<int> codebook_shape,
std::initializer_list<int> output_shape) {
// Setup the model inputs and the interpreter.
output_ = AddOutput({TensorType_FLOAT32, output_shape});
SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(),
Register_KmeansEmbeddingLookup);
BuildInterpreter({input_shape, encoding_table_shape, codebook_shape});
}
TfLiteStatus SetUpInputTensor(const std::vector<int>& input,
const std::vector<uint8_t>& encoding_table,
const std::vector<float>& codebook) {
PopulateTensor<int>(input_, {input});
PopulateTensor<uint8_t>(encoding_table_, {encoding_table});
PopulateTensor<float>(codebook_, {codebook});
return interpreter_->AllocateTensors();
}
void Invoke(const std::vector<int>& input,
const std::vector<uint8_t>& encoding_table,
const std::vector<float>& codebook) {
CHECK_EQ(SetUpInputTensor(input, encoding_table, codebook), kTfLiteOk);
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeUnchecked(const std::vector<int>& input,
const std::vector<uint8_t>& encoding_table,
const std::vector<float>& codebook) {
TfLiteStatus allocation_status =
SetUpInputTensor(input, encoding_table, codebook);
if (allocation_status != kTfLiteOk) {
return allocation_status;
}
return SingleOpModel::Invoke();
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_ = AddInput(TensorType_INT32);
int encoding_table_ = AddInput(TensorType_UINT8);
int codebook_ = AddInput(TensorType_FLOAT32);
int output_;
};
template <typename T>
std::vector<T> FlattenVector2D(std::vector<std::vector<T>> input_vec) {
std::vector<T> output_vec(input_vec.size() * input_vec[0].size());
for (int i = 0, k = 0; i < input_vec.size(); i++) {
for (int j = 0; j < input_vec[i].size(); j++, k++) {
output_vec[k] = input_vec[i][j];
}
}
return output_vec;
}
class KmeansEmbeddingLookupTestWithSampleInputs : public ::testing::Test {
public:
KmeansEmbeddingLookupTestWithSampleInputs() {
input_ = std::vector<int>({1, 2, 3, 0, 0});
encoding_table_ = std::vector<std::vector<uint8_t>>(
{{0, 0}, {1, 1}, {1, 2}, {1, 0}, {1, 0}, {2, 0}});
codebook_ =
std::vector<std::vector<float>>({{0.0, 0.0}, {7.0, 7.0}, {7.0, 0.0}});
expected_output_ = std::vector<float>(
// The output is the average of the embeddings at the three indices
// (1, 2, 3).
{7.0, 7.0, 4.66667, 2.33333});
}
protected:
std::vector<int> input_;
std::vector<std::vector<uint8_t>> encoding_table_;
std::vector<std::vector<float>> codebook_;
std::vector<float> expected_output_;
};
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, ReturnsCorrectly) {
// Check if the expected output is returned
KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 5},
/*encoding_table_shape=*/{6, 2},
/*codebook_shape=*/{3, 2},
/*output_shape=*/{1, 4});
m.Invoke(input_, FlattenVector2D<uint8_t>(encoding_table_),
FlattenVector2D<float>(codebook_));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(expected_output_, 1e-5)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4}));
}
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs,
HandlesNegativeValuesInCodebook) {
KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4},
/*encoding_table_shape=*/{4, 2},
/*codebook_shape=*/{4, 3},
/*output_shape=*/{1, 6});
std::vector<int> input = std::vector<int>({2, 2, 1, 3});
std::vector<std::vector<uint8_t>> encoding_table =
std::vector<std::vector<uint8_t>>({{0, 0}, {1, 2}, {3, 0}, {2, 3}});
std::vector<std::vector<float>> codebook = std::vector<std::vector<float>>(
{{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}});
m.Invoke(input, FlattenVector2D<uint8_t>(encoding_table),
FlattenVector2D<float>(codebook));
std::vector<float> expected_output =
std::vector<float>({2.55, 0.1, 2.4, 2.925, 1.1, 2.65});
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(expected_output, 1e-5)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6}));
}
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, IgnoresIndicesAfterZero) {
KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4},
/*encoding_table_shape=*/{4, 2},
/*codebook_shape=*/{4, 3},
/*output_shape=*/{1, 6});
std::vector<int> input = std::vector<int>({2, 2, 0, 3});
std::vector<std::vector<uint8_t>> encoding_table =
std::vector<std::vector<uint8_t>>({{0, 0}, {1, 2}, {3, 0}, {2, 3}});
std::vector<std::vector<float>> codebook = std::vector<std::vector<float>>(
{{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}});
m.Invoke(input, FlattenVector2D<uint8_t>(encoding_table),
FlattenVector2D<float>(codebook));
std::vector<float> expected_output =
std::vector<float>({0.5, -2.0, 1.0, 5.0, 2.0, 3.0});
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(expected_output, 1e-5)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6}));
}
TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) {
// Check that the op errors out when the batch size is greater than 1.
KmeansEmbeddingLookupModel m(/*input_shape=*/{2, 1},
/*encoding_table_shape=*/{1, 1},
/*codebook shape=*/{1, 2},
/*output_shape=*/{2, 2});
EXPECT_EQ(m.InvokeUnchecked(/*input=*/{1, 1},
/*encoding_table=*/{0},
/*codebook=*/{2.3, 4.5}),
kTfLiteError);
}
} // namespace
} // namespace tflite::ops::custom