Source code for stringalign.evaluate

import enum
import string
from collections import Counter, defaultdict, deque
from collections.abc import Generator, Hashable, Iterator, Mapping
from copy import deepcopy
from dataclasses import dataclass
from functools import cached_property
from inspect import cleandoc
from itertools import chain
from typing import Any, Iterable, Literal, Self, TypeVar

import numpy as np

import stringalign
from stringalign.align import (
    AlignmentOperation,
    AlignmentTuple,
    Kept,
    Replaced,
    align_strings,
    combine_alignment_ops,
)
from stringalign.error_classification.case_error import count_case_errors
from stringalign.error_classification.confusable_error import count_confusable_errors
from stringalign.error_classification.diacritic_error import count_diacritic_errors
from stringalign.error_classification.duplication_error import check_ngram_duplication_errors
from stringalign.normalize import StringNormalizer
from stringalign.statistics import StringConfusionMatrix
from stringalign.tokenize import Tokenizer
from stringalign.utils import _indent
from stringalign.visualize import HtmlString

T = TypeVar("T")


[docs] def join_windows(center_string: str, previous_operation: Kept | None, next_operation: Kept | None) -> str: """Join the text in from the center alignment operation with the previous and next operation (if possible). Paramters --------- center_string The text in the center string, must come from an edit operation. previous_operation The previous alignment operation. Since the error classification algorithms use combined alignments, this is guaranteed to be a :class:`stringalign.align.Kept`-operation or ``None``. next_operation The next alignment operation. Since the error classification algorithms use combined alignments, this is guaranteed to be a :class:`stringalign.align.Kept`-operation or ``None``. """ window_text = "" if previous_operation is not None: window_text += previous_operation.substring window_text += center_string if next_operation is not None: window_text += next_operation.substring return window_text
[docs] def check_operation_for_case_error( previous_operation: AlignmentOperation | None, current_operation: AlignmentOperation, next_operation: AlignmentOperation | None, ) -> int: """Check if this alignment operation is an edit due to mistaken casing. This function resolves resolves case errors by casefolding. This means that certain characters are changed even if they are lowercase already (like ``'ß'`` being changed into ``'ss'``). .. note:: Error classification should be performed on combined alignment operations so edit operations and kept operations alternate. Parameters ---------- previous_operation The previous alignment operation. If ``current_operation`` is the first alignment operation in an alignment, this is ``None``. current_operation The alignment operation to check for case errors. next_operation The next alignment operation. If ``current_operation`` is the last alignment operation in an alignment, this is ``None``. tokenizer The tokenizer used for the original alignment. Returns ------- int: The number of edits that are due to mistaken diacritics. See also -------- :func:`stringalign.error_classification.case_error.count_case_errors` """ if not isinstance(current_operation, Replaced): return False current_operation = current_operation.generalize() assert isinstance(current_operation, Replaced) return count_case_errors(current_operation.reference, current_operation.predicted)
[docs] def check_operation_for_diacritic_error( previous_operation: AlignmentOperation | None, current_operation: AlignmentOperation, next_operation: AlignmentOperation | None, ) -> int: """Check if this alignment operation is an edit due to mistaken diacritics. This function resolves confusables with the ``"confusables"``-list as well (otherwise it would not be possible to remove the diacritics). .. note:: Error classification should be performed on combined alignment operations so edit operations and kept operations alternate. Parameters ---------- previous_operation The previous alignment operation. If ``current_operation`` is the first alignment operation in an alignment, this is ``None`` current_operation The alignment operation to check for diacritic errors. next_operation The next alignment operation. If ``current_operation`` is the last alignment operation in an alignment, this is ``None`` tokenizer The tokenizer used for the original alignment. Returns ------- int: The number of edits that are due to mistaken diacritics. See also -------- :func:`stringalign.error_classification.diacritic_error.count_diacritic_errors` """ current_operation = current_operation.generalize() if isinstance(current_operation, Kept): return False return count_diacritic_errors(current_operation.reference, current_operation.predicted)
[docs] def check_operation_for_confusable_error( previous_operation: AlignmentOperation | None, current_operation: AlignmentOperation, next_operation: AlignmentOperation | None, *, tokenizer: Tokenizer, ) -> int: """Check if this alignment operation is an edit due to confusable characters. This function uses the ``"confusables"``-list. If you want to check with a different set of confusables, then you should use :func:`stringalign.error_classification.confusable_error.count_confusable_errors` directly. .. note:: Error classification should be performed on combined alignment operations so edit operations and kept operations alternate. Parameters ---------- previous_operation The previous alignment operation. If ``current_operation`` is the first alignment operation in an alignment, this is ``None`` current_operation The alignment operation to check for confusable errors. next_operation The next alignment operation. If ``current_operation`` is the last alignment operation in an alignment, this is ``None`` tokenizer The tokenizer used for the original alignment. Returns ------- int: The number of edits that are due to confusable characters. See also -------- :func:`stringalign.error_classification.confusable_error.count_confusable_errors` """ current_operation = current_operation.generalize() if isinstance(current_operation, Kept): return False return count_confusable_errors( current_operation.reference, current_operation.predicted, tokenizer=tokenizer, consider_confusables="confusables", )
[docs] def check_operation_for_horizontal_segmentation_error( previous_operation: AlignmentOperation | None, current_operation: AlignmentOperation, next_operation: AlignmentOperation | None, ) -> bool: """Check if the alignment error is likely due to a horisontal segmentation error. This is checked by seeing if the alignment operation is an edit at the start or end of the string. .. note:: Error classification should be performed on combined alignment operations so edit operations and kept operations alternate. Parameters ---------- previous_operation The previous alignment operation. If ``current_operation`` is the first alignment operation in an alignment, this is ``None`` current_operation The alignment operation to check for horisontal segmentation errors. next_operation The next alignment operation. If ``current_operation`` is the last alignment operation in an alignment, this is ``None`` Returns ------- bool: True if the alignment error is likely due to a horisontal segmentation error. Else false. """ is_boundary = (previous_operation is None) or (next_operation is None) return is_boundary and not isinstance(current_operation, Kept)
[docs] def check_operation_for_ngram_duplication_error( previous_operation: AlignmentOperation | None, current_operation: AlignmentOperation, next_operation: AlignmentOperation | None, *, n: int, error_type: Literal["inserted", "deleted"] = "inserted", tokenizer: Tokenizer, ) -> bool: """Check if this alignment operation is an n-gram duplication error or missing duplicate n-gram error. This function checks if the only reason for the alignment operation is due to an n-gram duplication or missing duplicate n-gram. .. note:: Error classification should be performed on combined alignment operations so edit operations and kept operations alternate. Parameters ---------- previous_operation The previous alignment operation. If ``current_operation`` is the first alignment operation in an alignment, this is ``None``. current_operation The alignment operation to check for n-gram duplication errors. next_operation The next alignment operation. If ``current_operation`` is the last alignment operation in an alignment, this is ``None``. n The number of tokens in the n-grams we evaluate. For single token duplication errors, this should be 1. error_type ``"inserted"`` if we are checking for inserted duplicates and ``"deleted"`` if we are checking for deleted duplicates. tokenizer The tokenizer used for the original alignment. Returns ------- bool: True if the only reason for the alignment operation is due to an n-gram duplication or missing duplicate n-gram. Else false. See also -------- :func:`stringalign.error_classification.duplication_error.check_ngram_duplication_errors` """ if isinstance(current_operation, Kept): return False current_operation = current_operation.generalize() assert isinstance(current_operation, Replaced) assert isinstance(next_operation, (Kept, type(None))) assert isinstance(previous_operation, (Kept, type(None))) window_text_reference = join_windows(current_operation.reference, previous_operation, next_operation) window_text_prediction = join_windows(current_operation.predicted, previous_operation, next_operation) return check_ngram_duplication_errors( window_text_reference, window_text_prediction, n=n, error_type=error_type, tokenizer=tokenizer )
def _safe_hash(value: Any) -> int: try: return hash(value) except TypeError: import pickle return hash(pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL))
[docs] class FrozenDict(Mapping[Hashable, Any]): """An immutable and hashable dictionary. Pickle is used to create hashes for non-hashable values. """ def __init__(self, data: Mapping[Hashable, Any] | None = None): if not data: data = {} self._data = deepcopy(data) self._hash: int | None = None def __getitem__(self, key: Hashable) -> Any: return self._data[key] def __iter__(self) -> Iterator[Hashable]: return iter(self._data) def __contains__(self, value: Any) -> bool: return value in self._data def __len__(self) -> int: return len(self._data) def __hash__(self) -> int: if self._hash is not None: return self._hash keys = tuple(self.keys()) values = tuple(_safe_hash(v) for v in self.values()) self._hash = hash((keys, values)) return self._hash def __repr__(self): return f"{type(self).__name__}({self._data!r})"
[docs] class ErrorType(enum.StrEnum): """Enum representing different edit types.""" HORISONTAL_SEGMENTATION_ERROR = enum.auto() TOKEN_DUPLICATION_ERROR = enum.auto() REMOVED_DUPLICATE_TOKEN_ERROR = enum.auto() DIACRITIC_ERROR = enum.auto() CONFUSABLE_ERROR = enum.auto() CASE_ERROR = enum.auto()
[docs] @dataclass(frozen=True, slots=False) class AlignmentAnalyzer: """Utility data class that represents the errors for a single sample (reference/predicted pair) Parameters ---------- reference The reference string, also known as gold standard and ground truth predicted The predicted string combined_alignment The combined alignment for the reference and predicted string raw_alignment The uncombined alignment for the reference and predicted string unique_alignment Boolean flag which is true if the alignment is unique. horisontal_segmentation_errors The alignment operations that likely are wrong due to segmentation errors. Corresponds to edits in the start or end of the string. token_duplication_errors Alignment operations that correspond to tokens that were repeated more times in the prediction than in the reference. removed_duplicate_token_errors Alignment operations that correspond to tokens that were repeated fewer times in the prediction than in the reference. diacritic_errors Alignment operations that correspond to diacritics being added or removed (e.g. ``"ë" -> "e"``). confusable_errors Alignment operations that correspond to confusable tplems being predicted. case_errors Alignment operations that correspond to case errors (i.e. errors that are resolved by casefolding the strings). metadata Optional metadata to include with the line error, useful if you e.g. want to include a text line ID. tokenizer The tokenizer used prior to alignment. Included for reproducibility purposes. """ reference: str predicted: str combined_alignment: AlignmentTuple raw_alignment: AlignmentTuple unique_alignment: bool heuristic_edit_classifications: FrozenDict metadata: FrozenDict | None tokenizer: Tokenizer
[docs] def summarise(self) -> dict[Hashable, Hashable]: """Convert this utility class to a dictionary, where the error classifications are converted to booleans. This is useful if we, for example, want to know the number of many samples with at least one suspected diacritic error. However, it removes information about what the errors might be. Returns ------- summary : dict[Hashable, Hashable] """ metadata = self.metadata if metadata is None: metadata = FrozenDict() return { "reference": self.reference, "predicted": self.predicted, "horisontal_segmentation_error": bool( self.heuristic_edit_classifications[ErrorType.HORISONTAL_SEGMENTATION_ERROR] ), "token_duplication_error": bool(self.heuristic_edit_classifications[ErrorType.TOKEN_DUPLICATION_ERROR]), "removed_duplicate_token_error": bool( self.heuristic_edit_classifications[ErrorType.REMOVED_DUPLICATE_TOKEN_ERROR] ), "diacritic_error": bool(self.heuristic_edit_classifications[ErrorType.DIACRITIC_ERROR]), "confusable_error": bool(self.heuristic_edit_classifications[ErrorType.CONFUSABLE_ERROR]), "case_error": bool(self.heuristic_edit_classifications[ErrorType.CASE_ERROR]), **metadata, }
[docs] @cached_property def confusion_matrix(self) -> StringConfusionMatrix: """The string confusion matrix for this string pair. Returns ------- string_confusion_matrix : StringConfusionMatrix """ return StringConfusionMatrix.from_strings_and_alignment( reference=self.reference, predicted=self.predicted, alignment=self.raw_alignment, tokenizer=self.tokenizer )
[docs] @classmethod def from_strings( cls, reference: str, predicted: str, tokenizer: Tokenizer | None, metadata: Mapping[Hashable, Hashable] | None = None, randomize_alignment: bool = False, random_state: np.random.Generator | int | None = None, ) -> Self: """ Create a AlignmentAnalyzer based on a reference string and a predicted string given a tokenizer. Parameters ---------- reference The reference string, also known as gold standard or ground truth. predicted The string to align with the reference. alignment An optimal alignment for these strings tokenizer : optional A tokenizer that turns a string into an iterable of tokens. For this function, it is sufficient that it is a callable that turns a string into an iterable of tokens. If not provided, then ``stringalign.tokenize.DEFAULT_TOKENIZER`` is used instead, which by default is a grapheme cluster (character) tokenizer. metadata Additional metadata about the sample, e.g. sample id. randomize_alignment If ``True``, then a random optimal alignment is chosen (slightly slower if enabled) random_state The NumPy RNG or a seed to create a NumPy RNG used for picking the optimal alignment. If ``None``, then the default RNG will be used instead. Returns ------- alignment_analyzer : AlignmentAnalyzer The AlignmentAnalyzer object. """ if tokenizer is None: tokenizer = stringalign.tokenize.DEFAULT_TOKENIZER raw_alignment, unique_alignment = align_strings( reference, predicted, tokenizer=tokenizer, randomize_alignment=randomize_alignment, random_state=random_state, ) combined_alignment = tuple(combine_alignment_ops(raw_alignment, tokenizer=tokenizer)) if metadata is not None: frozen_metadata = FrozenDict(metadata) else: frozen_metadata = None if not combined_alignment: return cls( reference=reference, predicted=predicted, combined_alignment=tuple(), raw_alignment=tuple(), unique_alignment=True, heuristic_edit_classifications=FrozenDict({et: tuple() for et in ErrorType}), metadata=frozen_metadata, tokenizer=tokenizer, ) alignment_iterator = iter(combined_alignment) window: deque[AlignmentOperation | None] = deque(maxlen=3) window.append(None) window.append(next(alignment_iterator)) horisontal_segmentation_errors = [] token_duplication_errors = [] removed_duplicate_token_errors = [] diacritic_errors = [] confusable_errors = [] case_errors = [] op: AlignmentOperation | None for op in chain(alignment_iterator, (None, None)): window.append(op) if window[1] is None: break if check_operation_for_horizontal_segmentation_error(window[0], window[1], window[2]): horisontal_segmentation_errors.append(window[1]) if check_operation_for_ngram_duplication_error( window[0], window[1], window[2], n=1, error_type="inserted", tokenizer=tokenizer ): token_duplication_errors.append(window[1]) if check_operation_for_ngram_duplication_error( window[0], window[1], window[2], n=1, error_type="deleted", tokenizer=tokenizer ): removed_duplicate_token_errors.append(window[1]) if check_operation_for_diacritic_error(window[0], window[1], window[2]): diacritic_errors.append(window[1]) if check_operation_for_confusable_error(window[0], window[1], window[2], tokenizer=tokenizer): confusable_errors.append(window[1]) if check_operation_for_case_error(window[0], window[1], window[2]): case_errors.append(window[1]) return cls( reference=reference, predicted=predicted, combined_alignment=combined_alignment, raw_alignment=tuple(raw_alignment), unique_alignment=unique_alignment, heuristic_edit_classifications=FrozenDict( { ErrorType.HORISONTAL_SEGMENTATION_ERROR: tuple(horisontal_segmentation_errors), ErrorType.TOKEN_DUPLICATION_ERROR: tuple(token_duplication_errors), ErrorType.REMOVED_DUPLICATE_TOKEN_ERROR: tuple(removed_duplicate_token_errors), ErrorType.DIACRITIC_ERROR: tuple(diacritic_errors), ErrorType.CONFUSABLE_ERROR: tuple(confusable_errors), ErrorType.CASE_ERROR: tuple(case_errors), } ), metadata=frozen_metadata, tokenizer=tokenizer, )
[docs] def compute_ter(self) -> float: return self.confusion_matrix.compute_token_error_rate()
compute_ter.__doc__ = stringalign.statistics.StringConfusionMatrix.compute_token_error_rate.__doc__
[docs] def visualize(self, which: Literal["raw", "combined"] = "raw", space_alignment_ops: bool = False) -> HtmlString: """Visualize the alignment (for Jupyter Notebooks). This is a simple wrapper around :func:`stringalign.visualize.create_alignment_html`. Use that function if you want to customise the visualisation further. See :ref:`visualize_example` for an example. Parameters ---------- which If ``which="raw"``, then the raw alignment is visualised and if ``which="combined"`` then the combined alignment is visualised. space_alignment_ops If this is True, then there will be a small space between each alignment operation. Returns ------- HtmlString A special string type that is interpreted as HTML by Jupyter. It contains HTML for visualising the specified alignment. """ if which == "raw": alignment = self.raw_alignment else: alignment = self.combined_alignment return stringalign.visualize.create_alignment_html(alignment=alignment, space_alignment_ops=space_alignment_ops)
def __repr__(self) -> str: repr_template = string.Template( cleandoc( """AlignmentAnalyzer( reference=$reference, predicted=$predicted, metadata=$metadata, tokenizer=$tokenizer )""" ) ) return repr_template.substitute( reference=repr(self.reference), predicted=repr(self.predicted), metadata=repr(self.metadata), tokenizer=_indent( repr(self.tokenizer), n_spaces=4, skip=1, ), ) __str__ = __repr__
[docs] @dataclass(frozen=True, slots=False) class MultiAlignmentAnalyzer: """Utility class for evaluating all samples in a dataset. Parameters ---------- references: Reference strings predictions: Strings to align with corresponding references. alignment_analyzers: Alignment errors, one per sample. """ references: tuple[str, ...] predictions: tuple[str, ...] alignment_analyzers: tuple[AlignmentAnalyzer, ...] tokenizer: stringalign.tokenize.Tokenizer
[docs] def dump(self) -> list[dict[Hashable, Hashable]]: """Convert the alignment errors to dictionaries, where the error classifications are converted to booleans. This is useful if we, for example, want to know the number of many samples with at least one suspected diacritic error. However, it removes information about what the errors might be. Returns ------- summary : list[dict[Hashable, Hashable]] """ return [err.summarise() for err in self.alignment_analyzers]
@property def not_unique_alignments(self) -> Generator[AlignmentAnalyzer]: """:class:`AlignmentAnalyzer` instances whose alignments are not unique. This is useful to assess why alignments might not be unique. For example, whether the non-uniqueness stems from duplicated or transposed tokens. Yields ------ AlignmentAnalyzer """ return (err for err in self.alignment_analyzers if not err.unique_alignment)
[docs] @cached_property def alignment_operation_counts(self) -> dict[Literal["raw", "combined"], Counter[AlignmentOperation]]: """Count the number of times each alignment operation occurs. This is useful to identify common mistakes for a transcription model. Returns ------- Counter[AlignmentOperation] The number of times each alignment operation occurs See also -------- edit_counts """ return { "combined": Counter(op for analyzer in self.alignment_analyzers for op in analyzer.combined_alignment), "raw": Counter(op for analyzer in self.alignment_analyzers for op in analyzer.raw_alignment), }
@property def edit_counts(self) -> dict[Literal["raw", "combined"], Counter[AlignmentOperation]]: """Count the number of times each alignment operation representing edits occurs. This is useful to identify common mistakes for a transcription model. Returns ------- Counter[AlignmentOperation] The number of times each edit operation, i.e. alignment operations that represent edit (i.e. :class:`stringalign.align.Deleted`, :class:`stringalign.align.Inserted`, or :class:`stringalign.align.Replaced`, occurs. See also -------- alignment_operation_counts """ def remove_kept_from_counter(cnt: Counter[AlignmentOperation]) -> Counter[AlignmentOperation]: return Counter({k: v for k, v in cnt.items() if not isinstance(k, Kept)}) return { "combined": remove_kept_from_counter(self.alignment_operation_counts["combined"]), "raw": remove_kept_from_counter(self.alignment_operation_counts["raw"]), }
[docs] @cached_property def confusion_matrix(self) -> StringConfusionMatrix: """The micro-averaged confusion matrix for all samples.""" return sum((ae.confusion_matrix for ae in self.alignment_analyzers), start=StringConfusionMatrix.get_empty())
[docs] @cached_property def alignment_operator_index( self, ) -> dict[Literal["raw", "combined"], dict[AlignmentOperation, frozenset[AlignmentAnalyzer]]]: """Mapping from alignment ops. to sets of :class:`AlignmentAnalyzer` with that operation in the combined alignment. This function is used to find all samples that contain specific alignment operations. It can, for example be used to identify all lines that contain a specific error a transcription model makes, which again can be useful for finding mistakes in the references. """ raw_index = defaultdict(set) for alignment_analyzer in self.alignment_analyzers: for alignment_op in alignment_analyzer.raw_alignment: raw_index[alignment_op].add(alignment_analyzer) combined_index = defaultdict(set) for alignment_analyzer in self.alignment_analyzers: for alignment_op in alignment_analyzer.combined_alignment: combined_index[alignment_op].add(alignment_analyzer) return { "raw": {k: frozenset(v) for k, v in raw_index.items()}, "combined": {k: frozenset(v) for k, v in combined_index.items()}, }
[docs] @cached_property def false_positive_index(self) -> dict[str, frozenset[AlignmentAnalyzer]]: """Mapping from tokens to sets of :class:`AlignmentAnalyzer` with that false positive token""" out = defaultdict(set) for alignment_analyzer in self.alignment_analyzers: for token in alignment_analyzer.confusion_matrix.false_positives: out[token].add(alignment_analyzer) return {k: frozenset(v) for k, v in out.items()}
[docs] @cached_property def false_negative_index(self) -> dict[str, frozenset[AlignmentAnalyzer]]: """Mapping from tokens to sets of :class:`AlignmentAnalyzer` with that false negative token""" out = defaultdict(set) for alignment_analyzer in self.alignment_analyzers: for token in alignment_analyzer.confusion_matrix.false_negatives: out[token].add(alignment_analyzer) return {k: frozenset(v) for k, v in out.items()}
@property def error_type_index(self) -> dict[ErrorType, Generator[AlignmentAnalyzer, None, None]]: """ Mapping from error type to generators yielding :class:`AlignmentAnalyzer` with at least one edit of that type. **Horisontal segmentation errors** An alignment is said to contain a horisontal segmentation error if there is an edit at the start or end of the alignment. See :func:`check_operation_for_horizontal_segmentation_error` for more information. **Token duplication errors** An alignment is said to contain a duplication error if at least one token is duplicated in the prediction where it is not duplicated in the reference. For example, transcribing ``"hello"`` as ``"helllo"`` would correspond to a duplication error. See :func:`check_operation_for_ngram_duplication_error` for more information. **Missed duplicated token errors** An alignment is said to contain a removed duplicate token error if at least one token is duplicated in the reference where it is duplicated in the prediction. For example, transcribing ``"hello"`` as ``"helo"`` would correspond to a removed duplicate token error. See :func:`check_operation_for_ngram_duplication_error` and :func:`stringalign.error_classification.duplication_error.check_ngram_duplication_errors` for more information. **Missing diacritic errors** An alignment is said to contain a diacritic error if at least one of the edits would change into a Kept if we remove all diacritics. Note that this function also resolves confusables to be able to correctly remove diacritics. See :func:`check_operation_for_diacritic_error` and :func:`stringalign.error_classification.diacritic_error.count_diacritic_errors` for more information. **Confusable character errors** An alignment is said to contain a confusable error if at least one of the edits would change into a Kept if we resolve confusables. See :func:`check_operation_for_confusable_error` and :func:`stringalign.error_classification.confusable_error.count_confusable_errors` for more information. **Case errors** An alignment is said to contain a case error if at least one of the edits would change into a Kept if we case fold the contents. See :func:`check_operation_for_case_error` and :func:`stringalign.error_classification.case_error.count_case_errors` for more information. Returns ------- dict[ErrorType, Generator[AlignmentAnalyzer, None, None]] """ def make_alignment_analyzer_generator(error_type: ErrorType) -> Generator[AlignmentAnalyzer, None, None]: """We need this function to bind the error type variable in the generator""" return (aa for aa in self.alignment_analyzers if aa.heuristic_edit_classifications[error_type]) return {et: make_alignment_analyzer_generator(et) for et in ErrorType}
[docs] @classmethod def from_strings( cls, references: Iterable[str], predictions: Iterable[str], tokenizer: Tokenizer | None = None, metadata: Iterable[Mapping[Hashable, Hashable] | None] | None = None, randomize_alignment: bool = False, random_state: np.random.Generator | int | None = None, ) -> Self: """Creates a transcription evaluator from iterables containing references and predictions. Parameters ---------- references Iterable containing the reference strings. predictions Iterable containing the strings to align with the references. tokenizer : optional A tokenizer that turns a string into an iterable of tokens. For this function, it is sufficient that it is a callable that turns a string into an iterable of tokens. If not provided, then ``stringalign.tokenize.DEFAULT_TOKENIZER`` is used instead, which by default is a grapheme cluster (character) tokenizer. metadata Additional metadata about the sample, e.g. sample id. randomize_alignment If ``True``, then a random optimal alignment is chosen (slightly slower if enabled) random_state The NumPy RNG or a seed to create a NumPy RNG used for picking the optimal alignment. If ``None``, then the default RNG will be used instead. Returns ------- transcription_evaluator: MultiAlignmentAnalyzer """ references = tuple(references) predictions = tuple(predictions) if metadata is None: metadata = tuple(None for _ in references) alignment_analyzers = tuple( AlignmentAnalyzer.from_strings( reference, prediction, tokenizer, metadata=metadata, randomize_alignment=randomize_alignment, random_state=random_state, ) for reference, prediction, metadata in zip(references, predictions, metadata, strict=True) ) return cls( references=references, predictions=predictions, alignment_analyzers=alignment_analyzers, tokenizer=alignment_analyzers[0].tokenizer, )
def __len__(self) -> int: """The number of samples in the transcription.""" return len(self.alignment_analyzers) def __repr__(self) -> str: repr_template = string.Template( cleandoc( """MultiAlignmentAnalyzer( len=$len, tokenizer=$tokenizer )""" ) ) return repr_template.substitute( len=len(self), tokenizer=_indent( repr(self.tokenizer), n_spaces=4, skip=1, ), ) __str__ = __repr__
[docs] def compute_ter(self) -> float: return self.confusion_matrix.compute_token_error_rate()
[docs] def compute_ter( reference: str, predicted: str, tokenizer: Tokenizer, ) -> tuple[float, AlignmentAnalyzer]: """Compute the token error rate (TER) for two strings. This is just a convenience function that creates an :class:`AlignmentAnalyzer` and computes the TER with the :meth:`stringalign.statistics.StringConfusionMatrix.compute_token_error_rate` method of the :class:`AlignmentAnalyzer`'s :class:`stringalign.statistics.StringConfusionMatrix`. For more information about the TER, see :ref:`token_error_rate`. Parameters ---------- reference The reference string, also known as gold standard and ground truth predicted The predicted string tokenizer Tokenizer to split the string into a iterable of tokens. Returns ------- float The TER AlignmentAnalyzer The alignment analyzer used to compute the TER (token error rate) See also -------- stringalign.evaluate.compute_cer stringalign.evaluate.compute_wer stringalign.evaluate.AlignmentAnalyzer stringalign.statistics.StringConfusionMatrix Examples -------- If we use a :class:`stringalign.tokenize.GraphemeClusterTokenizer`, we compute the character error rate: >>> tokenizer = stringalign.tokenize.GraphemeClusterTokenizer() >>> ter, analyzer = compute_ter("Hi there", "He there", tokenizer=tokenizer) >>> ter 0.125 >>> analyzer.confusion_matrix.compute_token_error_rate() 0.125 >>> cer, _analyzer = compute_cer("Hi there", "He there") >>> cer 0.125 And if we use a :class:`stringalign.tokenize.SplitAtWhitespaceTokenizer`, we compute a word error rate: >>> tokenizer = stringalign.tokenize.SplitAtWhitespaceTokenizer() >>> ter, analyzer = compute_ter("Hi there", "He there", tokenizer=tokenizer) >>> ter 0.5 >>> analyzer.confusion_matrix.compute_token_error_rate() 0.5 >>> wer, wer_analyzer = compute_wer("Hi there", "He there", word_definition="whitespace") >>> wer 0.5 >>> wer_analyzer AlignmentAnalyzer( reference='Hi there', predicted='He there', metadata=None, tokenizer=SplitAtWhitespaceTokenizer( pre_tokenization_normalizer=StringNormalizer( normalization='NFC', case_insensitive=False, normalize_whitespace=False, remove_whitespace=False, remove_non_word_characters=False, resolve_confusables=None, ), post_tokenization_normalizer=StringNormalizer( normalization='NFC', case_insensitive=False, normalize_whitespace=False, remove_whitespace=False, remove_non_word_characters=False, resolve_confusables=None, ) ) ) """ analyzer = AlignmentAnalyzer.from_strings( reference=reference, predicted=predicted, tokenizer=tokenizer, ) return analyzer.confusion_matrix.compute_token_error_rate(), analyzer
[docs] def compute_wer( reference: str, predicted: str, word_definition: Literal["whitespace", "unicode", "unicode_word_boundary"] = "whitespace", ) -> tuple[float, AlignmentAnalyzer]: """Compute the WER for two strings. This is just a convenience function that creates an :class:`AlignmentAnalyzer` with an appropriate tokenizer and computes the WER with the :meth:`stringalign.statistics.StringConfusionMatrix.compute_token_error_rate` method of the :class:`AlignmentAnalyzer`'s :class:`stringalign.statistics.StringConfusionMatrix`. For more information about the WER, see :ref:`token_error_rate`. Parameters ---------- reference The reference string, also known as gold standard and ground truth predicted The predicted string word_definition How words are defined for the WER. Used to select tokenizer: * ``"whitespace"``: :class:`stringalign.tokenize.SplitAtWhitespaceTokenizer` (default) * ``"unicode"``: :class:`stringalign.tokenize.UnicodeWordTokenizer` * ``"unicode_boundary"``: :class:`stringalign.tokenize.SplitAtWordBoundaryTokenizer` Returns ------- float The WER AlignmentAnalyzer The alignment analyzer used to compute the WER (via the token error rate) See also -------- stringalign.evaluate.compute_ter stringalign.evaluate.compute_cer stringalign.evaluate.AlignmentAnalyzer stringalign.statistics.StringConfusionMatrix Examples -------- >>> wer, analyzer = compute_wer("Hello world!", "Hello world") >>> wer 0.5 >>> analyzer.confusion_matrix.compute_token_error_rate() 0.5 >>> analyzer AlignmentAnalyzer( reference='Hello world!', predicted='Hello world', metadata=None, tokenizer=SplitAtWhitespaceTokenizer( pre_tokenization_normalizer=StringNormalizer( normalization='NFC', case_insensitive=False, normalize_whitespace=False, remove_whitespace=False, remove_non_word_characters=False, resolve_confusables=None, ), post_tokenization_normalizer=StringNormalizer( normalization='NFC', case_insensitive=False, normalize_whitespace=False, remove_whitespace=False, remove_non_word_characters=False, resolve_confusables=None, ) ) ) """ tokenizer: stringalign.tokenize.Tokenizer if word_definition == "whitespace": tokenizer = stringalign.tokenize.SplitAtWhitespaceTokenizer() elif word_definition == "unicode": tokenizer = stringalign.tokenize.UnicodeWordTokenizer() elif word_definition == "unicode_boundary": tokenizer = stringalign.tokenize.SplitAtWordBoundaryTokenizer() return compute_ter(reference, predicted, tokenizer)
[docs] def compute_cer( reference: str, predicted: str, ) -> tuple[float, AlignmentAnalyzer]: """Compute the CER for two strings. This is just a convenience function that creates an :class:`AlignmentAnalyzer` with a :class:`stringalign.tokenize.GraphemeClusterTokenizer` and computes the CER with the :meth:`stringalign.statistics.StringConfusionMatrix.compute_token_error_rate` method of the :class:`AlignmentAnalyzer`'s :class:`stringalign.statistics.StringConfusionMatrix`. For more information about the CER, see :ref:`token_error_rate`. Parameters ---------- reference The reference string, also known as gold standard and ground truth predicted The predicted string Returns ------- float The CER AlignmentAnalyzer The alignment analyzer used to compute the CER (via the token error rate) See also -------- stringalign.evaluate.compute_ter stringalign.evaluate.compute_wer stringalign.evaluate.AlignmentAnalyzer stringalign.statistics.StringConfusionMatrix Examples -------- >>> tokenizer = stringalign.tokenize.GraphemeClusterTokenizer() >>> ter, analyzer = compute_cer("Hi there", "He there") >>> ter 0.125 >>> analyzer.confusion_matrix.compute_token_error_rate() 0.125 >>> analyzer AlignmentAnalyzer( reference='Hi there', predicted='He there', metadata=None, tokenizer=GraphemeClusterTokenizer( pre_tokenization_normalizer=StringNormalizer( normalization='NFC', case_insensitive=False, normalize_whitespace=False, remove_whitespace=False, remove_non_word_characters=False, resolve_confusables=None, ), post_tokenization_normalizer=StringNormalizer( normalization='NFC', case_insensitive=False, normalize_whitespace=False, remove_whitespace=False, remove_non_word_characters=False, resolve_confusables=None, ) ) ) """ tokenizer = stringalign.tokenize.GraphemeClusterTokenizer() return compute_ter(reference, predicted, tokenizer)