File size: 2,536 Bytes
de5dcb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from dataclasses import dataclass
import statistics
import sys
from typing import List, Union

from numpy.typing import NDArray


NumSentencesType = Union[List[int], List[List[int]]]
EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]


def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
    def _slice_embeddings(s_idx: int, n_sentences: List[int]):
        _result = []
        for count in n_sentences:
            _result.append(embeddings[s_idx:s_idx + count])
            s_idx += count
        return _result, s_idx

    if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences):
        result, _ = _slice_embeddings(0, num_sentences)
        return result
    elif isinstance(num_sentences, list) and all(
        isinstance(sublist, list) and all(
            isinstance(item, int) for item in sublist
        )
        for sublist in num_sentences
    ):
        nested_result = []
        start_idx = 0
        for nested_num_sentences in num_sentences:
            embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences)
            nested_result.append(embedding_slice)

        return nested_result
    else:
        raise TypeError(f"Incorrect Type for {num_sentences=}")


def is_list_of_strings_at_depth(obj, depth: int) -> bool:
    if depth == 0:
        return isinstance(obj, str)
    elif depth > 0:
        return isinstance(obj, list) and all(is_list_of_strings_at_depth(item, depth - 1) for item in obj)
    else:
        raise ValueError("Depth can't be negative")


def flatten_list(nested_list: list) -> list:
    """
    Recursively flattens a nested list of any depth.

    Parameters:
        nested_list (list): The nested list to flatten.

    Returns:
        list: A flat list containing all the elements of the nested list.
    """
    flat_list = []
    for item in nested_list:
        if isinstance(item, list):
            flat_list.extend(flatten_list(item))
        else:
            flat_list.append(item)
    return flat_list


def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float:
    """
    Computes F1 value
    :param p: Precision Value
    :param r: Recall Value
    :param eps: Epsilon Value
    :return:
    """
    f1 = 2 * p * r / (p + r + eps)
    return f1


@dataclass
class Scores:
    precision: float
    recall: List[float]

    def __post_init__(self):
        self.f1: float = compute_f1(self.precision, statistics.fmean(self.recall))