Internal MediaPipe Tasks change.
PiperOrigin-RevId: 513897822
This commit is contained in:
parent
fe92d2e781
commit
2963739086
44
mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD
Normal file
44
mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts")
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "kmeans_embedding_lookup",
|
||||||
|
srcs = ["kmeans_embedding_lookup.cc"],
|
||||||
|
hdrs = ["kmeans_embedding_lookup.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "kmeans_embedding_lookup_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["kmeans_embedding_lookup_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":kmeans_embedding_lookup",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,145 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
namespace kmeans_embedding_lookup_op {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int kInputMessage = 0;
|
||||||
|
constexpr int kEncodingTable = 1;
|
||||||
|
constexpr int kCodebook = 2;
|
||||||
|
constexpr int kOutputLabel = 0;
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
|
||||||
|
output_size->data[0] = 1;
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputMessage);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable);
|
||||||
|
TF_LITE_ENSURE(context, encoding_table != nullptr);
|
||||||
|
const TfLiteTensor* codebook = GetInput(context, node, kCodebook);
|
||||||
|
TF_LITE_ENSURE(context, codebook != nullptr);
|
||||||
|
const int encoding_size = encoding_table->dims->data[1];
|
||||||
|
const int block_size = codebook->dims->data[1];
|
||||||
|
output_size->data[1] = encoding_size * block_size;
|
||||||
|
|
||||||
|
// Check if the inputs and output are typed correctly.
|
||||||
|
if (input->type != kTfLiteInt32) {
|
||||||
|
context->ReportError(context, "Input type must be Int32.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (encoding_table->type != kTfLiteUInt8) {
|
||||||
|
context->ReportError(context, "Encoding Table type must be UInt8.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (codebook->type != kTfLiteFloat32) {
|
||||||
|
context->ReportError(context, "Codebook type must be Float32.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
if (output->type != kTfLiteFloat32) {
|
||||||
|
context->ReportError(context, "Output type must be Float32.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return context->ResizeTensor(context, output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the core method that generates the aggregated embedding from the
|
||||||
|
// given input, encoding table and codebook tensors.
|
||||||
|
void GetEmbedding(const TfLiteTensor* input, const TfLiteTensor* encoding_table,
|
||||||
|
const TfLiteTensor* codebook, float* data) {
|
||||||
|
const int input_encoding_size = encoding_table->dims->data[1];
|
||||||
|
const int block_size = codebook->dims->data[1];
|
||||||
|
const int num_tokens = input->dims->data[1];
|
||||||
|
const int output_embedding_size = input_encoding_size * block_size;
|
||||||
|
|
||||||
|
int num_embeddings = 0;
|
||||||
|
std::vector<float> final_embedding(output_embedding_size, 0.0);
|
||||||
|
for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
|
||||||
|
const int32_t token = GetTensorData<int32_t>(input)[token_idx];
|
||||||
|
if (token == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++num_embeddings;
|
||||||
|
|
||||||
|
for (int encoding_dim_idx = 0; encoding_dim_idx < input_encoding_size;
|
||||||
|
encoding_dim_idx++) {
|
||||||
|
int codebook_idx = GetTensorData<uint8_t>(
|
||||||
|
encoding_table)[token * input_encoding_size + encoding_dim_idx];
|
||||||
|
for (int block_offset = 0; block_offset < block_size; block_offset++) {
|
||||||
|
final_embedding[encoding_dim_idx * block_size + block_offset] +=
|
||||||
|
GetTensorData<float>(
|
||||||
|
codebook)[codebook_idx * block_size + block_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the mean of the embeddings.
|
||||||
|
for (int embed_dim_idx = 0; embed_dim_idx < output_embedding_size;
|
||||||
|
embed_dim_idx++) {
|
||||||
|
data[embed_dim_idx] =
|
||||||
|
final_embedding[embed_dim_idx] / (std::max(num_embeddings, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, kInputMessage);
|
||||||
|
TF_LITE_ENSURE(context, input != nullptr);
|
||||||
|
const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable);
|
||||||
|
TF_LITE_ENSURE(context, encoding_table != nullptr);
|
||||||
|
const TfLiteTensor* codebook = GetInput(context, node, kCodebook);
|
||||||
|
TF_LITE_ENSURE(context, codebook != nullptr);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
|
||||||
|
// Sanity checks on the input.
|
||||||
|
const int batch_size = input->dims->data[0];
|
||||||
|
if (batch_size != 1) {
|
||||||
|
context->ReportError(context, "`batch_size` must be == 1.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the output embedding.
|
||||||
|
GetEmbedding(input, encoding_table, codebook, GetTensorData<float>(output));
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kmeans_embedding_lookup_op
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_KmeansEmbeddingLookup() {
|
||||||
|
static TfLiteRegistration r = {nullptr, nullptr,
|
||||||
|
kmeans_embedding_lookup_op::Prepare,
|
||||||
|
kmeans_embedding_lookup_op::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,36 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This op was originally written by the Learn2Compress team.
|
||||||
|
// It takes in a list of indices, an encoding table which consists of
|
||||||
|
// integer indices into a codebook with floating point vectors.
|
||||||
|
// For each index, it looks up the corresponding row in the encoding table and
|
||||||
|
// for each entry in the row of the encoding table, it looks up the
|
||||||
|
// corresponding row in the codebook and populates it in an output embedding.
|
||||||
|
// The average of the output embeddings for each of the input indices is the
|
||||||
|
// output of this op.
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_KmeansEmbeddingLookup();
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
|
|
@ -0,0 +1,176 @@
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
using ::tflite::ArrayFloatNear;
|
||||||
|
|
||||||
|
// Helper class for testing the op.
|
||||||
|
class KmeansEmbeddingLookupModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
explicit KmeansEmbeddingLookupModel(
|
||||||
|
std::initializer_list<int> input_shape,
|
||||||
|
std::initializer_list<int> encoding_table_shape,
|
||||||
|
std::initializer_list<int> codebook_shape,
|
||||||
|
std::initializer_list<int> output_shape) {
|
||||||
|
// Setup the model inputs and the interpreter.
|
||||||
|
output_ = AddOutput({TensorType_FLOAT32, output_shape});
|
||||||
|
SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(),
|
||||||
|
Register_KmeansEmbeddingLookup);
|
||||||
|
BuildInterpreter({input_shape, encoding_table_shape, codebook_shape});
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus SetUpInputTensor(const std::vector<int>& input,
|
||||||
|
const std::vector<uint8_t>& encoding_table,
|
||||||
|
const std::vector<float>& codebook) {
|
||||||
|
PopulateTensor<int>(input_, {input});
|
||||||
|
PopulateTensor<uint8_t>(encoding_table_, {encoding_table});
|
||||||
|
PopulateTensor<float>(codebook_, {codebook});
|
||||||
|
return interpreter_->AllocateTensors();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Invoke(const std::vector<int>& input,
|
||||||
|
const std::vector<uint8_t>& encoding_table,
|
||||||
|
const std::vector<float>& codebook) {
|
||||||
|
CHECK_EQ(SetUpInputTensor(input, encoding_table, codebook), kTfLiteOk);
|
||||||
|
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InvokeUnchecked(const std::vector<int>& input,
|
||||||
|
const std::vector<uint8_t>& encoding_table,
|
||||||
|
const std::vector<float>& codebook) {
|
||||||
|
TfLiteStatus allocation_status =
|
||||||
|
SetUpInputTensor(input, encoding_table, codebook);
|
||||||
|
if (allocation_status != kTfLiteOk) {
|
||||||
|
return allocation_status;
|
||||||
|
}
|
||||||
|
return SingleOpModel::Invoke();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
|
return ExtractVector<T>(output_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int input_ = AddInput(TensorType_INT32);
|
||||||
|
int encoding_table_ = AddInput(TensorType_UINT8);
|
||||||
|
int codebook_ = AddInput(TensorType_FLOAT32);
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> FlattenVector2D(std::vector<std::vector<T>> input_vec) {
|
||||||
|
std::vector<T> output_vec(input_vec.size() * input_vec[0].size());
|
||||||
|
for (int i = 0, k = 0; i < input_vec.size(); i++) {
|
||||||
|
for (int j = 0; j < input_vec[i].size(); j++, k++) {
|
||||||
|
output_vec[k] = input_vec[i][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
class KmeansEmbeddingLookupTestWithSampleInputs : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
KmeansEmbeddingLookupTestWithSampleInputs() {
|
||||||
|
input_ = std::vector<int>({1, 2, 3, 0, 0});
|
||||||
|
encoding_table_ = std::vector<std::vector<uint8_t>>(
|
||||||
|
{{0, 0}, {1, 1}, {1, 2}, {1, 0}, {1, 0}, {2, 0}});
|
||||||
|
codebook_ =
|
||||||
|
std::vector<std::vector<float>>({{0.0, 0.0}, {7.0, 7.0}, {7.0, 0.0}});
|
||||||
|
expected_output_ = std::vector<float>(
|
||||||
|
// The output is the average of the embeddings at the three indices
|
||||||
|
// (1, 2, 3).
|
||||||
|
{7.0, 7.0, 4.66667, 2.33333});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<int> input_;
|
||||||
|
std::vector<std::vector<uint8_t>> encoding_table_;
|
||||||
|
std::vector<std::vector<float>> codebook_;
|
||||||
|
std::vector<float> expected_output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, ReturnsCorrectly) {
|
||||||
|
// Check if the expected output is returned
|
||||||
|
KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 5},
|
||||||
|
/*encoding_table_shape=*/{6, 2},
|
||||||
|
/*codebook_shape=*/{3, 2},
|
||||||
|
/*output_shape=*/{1, 4});
|
||||||
|
|
||||||
|
m.Invoke(input_, FlattenVector2D<uint8_t>(encoding_table_),
|
||||||
|
FlattenVector2D<float>(codebook_));
|
||||||
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(expected_output_, 1e-5)));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs,
|
||||||
|
HandlesNegativeValuesInCodebook) {
|
||||||
|
KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4},
|
||||||
|
/*encoding_table_shape=*/{4, 2},
|
||||||
|
/*codebook_shape=*/{4, 3},
|
||||||
|
/*output_shape=*/{1, 6});
|
||||||
|
std::vector<int> input = std::vector<int>({2, 2, 1, 3});
|
||||||
|
std::vector<std::vector<uint8_t>> encoding_table =
|
||||||
|
std::vector<std::vector<uint8_t>>({{0, 0}, {1, 2}, {3, 0}, {2, 3}});
|
||||||
|
std::vector<std::vector<float>> codebook = std::vector<std::vector<float>>(
|
||||||
|
{{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}});
|
||||||
|
m.Invoke(input, FlattenVector2D<uint8_t>(encoding_table),
|
||||||
|
FlattenVector2D<float>(codebook));
|
||||||
|
std::vector<float> expected_output =
|
||||||
|
std::vector<float>({2.55, 0.1, 2.4, 2.925, 1.1, 2.65});
|
||||||
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(expected_output, 1e-5)));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, IgnoresIndicesAfterZero) {
|
||||||
|
KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4},
|
||||||
|
/*encoding_table_shape=*/{4, 2},
|
||||||
|
/*codebook_shape=*/{4, 3},
|
||||||
|
/*output_shape=*/{1, 6});
|
||||||
|
std::vector<int> input = std::vector<int>({2, 2, 0, 3});
|
||||||
|
std::vector<std::vector<uint8_t>> encoding_table =
|
||||||
|
std::vector<std::vector<uint8_t>>({{0, 0}, {1, 2}, {3, 0}, {2, 3}});
|
||||||
|
std::vector<std::vector<float>> codebook = std::vector<std::vector<float>>(
|
||||||
|
{{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}});
|
||||||
|
m.Invoke(input, FlattenVector2D<uint8_t>(encoding_table),
|
||||||
|
FlattenVector2D<float>(codebook));
|
||||||
|
std::vector<float> expected_output =
|
||||||
|
std::vector<float>({0.5, -2.0, 1.0, 5.0, 2.0, 3.0});
|
||||||
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear(expected_output, 1e-5)));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) {
|
||||||
|
// Check that the op errors out when the batch size is greater than 1.
|
||||||
|
KmeansEmbeddingLookupModel m(/*input_shape=*/{2, 1},
|
||||||
|
/*encoding_table_shape=*/{1, 1},
|
||||||
|
/*codebook shape=*/{1, 2},
|
||||||
|
/*output_shape=*/{2, 2});
|
||||||
|
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked(/*input=*/{1, 1},
|
||||||
|
/*encoding_table=*/{0},
|
||||||
|
/*codebook=*/{2.3, 4.5}),
|
||||||
|
kTfLiteError);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite::ops::custom
|
Loading…
Reference in New Issue
Block a user