from aisquared.base import BaseObject
[docs]class BinaryClassification(BaseObject):
"""
Postprocesssing configuration object for binary classification
Example usage
>>> import aisquared
>>> my_obj = aisquared.config.postprocessing.BinaryClassification(
['class1', 'class2']
)
>>> my_obj.to_dict()
{'className': 'BinaryClassification',
'params': {'labelMap': ['class1', 'class2'], 'threshold': 0.5}}
"""
def __init__(
self,
label_map: list,
threshold: float = 0.5
):
"""
Parameters
----------
label_map : list
List of two values to be mapped to the model outputs
threshold : float (default 0.5)
The threshold for the second value to the label map to be the one chosen
"""
super().__init__()
self.label_map = label_map
self.threshold = threshold
@property
def label_map(self):
return self._label_map
@label_map.setter
def label_map(self, value):
if not isinstance(value, list):
raise TypeError('label_map must be a list')
if len(value) != 2:
raise ValueError(
'label_map must have exactly two values for binary classification')
self._label_map = value
@property
def threshold(self):
return self._threshold
@threshold.setter
def threshold(self, value):
if not isinstance(value, float):
raise TypeError('threshold must be float-valued')
if value < 0 or value > 1:
raise ValueError('threshold value must be between 0 and 1')
self._threshold = value
[docs] def to_dict(self) -> dict:
"""
Get the configuration object as a dictionary
"""
return {
'className': 'BinaryClassification',
'params': {
'labelMap': self.label_map,
'threshold': self.threshold
}
}