Internal MediaPipe Tasks change.
PiperOrigin-RevId: 513897822
This commit is contained in:
		
							parent
							
								
									fe92d2e781
								
							
						
					
					
						commit
						2963739086
					
				
							
								
								
									
										44
									
								
								mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
load("@org_tensorflow//tensorflow/lite:build_def.bzl", "tflite_copts")
 | 
			
		||||
 | 
			
		||||
package(default_visibility = ["//mediapipe/tasks:internal"])
 | 
			
		||||
 | 
			
		||||
licenses(["notice"])
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "kmeans_embedding_lookup",
 | 
			
		||||
    srcs = ["kmeans_embedding_lookup.cc"],
 | 
			
		||||
    hdrs = ["kmeans_embedding_lookup.h"],
 | 
			
		||||
    copts = tflite_copts(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_test(
 | 
			
		||||
    name = "kmeans_embedding_lookup_test",
 | 
			
		||||
    size = "small",
 | 
			
		||||
    srcs = ["kmeans_embedding_lookup_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":kmeans_embedding_lookup",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite:framework",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/c:common",
 | 
			
		||||
        "@org_tensorflow//tensorflow/lite/kernels:test_util",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,145 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
 | 
			
		||||
#include "tensorflow/lite/kernels/kernel_util.h"
 | 
			
		||||
 | 
			
		||||
namespace tflite::ops::custom {
 | 
			
		||||
namespace kmeans_embedding_lookup_op {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
constexpr int kInputMessage = 0;
 | 
			
		||||
constexpr int kEncodingTable = 1;
 | 
			
		||||
constexpr int kCodebook = 2;
 | 
			
		||||
constexpr int kOutputLabel = 0;
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
  TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
 | 
			
		||||
  TF_LITE_ENSURE(context, output != nullptr);
 | 
			
		||||
 | 
			
		||||
  TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
 | 
			
		||||
  output_size->data[0] = 1;
 | 
			
		||||
  const TfLiteTensor* input = GetInput(context, node, kInputMessage);
 | 
			
		||||
  TF_LITE_ENSURE(context, input != nullptr);
 | 
			
		||||
  const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable);
 | 
			
		||||
  TF_LITE_ENSURE(context, encoding_table != nullptr);
 | 
			
		||||
  const TfLiteTensor* codebook = GetInput(context, node, kCodebook);
 | 
			
		||||
  TF_LITE_ENSURE(context, codebook != nullptr);
 | 
			
		||||
  const int encoding_size = encoding_table->dims->data[1];
 | 
			
		||||
  const int block_size = codebook->dims->data[1];
 | 
			
		||||
  output_size->data[1] = encoding_size * block_size;
 | 
			
		||||
 | 
			
		||||
  // Check if the inputs and output are typed correctly.
 | 
			
		||||
  if (input->type != kTfLiteInt32) {
 | 
			
		||||
    context->ReportError(context, "Input type must be Int32.");
 | 
			
		||||
    return kTfLiteError;
 | 
			
		||||
  }
 | 
			
		||||
  if (encoding_table->type != kTfLiteUInt8) {
 | 
			
		||||
    context->ReportError(context, "Encoding Table type must be UInt8.");
 | 
			
		||||
    return kTfLiteError;
 | 
			
		||||
  }
 | 
			
		||||
  if (codebook->type != kTfLiteFloat32) {
 | 
			
		||||
    context->ReportError(context, "Codebook type must be Float32.");
 | 
			
		||||
    return kTfLiteError;
 | 
			
		||||
  }
 | 
			
		||||
  if (output->type != kTfLiteFloat32) {
 | 
			
		||||
    context->ReportError(context, "Output type must be Float32.");
 | 
			
		||||
    return kTfLiteError;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return context->ResizeTensor(context, output, output_size);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This is the core method that generates the aggregated embedding from the
 | 
			
		||||
// given input, encoding table and codebook tensors.
 | 
			
		||||
void GetEmbedding(const TfLiteTensor* input, const TfLiteTensor* encoding_table,
 | 
			
		||||
                  const TfLiteTensor* codebook, float* data) {
 | 
			
		||||
  const int input_encoding_size = encoding_table->dims->data[1];
 | 
			
		||||
  const int block_size = codebook->dims->data[1];
 | 
			
		||||
  const int num_tokens = input->dims->data[1];
 | 
			
		||||
  const int output_embedding_size = input_encoding_size * block_size;
 | 
			
		||||
 | 
			
		||||
  int num_embeddings = 0;
 | 
			
		||||
  std::vector<float> final_embedding(output_embedding_size, 0.0);
 | 
			
		||||
  for (int token_idx = 0; token_idx < num_tokens; token_idx++) {
 | 
			
		||||
    const int32_t token = GetTensorData<int32_t>(input)[token_idx];
 | 
			
		||||
    if (token == 0) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    ++num_embeddings;
 | 
			
		||||
 | 
			
		||||
    for (int encoding_dim_idx = 0; encoding_dim_idx < input_encoding_size;
 | 
			
		||||
         encoding_dim_idx++) {
 | 
			
		||||
      int codebook_idx = GetTensorData<uint8_t>(
 | 
			
		||||
          encoding_table)[token * input_encoding_size + encoding_dim_idx];
 | 
			
		||||
      for (int block_offset = 0; block_offset < block_size; block_offset++) {
 | 
			
		||||
        final_embedding[encoding_dim_idx * block_size + block_offset] +=
 | 
			
		||||
            GetTensorData<float>(
 | 
			
		||||
                codebook)[codebook_idx * block_size + block_offset];
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Compute the mean of the embeddings.
 | 
			
		||||
  for (int embed_dim_idx = 0; embed_dim_idx < output_embedding_size;
 | 
			
		||||
       embed_dim_idx++) {
 | 
			
		||||
    data[embed_dim_idx] =
 | 
			
		||||
        final_embedding[embed_dim_idx] / (std::max(num_embeddings, 1));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
 | 
			
		||||
  const TfLiteTensor* input = GetInput(context, node, kInputMessage);
 | 
			
		||||
  TF_LITE_ENSURE(context, input != nullptr);
 | 
			
		||||
  const TfLiteTensor* encoding_table = GetInput(context, node, kEncodingTable);
 | 
			
		||||
  TF_LITE_ENSURE(context, encoding_table != nullptr);
 | 
			
		||||
  const TfLiteTensor* codebook = GetInput(context, node, kCodebook);
 | 
			
		||||
  TF_LITE_ENSURE(context, codebook != nullptr);
 | 
			
		||||
  TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
 | 
			
		||||
  TF_LITE_ENSURE(context, output != nullptr);
 | 
			
		||||
 | 
			
		||||
  // Sanity checks on the input.
 | 
			
		||||
  const int batch_size = input->dims->data[0];
 | 
			
		||||
  if (batch_size != 1) {
 | 
			
		||||
    context->ReportError(context, "`batch_size` must be == 1.");
 | 
			
		||||
    return kTfLiteError;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Compute the output embedding.
 | 
			
		||||
  GetEmbedding(input, encoding_table, codebook, GetTensorData<float>(output));
 | 
			
		||||
 | 
			
		||||
  return kTfLiteOk;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace kmeans_embedding_lookup_op
 | 
			
		||||
 | 
			
		||||
TfLiteRegistration* Register_KmeansEmbeddingLookup() {
 | 
			
		||||
  static TfLiteRegistration r = {nullptr, nullptr,
 | 
			
		||||
                                 kmeans_embedding_lookup_op::Prepare,
 | 
			
		||||
                                 kmeans_embedding_lookup_op::Eval};
 | 
			
		||||
  return &r;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace tflite::ops::custom
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,36 @@
 | 
			
		|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This op was originally written by the Learn2Compress team.
 | 
			
		||||
// It takes in a list of indices, an encoding table which consists of
 | 
			
		||||
// integer indices into a codebook with floating point vectors.
 | 
			
		||||
// For each index, it looks up the corresponding row in the encoding table and
 | 
			
		||||
// for each entry in the row of the encoding table, it looks up the
 | 
			
		||||
// corresponding row in the codebook and populates it in an output embedding.
 | 
			
		||||
// The average of the output embeddings for each of the input indices is the
 | 
			
		||||
// output of this op.
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/lite/kernels/register.h"
 | 
			
		||||
 | 
			
		||||
namespace tflite::ops::custom {
 | 
			
		||||
 | 
			
		||||
TfLiteRegistration* Register_KmeansEmbeddingLookup();
 | 
			
		||||
 | 
			
		||||
}  // namespace tflite::ops::custom
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_KMEANS_EMBEDDING_LOOKUP_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,176 @@
 | 
			
		|||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <initializer_list>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/port/gmock.h"
 | 
			
		||||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "tensorflow/lite/c/common.h"
 | 
			
		||||
#include "tensorflow/lite/interpreter.h"
 | 
			
		||||
#include "tensorflow/lite/kernels/test_util.h"
 | 
			
		||||
 | 
			
		||||
namespace tflite::ops::custom {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::testing::ElementsAreArray;
 | 
			
		||||
using ::tflite::ArrayFloatNear;
 | 
			
		||||
 | 
			
		||||
// Helper class for testing the op.
 | 
			
		||||
class KmeansEmbeddingLookupModel : public SingleOpModel {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit KmeansEmbeddingLookupModel(
 | 
			
		||||
      std::initializer_list<int> input_shape,
 | 
			
		||||
      std::initializer_list<int> encoding_table_shape,
 | 
			
		||||
      std::initializer_list<int> codebook_shape,
 | 
			
		||||
      std::initializer_list<int> output_shape) {
 | 
			
		||||
    // Setup the model inputs and the interpreter.
 | 
			
		||||
    output_ = AddOutput({TensorType_FLOAT32, output_shape});
 | 
			
		||||
    SetCustomOp("KmeansEmbeddingLookup", std::vector<uint8_t>(),
 | 
			
		||||
                Register_KmeansEmbeddingLookup);
 | 
			
		||||
    BuildInterpreter({input_shape, encoding_table_shape, codebook_shape});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TfLiteStatus SetUpInputTensor(const std::vector<int>& input,
 | 
			
		||||
                                const std::vector<uint8_t>& encoding_table,
 | 
			
		||||
                                const std::vector<float>& codebook) {
 | 
			
		||||
    PopulateTensor<int>(input_, {input});
 | 
			
		||||
    PopulateTensor<uint8_t>(encoding_table_, {encoding_table});
 | 
			
		||||
    PopulateTensor<float>(codebook_, {codebook});
 | 
			
		||||
    return interpreter_->AllocateTensors();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Invoke(const std::vector<int>& input,
 | 
			
		||||
              const std::vector<uint8_t>& encoding_table,
 | 
			
		||||
              const std::vector<float>& codebook) {
 | 
			
		||||
    CHECK_EQ(SetUpInputTensor(input, encoding_table, codebook), kTfLiteOk);
 | 
			
		||||
    CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TfLiteStatus InvokeUnchecked(const std::vector<int>& input,
 | 
			
		||||
                               const std::vector<uint8_t>& encoding_table,
 | 
			
		||||
                               const std::vector<float>& codebook) {
 | 
			
		||||
    TfLiteStatus allocation_status =
 | 
			
		||||
        SetUpInputTensor(input, encoding_table, codebook);
 | 
			
		||||
    if (allocation_status != kTfLiteOk) {
 | 
			
		||||
      return allocation_status;
 | 
			
		||||
    }
 | 
			
		||||
    return SingleOpModel::Invoke();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  std::vector<T> GetOutput() {
 | 
			
		||||
    return ExtractVector<T>(output_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  int input_ = AddInput(TensorType_INT32);
 | 
			
		||||
  int encoding_table_ = AddInput(TensorType_UINT8);
 | 
			
		||||
  int codebook_ = AddInput(TensorType_FLOAT32);
 | 
			
		||||
  int output_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
std::vector<T> FlattenVector2D(std::vector<std::vector<T>> input_vec) {
 | 
			
		||||
  std::vector<T> output_vec(input_vec.size() * input_vec[0].size());
 | 
			
		||||
  for (int i = 0, k = 0; i < input_vec.size(); i++) {
 | 
			
		||||
    for (int j = 0; j < input_vec[i].size(); j++, k++) {
 | 
			
		||||
      output_vec[k] = input_vec[i][j];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return output_vec;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class KmeansEmbeddingLookupTestWithSampleInputs : public ::testing::Test {
 | 
			
		||||
 public:
 | 
			
		||||
  KmeansEmbeddingLookupTestWithSampleInputs() {
 | 
			
		||||
    input_ = std::vector<int>({1, 2, 3, 0, 0});
 | 
			
		||||
    encoding_table_ = std::vector<std::vector<uint8_t>>(
 | 
			
		||||
        {{0, 0}, {1, 1}, {1, 2}, {1, 0}, {1, 0}, {2, 0}});
 | 
			
		||||
    codebook_ =
 | 
			
		||||
        std::vector<std::vector<float>>({{0.0, 0.0}, {7.0, 7.0}, {7.0, 0.0}});
 | 
			
		||||
    expected_output_ = std::vector<float>(
 | 
			
		||||
        // The output is the average of the embeddings at the three indices
 | 
			
		||||
        // (1, 2, 3).
 | 
			
		||||
        {7.0, 7.0, 4.66667, 2.33333});
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  std::vector<int> input_;
 | 
			
		||||
  std::vector<std::vector<uint8_t>> encoding_table_;
 | 
			
		||||
  std::vector<std::vector<float>> codebook_;
 | 
			
		||||
  std::vector<float> expected_output_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, ReturnsCorrectly) {
 | 
			
		||||
  // Check if the expected output is returned
 | 
			
		||||
  KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 5},
 | 
			
		||||
                               /*encoding_table_shape=*/{6, 2},
 | 
			
		||||
                               /*codebook_shape=*/{3, 2},
 | 
			
		||||
                               /*output_shape=*/{1, 4});
 | 
			
		||||
 | 
			
		||||
  m.Invoke(input_, FlattenVector2D<uint8_t>(encoding_table_),
 | 
			
		||||
           FlattenVector2D<float>(codebook_));
 | 
			
		||||
  EXPECT_THAT(m.GetOutput<float>(),
 | 
			
		||||
              ElementsAreArray(ArrayFloatNear(expected_output_, 1e-5)));
 | 
			
		||||
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs,
 | 
			
		||||
       HandlesNegativeValuesInCodebook) {
 | 
			
		||||
  KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4},
 | 
			
		||||
                               /*encoding_table_shape=*/{4, 2},
 | 
			
		||||
                               /*codebook_shape=*/{4, 3},
 | 
			
		||||
                               /*output_shape=*/{1, 6});
 | 
			
		||||
  std::vector<int> input = std::vector<int>({2, 2, 1, 3});
 | 
			
		||||
  std::vector<std::vector<uint8_t>> encoding_table =
 | 
			
		||||
      std::vector<std::vector<uint8_t>>({{0, 0}, {1, 2}, {3, 0}, {2, 3}});
 | 
			
		||||
  std::vector<std::vector<float>> codebook = std::vector<std::vector<float>>(
 | 
			
		||||
      {{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}});
 | 
			
		||||
  m.Invoke(input, FlattenVector2D<uint8_t>(encoding_table),
 | 
			
		||||
           FlattenVector2D<float>(codebook));
 | 
			
		||||
  std::vector<float> expected_output =
 | 
			
		||||
      std::vector<float>({2.55, 0.1, 2.4, 2.925, 1.1, 2.65});
 | 
			
		||||
  EXPECT_THAT(m.GetOutput<float>(),
 | 
			
		||||
              ElementsAreArray(ArrayFloatNear(expected_output, 1e-5)));
 | 
			
		||||
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(KmeansEmbeddingLookupTestWithSampleInputs, IgnoresIndicesAfterZero) {
 | 
			
		||||
  KmeansEmbeddingLookupModel m(/*input_shape=*/{1, 4},
 | 
			
		||||
                               /*encoding_table_shape=*/{4, 2},
 | 
			
		||||
                               /*codebook_shape=*/{4, 3},
 | 
			
		||||
                               /*output_shape=*/{1, 6});
 | 
			
		||||
  std::vector<int> input = std::vector<int>({2, 2, 0, 3});
 | 
			
		||||
  std::vector<std::vector<uint8_t>> encoding_table =
 | 
			
		||||
      std::vector<std::vector<uint8_t>>({{0, 0}, {1, 2}, {3, 0}, {2, 3}});
 | 
			
		||||
  std::vector<std::vector<float>> codebook = std::vector<std::vector<float>>(
 | 
			
		||||
      {{5.0, 2.0, 3.0}, {8.0, 2.0, 4.0}, {1.2, 2.4, 3.6}, {0.5, -2.0, 1.0}});
 | 
			
		||||
  m.Invoke(input, FlattenVector2D<uint8_t>(encoding_table),
 | 
			
		||||
           FlattenVector2D<float>(codebook));
 | 
			
		||||
  std::vector<float> expected_output =
 | 
			
		||||
      std::vector<float>({0.5, -2.0, 1.0, 5.0, 2.0, 3.0});
 | 
			
		||||
  EXPECT_THAT(m.GetOutput<float>(),
 | 
			
		||||
              ElementsAreArray(ArrayFloatNear(expected_output, 1e-5)));
 | 
			
		||||
  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 6}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(KmeansEmbeddingLookupTest, ThrowsErrorWhenGivenInvalidInputBatchSize) {
 | 
			
		||||
  // Check that the op errors out when the batch size is greater than 1.
 | 
			
		||||
  KmeansEmbeddingLookupModel m(/*input_shape=*/{2, 1},
 | 
			
		||||
                               /*encoding_table_shape=*/{1, 1},
 | 
			
		||||
                               /*codebook shape=*/{1, 2},
 | 
			
		||||
                               /*output_shape=*/{2, 2});
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(m.InvokeUnchecked(/*input=*/{1, 1},
 | 
			
		||||
                              /*encoding_table=*/{0},
 | 
			
		||||
                              /*codebook=*/{2.3, 4.5}),
 | 
			
		||||
            kTfLiteError);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace tflite::ops::custom
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user