SemF1 / utils.py
nbansal's picture
Refactored the code and made it faster
de5dcb7
raw
history blame
2.54 kB
from dataclasses import dataclass
import statistics
import sys
from typing import List, Union
from numpy.typing import NDArray
NumSentencesType = Union[List[int], List[List[int]]]
EmbeddingSlicesType = Union[List[NDArray], List[List[NDArray]]]
def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
def _slice_embeddings(s_idx: int, n_sentences: List[int]):
_result = []
for count in n_sentences:
_result.append(embeddings[s_idx:s_idx + count])
s_idx += count
return _result, s_idx
if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences):
result, _ = _slice_embeddings(0, num_sentences)
return result
elif isinstance(num_sentences, list) and all(
isinstance(sublist, list) and all(
isinstance(item, int) for item in sublist
)
for sublist in num_sentences
):
nested_result = []
start_idx = 0
for nested_num_sentences in num_sentences:
embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences)
nested_result.append(embedding_slice)
return nested_result
else:
raise TypeError(f"Incorrect Type for {num_sentences=}")
def is_list_of_strings_at_depth(obj, depth: int) -> bool:
if depth == 0:
return isinstance(obj, str)
elif depth > 0:
return isinstance(obj, list) and all(is_list_of_strings_at_depth(item, depth - 1) for item in obj)
else:
raise ValueError("Depth can't be negative")
def flatten_list(nested_list: list) -> list:
"""
Recursively flattens a nested list of any depth.
Parameters:
nested_list (list): The nested list to flatten.
Returns:
list: A flat list containing all the elements of the nested list.
"""
flat_list = []
for item in nested_list:
if isinstance(item, list):
flat_list.extend(flatten_list(item))
else:
flat_list.append(item)
return flat_list
def compute_f1(p: float, r: float, eps=sys.float_info.epsilon) -> float:
"""
Computes F1 value
:param p: Precision Value
:param r: Recall Value
:param eps: Epsilon Value
:return:
"""
f1 = 2 * p * r / (p + r + eps)
return f1
@dataclass
class Scores:
precision: float
recall: List[float]
def __post_init__(self):
self.f1: float = compute_f1(self.precision, statistics.fmean(self.recall))