Add class weights to core hyperparameters and classifier library.
PiperOrigin-RevId: 550962843
This commit is contained in:
parent
62538a9496
commit
85c3fed70a
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional
|
from typing import Mapping, Optional
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
@ -36,6 +36,8 @@ class BaseHParams:
|
||||||
steps_per_epoch: An optional integer indicate the number of training steps
|
steps_per_epoch: An optional integer indicate the number of training steps
|
||||||
per epoch. If not set, the training pipeline calculates the default steps
|
per epoch. If not set, the training pipeline calculates the default steps
|
||||||
per epoch as the training dataset size divided by batch size.
|
per epoch as the training dataset size divided by batch size.
|
||||||
|
class_weights: An optional mapping of indices to weights for weighting the
|
||||||
|
loss function during training.
|
||||||
shuffle: True if the dataset is shuffled before training.
|
shuffle: True if the dataset is shuffled before training.
|
||||||
export_dir: The location of the model checkpoint files.
|
export_dir: The location of the model checkpoint files.
|
||||||
distribution_strategy: A string specifying which Distribution Strategy to
|
distribution_strategy: A string specifying which Distribution Strategy to
|
||||||
|
@ -57,6 +59,7 @@ class BaseHParams:
|
||||||
batch_size: int
|
batch_size: int
|
||||||
epochs: int
|
epochs: int
|
||||||
steps_per_epoch: Optional[int] = None
|
steps_per_epoch: Optional[int] = None
|
||||||
|
class_weights: Optional[Mapping[int, float]] = None
|
||||||
|
|
||||||
# Dataset-related parameters
|
# Dataset-related parameters
|
||||||
shuffle: bool = False
|
shuffle: bool = False
|
||||||
|
|
|
@ -110,7 +110,9 @@ class Classifier(custom_model.CustomModel):
|
||||||
# dataset is exhausted even if there are epochs remaining.
|
# dataset is exhausted even if there are epochs remaining.
|
||||||
steps_per_epoch=None,
|
steps_per_epoch=None,
|
||||||
validation_data=validation_dataset,
|
validation_data=validation_dataset,
|
||||||
callbacks=self._callbacks)
|
callbacks=self._callbacks,
|
||||||
|
class_weight=self._hparams.class_weights,
|
||||||
|
)
|
||||||
|
|
||||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||||
"""Evaluates the classifier with the provided evaluation dataset.
|
"""Evaluates the classifier with the provided evaluation dataset.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user