Add class weights to core hyperparameters and classifier library.

PiperOrigin-RevId: 550962843
This commit is contained in:
MediaPipe Team 2023-07-25 12:27:03 -07:00 committed by Copybara-Service
parent 62538a9496
commit 85c3fed70a
2 changed files with 7 additions and 2 deletions

View File

@ -15,7 +15,7 @@
import dataclasses
import tempfile
from typing import Optional
from typing import Mapping, Optional
import tensorflow as tf
@ -36,6 +36,8 @@ class BaseHParams:
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size divided by batch size.
class_weights: An optional mapping of indices to weights for weighting the
loss function during training.
shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to
@ -57,6 +59,7 @@ class BaseHParams:
batch_size: int
epochs: int
steps_per_epoch: Optional[int] = None
class_weights: Optional[Mapping[int, float]] = None
# Dataset-related parameters
shuffle: bool = False

View File

@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
# dataset is exhausted even if there are epochs remaining.
steps_per_epoch=None,
validation_data=validation_dataset,
callbacks=self._callbacks)
callbacks=self._callbacks,
class_weight=self._hparams.class_weights,
)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset.