Add class weights to core hyperparameters and classifier library.

PiperOrigin-RevId: 550962843
This commit is contained in:
MediaPipe Team 2023-07-25 12:27:03 -07:00 committed by Copybara-Service
parent 62538a9496
commit 85c3fed70a
2 changed files with 7 additions and 2 deletions

View File

@ -15,7 +15,7 @@
import dataclasses import dataclasses
import tempfile import tempfile
from typing import Optional from typing import Mapping, Optional
import tensorflow as tf import tensorflow as tf
@ -36,6 +36,8 @@ class BaseHParams:
steps_per_epoch: An optional integer indicate the number of training steps steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size divided by batch size. per epoch as the training dataset size divided by batch size.
class_weights: An optional mapping of indices to weights for weighting the
loss function during training.
shuffle: True if the dataset is shuffled before training. shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files. export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to distribution_strategy: A string specifying which Distribution Strategy to
@ -57,6 +59,7 @@ class BaseHParams:
batch_size: int batch_size: int
epochs: int epochs: int
steps_per_epoch: Optional[int] = None steps_per_epoch: Optional[int] = None
class_weights: Optional[Mapping[int, float]] = None
# Dataset-related parameters # Dataset-related parameters
shuffle: bool = False shuffle: bool = False

View File

@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
# dataset is exhausted even if there are epochs remaining. # dataset is exhausted even if there are epochs remaining.
steps_per_epoch=None, steps_per_epoch=None,
validation_data=validation_dataset, validation_data=validation_dataset,
callbacks=self._callbacks) callbacks=self._callbacks,
class_weight=self._hparams.class_weights,
)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.