Source code for stringalign.statistics

import warnings
from collections import Counter, defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from numbers import Number
from typing import Self, cast

import numpy as np

import stringalign
from stringalign.align import AlignmentOperation, Kept, Replaced, align_strings
from stringalign.tokenize import Tokenizer

__all__ = ["CombinedAlignmentWarning", "StringConfusionMatrix"]


def sort_by_values(d: dict[str, float], reverse=False) -> dict[str, float]:
    return dict(sorted(d.items(), key=lambda x: x[1], reverse=reverse))


def _is_combined_alignment(alignment: Iterable[AlignmentOperation], tokenizer: Tokenizer) -> bool:
    """Check if any alignment op spans multiple tokens."""
    for op in alignment:
        if isinstance(op, Kept) and next(iter(tokenizer(op.substring)), None) != op.substring:
            return True
        elif isinstance(op, Kept):
            continue

        op = op.generalize()
        if next(iter(tokenizer(op.reference)), "") != op.reference:
            return True
        if next(iter(tokenizer(op.predicted)), "") != op.predicted:
            return True

    return False


def _compute_f1_from_tpr_and_ppv(tpr: float, ppv: float) -> float:
    # If either tpr or ppv is 0, then the F1-score is zero.
    # However, the ppv or tpr can be NAN if any of the computations would involve dividing by zero.
    # Therefore, we need to explicitly return zero in this case to avoid return NAN (as 0 * NAN / (0 * NAN) = NAN).
    #
    # This can, e.g. happen if the model never predicts a given token. The positive predicted value, or precision
    # is defined as TP / (TP + FP), however, if the model e.g. never predicted ``ø``, then TP=0 AND FP=0, which results
    # in PPV = 0/0 = NAN
    # However, the F1 score for ``ø`` should be 0.
    if tpr == 0 or ppv == 0:
        return 0.0
    return (tpr * ppv) / (0.5 * (tpr + ppv))


[docs] class CombinedAlignmentWarning(UserWarning): """Used to warn when passing alignments with potentially combined operations to the confusion matrix."""
[docs] @dataclass(eq=True) class StringConfusionMatrix: """A confusion-matrix like object that counts edit operations for aligned strings. The string confusion matrix counts the number of true positives, false positives and false negatives for two aligned strings. However, the number of true negatives does not make sense in the context of string alignment, as it would correspond to the number of times a token occurs in neither strings. We use the following definitions of true positives, false positives and false negatives: * **True positives:** The number of times a token occurs in the "same place" in both strings, i.e. the number of :class:`stringalign.align.Kept` operations with the given token as the substring. * **False positives:** The number of times a token occurs in the predicted string but not the reference, i.e. number of :class:`stringalign.align.Inserted` operations with the given token as the substring plus the number of :class:`stringalign.align.Replaced` operations with the given token as the predicted token. * **False negatives:** The number of times a token occurs in the reference string but not the predicted, i.e. number of :class:`stringalign.align.Deleted` operations with the given token as the substring plus the number of :class:`stringalign.align.Replaced` operations with the given token as the reference token. * **Edit count:** The number of edit operations (:class:`stringalign.align.Inserted`, :class:`stringalign.align.Deleted` or :class:`stringalign.align.Replaced`) In general, you should not initialize this class with the default constructor, but rather use some of the utility constructors: * :meth:`from_strings_and_alignment` * :meth:`from_strings` * :meth:`from_string_collections` * :meth:`get_empty` """ true_positives: Counter[str] false_positives: Counter[str] # Added tokens false_negatives: Counter[str] # Removed/missed tokens edit_counts: Counter[AlignmentOperation] # Count of each operation type # There is no true negatives when we compare strings. # Either, a character is in the string or it is not.
[docs] @classmethod def from_strings_and_alignment( cls, reference: str, predicted: str, alignment: Iterable[AlignmentOperation], tokenizer: Tokenizer | None = None ) -> Self: """Create confusion matrix based on a reference string, a predicted string and their alignment. .. important:: The string metrics are not well defined if we include combined alginments. This is because the true positive count etc. is not well defined for multi-token strings. For example, how many true positives is it for the ``'ll'`` substring in the string ``'llllll'``. The answer is most likely either 3 or 5 depending on whether we count overlapping substrings. Parameters ---------- reference The reference string, also known as gold standard or ground truth. predicted The string to align with the reference. alignment An optimal alignment for these strings tokenizer : optional A tokenizer that turns a string into an iterable of tokens. For this function, it is sufficient that it is a callable that turns a string into an iterable of tokens. If not provided, then ``stringalign.tokenize.DEFAULT_TOKENIZER`` is used instead, which by default is a grapheme cluster (character) tokenizer. Returns ------- confusion_matrix : StringConfusionMatrix The confusion matrix. """ if tokenizer is None: tokenizer = stringalign.tokenize.DEFAULT_TOKENIZER ref_iter = iter(tokenizer(reference)) pred_iter = iter(tokenizer(predicted)) true_positives: Counter[str] = Counter() false_positives: Counter[str] = Counter() false_negatives: Counter[str] = Counter() edit_counts: Counter[AlignmentOperation] = Counter() if _is_combined_alignment(alignment, tokenizer): warnings.warn( "The substrings of the alignment operation do not contain single tokens. This indicates that the" " alignments have either been combined, which means that the string confusion matrix is ill defined," " and some metrics might be confusing or wrong. Alternatively, the your tokenizer might not provide" " atomic tokens, (i.e. the tokens can be tokenized again:" " `tokenize(tokenize(s)[0])[0] != tokenize(s)[0]`). If that is the case, then you may ignore this" " warning.", CombinedAlignmentWarning, stacklevel=2, ) for op in alignment: if isinstance(op, Kept): for char in tokenizer(op.substring): true_positives[next(ref_iter)] += 1 next(pred_iter) continue edit_counts[op] += 1 op = cast(Replaced, op.generalize()) for char in tokenizer(op.predicted): false_positives[char] += 1 next(pred_iter) for char in tokenizer(op.reference): false_negatives[char] += 1 next(ref_iter) return cls( true_positives=true_positives, false_positives=false_positives, false_negatives=false_negatives, edit_counts=edit_counts, )
[docs] @classmethod def from_strings( cls, reference: str, predicted: str, tokenizer: Tokenizer | None = None, randomize_alignment: bool = False, random_state: np.random.Generator | int | None = None, ) -> Self: """Create confusion matrix based on a reference string and a predicted string. .. note:: This method will first align the strings and then create the confusion matrix. If you already have computed the alignment, you can use :meth:`StringConfusionMatrix.from_strings_and_alignment` instead. Parameters ---------- reference The reference string, also known as gold standard or ground truth. predicted The string to align with the reference. tokenizer : optional A tokenizer that turns a string into an iterable of tokens. For this function, it is sufficient that it is a callable that turns a string into an iterable of tokens. If not provided, then ``stringalign.tokenize.DEFAULT_TOKENIZER`` is used instead, which by default is a grapheme cluster (character) tokenizer. randomize_alignment If ``True``, then a random optimal alignment is chosen (slightly slower if enabled) random_state The NumPy RNG or a seed to create a NumPy RNG used for picking the optimal alignment. If ``None``, then the default RNG will be used instead. Returns ------- confusion_matrix : StringConfusionMatrix The confusion matrix. """ if tokenizer is None: tokenizer = stringalign.tokenize.DEFAULT_TOKENIZER alignment = align_strings( reference, predicted, tokenizer=tokenizer, randomize_alignment=randomize_alignment, random_state=random_state, )[0] return cls.from_strings_and_alignment(reference, predicted, alignment, tokenizer=tokenizer)
[docs] @classmethod def from_string_collections( cls, references: Iterable[str], predictions: Iterable[str], tokenizer: Tokenizer | None = None, ) -> Self: """Create confusion matrix for many strings, summing statistics across pairs of references and predictions. Parameters ---------- references Iterable containing the reference strings. predictions Iterable containing The strings to align with the references. tokenizer : optional A tokenizer that turns a string into an iterable of tokens. For this function, it is sufficient that it is a callable that turns a string into an iterable of tokens. If not provided, then ``stringalign.tokenize.DEFAULT_TOKENIZER`` is used instead, which by default is a grapheme cluster (character) tokenizer. Returns ------- confusion_matrix : StringConfusionMatrix The confusion matrix. """ if tokenizer is None: tokenizer = stringalign.tokenize.DEFAULT_TOKENIZER confusion_matrices = ( cls.from_strings(reference, predicted, tokenizer=tokenizer) for reference, predicted in zip(references, predictions, strict=True) ) return sum(confusion_matrices, start=cls.get_empty())
[docs] @classmethod def get_empty(cls) -> Self: """Make an empty confusion matrix (equivalent to that of two empty strings). This can be used as a starting point for summing multiple confusion matrices when computing micro-averaged metrics over multiple string pairs. Returns ------- confusion_matrix : StringConfusionMatrix An empty confusion matrix. """ return cls( true_positives=Counter(), false_positives=Counter(), false_negatives=Counter(), edit_counts=Counter(), )
[docs] def compute_true_positive_rate(self, aggregate_over: Iterable[str] | None = None) -> dict[str, float] | float: """Compute the true positive rate, also known as sensitivity or recall. The true positive rate is given by the number of true positives divided by the total number of positives. Parameters ---------- aggregate_over : optional If provided, this function returns only a single number, which is the true positive rate for the tokens in the `aggregate_over` iterable. This is useful e.g. if you want to compute the true positive rate for a set of special characters. Returns ------- true_positive_rate : dict[str, float] | float Either a dictionary that maps tokens to their true positive rate, or, if `aggregate_over` is provided, a single float that represent the true positive rate aggregated for the specified tokens. Examples -------- If we compute the true positive rate without aggregating over tokens, we get a dict of true positive rates >>> cm = StringConfusionMatrix.from_strings("ostehøvel", "ostehovl") >>> expected_tp = {'o': 1.0, 's': 1.0, 't': 1.0, 'h': 1.0, 'v': 1.0, 'l': 1.0, 'e': 0.5, 'ø': 0.0} >>> tp = cm.compute_true_positive_rate() >>> expected_tp == tp True If we specify an iterable of tokens to aggregate over, we get the total true positive rate for those tokens. In this case, we aggregate over ``["æ", "ø", "å"]``, and the prediction did not find any of those tokens, so the true positive rate is zero. >>> cm.compute_true_positive_rate(aggregate_over=["æ", "ø", "å"]) 0.0 The aggregated statistics is micro averaged, so the function counts the number of true and false negatives for all tokens, sums them and then computes the true positive rate. >>> cm = StringConfusionMatrix.from_strings("blåbær- og bringebærsyltetøy", "blabaer- og bringebærsyltetoy") >>> cm.compute_true_positive_rate(aggregate_over=["æ", "ø", "å"]) 0.25 """ if aggregate_over is not None and (aggregate_over := list(aggregate_over)): tp = sum(self.true_positives[c] for c in aggregate_over) fn = sum(self.false_negatives[c] for c in aggregate_over) if tp + fn == 0: return float("nan") return tp / (tp + fn) char_count = self.true_positives + self.false_negatives all_tokens = set(char_count) | set(self.false_positives) return sort_by_values( {key: self.true_positives[key] / char_count.get(key, float("nan")) for key in all_tokens}, reverse=True, )
compute_recall = compute_true_positive_rate compute_sensitivity = compute_true_positive_rate
[docs] def compute_positive_predictive_value( self, aggregate_over: Iterable[str] | None = None ) -> dict[str, float] | float: """Compute the positive predicted value, also known as precision. The positive predicted value is given by the number of true positives divided by the total number of predicted positives. Parameters ---------- aggregate_over : optional If provided, this function returns only a single number, which is the positive predicted value for the tokens in the `aggregate_over` iterable. This is useful e.g. if you want to compute the positive predicted value for a set of special characters. See :meth:`StringConfusionMatrix.compute_true_positive_rate` for examples of how this argument works. Returns ------- positive_predictive_value : dict[str, float] | float Either a dictionary that maps tokens to their positive predicted value, or, if `aggregate_over` is provided, a single float that represent the positive predicted value aggregated for the specified tokens. """ if aggregate_over: tp = sum(self.true_positives[c] for c in aggregate_over) fp = sum(self.false_positives[c] for c in aggregate_over) if tp + fp == 0: return float("nan") return tp / (tp + fp) predicted_positive = self.true_positives + self.false_positives all_tokens = set(predicted_positive) | set(self.false_negatives) return sort_by_values( {key: self.true_positives[key] / predicted_positive.get(key, float("nan")) for key in all_tokens}, reverse=True, )
compute_precision = compute_positive_predictive_value
[docs] def compute_false_discovery_rate(self, aggregate_over: Iterable[str] | None = None) -> dict[str, float] | float: """Compute the false discovery rate. The false discovery rate is given by the number of false positives divided by the total number of predicted positives. Parameters ---------- aggregate_over : optional If provided, this function returns only a single number, which is the false discovery rate for the tokens in the `aggregate_over` iterable. This is useful e.g. if you want to compute the false discovery rate for a set of special characters. See :meth:`StringConfusionMatrix.compute_true_positive_rate` for examples of how this argument works. Returns ------- false_discovery_rate : dict[str, float] | float Either a dictionary that maps tokens to their false discovery rate, or, if `aggregate_over` is provided, a single float that represent the false discovery rate aggregated for the specified tokens. """ if aggregate_over: tp = sum(self.true_positives[c] for c in aggregate_over) fp = sum(self.false_positives[c] for c in aggregate_over) if tp + fp == 0: return float("nan") return fp / (tp + fp) predicted_positive = self.true_positives + self.false_positives all_tokens = set(predicted_positive) | set(self.false_negatives) return sort_by_values( {key: self.false_positives[key] / predicted_positive.get(key, float("nan")) for key in all_tokens}, reverse=True, )
[docs] def compute_f1_score(self, aggregate_over: Iterable[str] | None = None) -> dict[str, float] | float: """Compute the F1 score, also known as the Dice score. The F1 score is given by the harmonic mean of the true positive rate and positive predictive value. Alternatively, you can interpret it as the number of true positives divided by the average number of predicted positives and the number of positives in the reference. Parameters ---------- aggregate_over : optional If provided, this function returns only a single number, which is the f1 score aggregated for the tokens in the `aggregate_over` iterable. This is useful e.g. if you want to compute the false discovery rate for a set of special characters. See :meth:`StringConfusionMatrix.compute_true_positive_rate` for examples of how this argument works. Returns ------- f1_score : dict[str, float] | float Either a dictionary that maps tokens to their f1 score rate, or, if `aggregate_over` is provided, a single float that represent the f1 score aggregated for the specified tokens. """ tpr = self.compute_true_positive_rate(aggregate_over=aggregate_over) ppv = self.compute_positive_predictive_value(aggregate_over=aggregate_over) if aggregate_over: assert isinstance(tpr, Number) and isinstance(ppv, Number) return _compute_f1_from_tpr_and_ppv(tpr, ppv) assert isinstance(tpr, dict) and isinstance(ppv, dict) all_chars = set(self.true_positives) | set(self.false_positives) | set(self.false_negatives) tpr, ppv = defaultdict(int, tpr), defaultdict(int, ppv) return sort_by_values( {c: _compute_f1_from_tpr_and_ppv(tpr[c], ppv[c]) for c in all_chars}, reverse=True, )
compute_dice = compute_f1_score
[docs] def compute_token_error_rate(self) -> float: """Compute the token error rate (a generalisation of CER and WER). The token error rate is the number of token edits divided by the total number of tokens in the reference. If the tokenizer tokenizes the string into characters, this is equivalent to the character error rate (CER). Returns ------- token_error_rate : dict[str, float] | float The token error rate. """ total_tokens = sum(self.true_positives.values()) + sum(self.false_negatives.values()) total_edit_counts = sum(self.edit_counts.values()) if total_edit_counts == 0 and total_tokens == 0: return 0.0 elif total_tokens == 0: return float("inf") return total_edit_counts / total_tokens
def __add__(self, other: Self) -> Self: """Create a new confusion matrix, adding the true positives, false positives, false negatives, and edit counts. We use this to compute micro-averaged metrics over multiple string pairs. Parameters ---------- other The other confusion matrix. Returns ------- confusion_matrix : StringConfusionMatrix A new confusion matrix where the number of true positive, false positive, false negatives and the edit counts are given by the sum of the corresponding attributes of the two added matrices. """ if not isinstance(other, self.__class__): return NotImplemented return self.__class__( true_positives=self.true_positives + other.true_positives, false_positives=self.false_positives + other.false_positives, false_negatives=self.false_negatives + other.false_negatives, edit_counts=self.edit_counts + other.edit_counts, ) __radd__ = __add__