import re import argparse from typing import List, Union # Keywords used for entity extraction KEYWORDS = { "bigger", "change", "cleared", "constant", "decrease", "decreased", "decreasing", "elevated", "elevation", "enlarged", "enlargement", "enlarging", "expanded", "greater", "growing", "improved", "improvement", "improving", "increase", "increased", "increasing", "larger", "new", "persistence", "persistent", "persisting", "progression", "progressive", "reduced", "removal", "resolution", "resolved", "resolving", "smaller", "stability", "stable", "stably", "unchanged", "unfolded", "worse", "worsen", "worsened", "worsening", "unaltered" } def clean_text(text: str) -> str: """ Clean the input text by removing special characters and redundant spaces or newlines. Args: text (str): Input text. Returns: str: Cleaned text. """ # Remove special characters and redundant newlines text = re.sub(r'\n+', ' ', text) # Replace multiple newlines with a single space text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes with spaces text = re.sub(r'\(___, __, __\)', '', text) # Remove irrelevant underscore patterns text = re.sub(r'---, ---, ---', '', text) # Remove dashed patterns text = re.sub(r'\(__, __, ___\)', '', text) # Remove similar underscore patterns text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes again (if any remain) text = re.sub(r'[^\w\s.,:;()-]', '', text) # Remove non-alphanumeric characters except common punctuation # Remove extra spaces text = re.sub(r'\s{2,}', ' ', text).strip() return text def extract_entities(text: str, keywords: set) -> set: """ Extract entities from the given text based on the provided keywords. Args: text (str): Input text. keywords (set): Set of keywords to extract entities. Returns: set: Set of matched keywords found in the text. """ # Clean the text before extracting entities text = clean_text(text) # Create a regex pattern that matches any of the keywords as whole words pattern = r'\b(' + '|'.join(re.escape(word) for word in keywords) + r')\b' # Find all matches and return them as a set return {match.group().lower() for match in re.finditer(pattern, text.lower())} def calculate_tem_score(prediction_text: str, reference_text: Union[str, List[str]], epsilon: float = 1e-10) -> float: """ Calculate the Temporal Entity Matching (TEM) score (similar to F1-score). Args: reference_text (Union[str, List[str]]): Reference text or a list of reference texts. prediction_text (str): Prediction text. epsilon (float): Small value to avoid division by zero. Returns: float: TEM score. """ if isinstance(reference_text, list): reference_entities = set() for ref in reference_text: reference_entities.update(extract_entities(ref, KEYWORDS)) else: reference_entities = extract_entities(reference_text, KEYWORDS) prediction_entities = extract_entities(prediction_text, KEYWORDS) if len(reference_entities) == 0: if len(prediction_entities) == 0: return { "f1": 1.0, "prediction_entities": prediction_entities, "reference_entities": reference_entities } # Perfect match when both are empty else: return { "f1": epsilon, "prediction_entities": prediction_entities, "reference_entities": reference_entities } # Minimal score when reference is empty but prediction is not # Calculate intersection of entities true_positives = len(prediction_entities & reference_entities) # Calculate precision and recall with epsilon to avoid division by zero precision = (true_positives + epsilon) / (len(prediction_entities) + epsilon) recall = (true_positives + epsilon) / (len(reference_entities) + epsilon) # Calculate TEM score (F1 score) tem_score = (2 * precision * recall) / (precision + recall + epsilon) return { "f1": tem_score, "prediction_entities": prediction_entities, "reference_entities": reference_entities } def temporal_f1_score(predictions: List[str], references: List[Union[str, List[str]]], epsilon: float = 1e-10) -> float: """ Calculate the average TEM score over a list of reference and prediction texts. Args: references (List[Union[str, List[str]]]): List of reference texts or lists of reference texts. predictions (List[str]): List of prediction texts. epsilon (float): Small value to avoid division by zero. Returns: float: Average TEM score. """ assert len(references) == len(predictions), "Reference and prediction lists must have the same length." tem_scores = [] prediction_entities = [] reference_entities = [] for pred, ref in zip(predictions, references): result = calculate_tem_score(pred, ref, epsilon) tem_scores.append(result["f1"]) prediction_entities.append(result["prediction_entities"]) reference_entities.append(result["reference_entities"]) average_f1 = sum(tem_scores) / len(tem_scores) return { "f1": average_f1, "prediction_entities": prediction_entities, "reference_entities": reference_entities } # Command-line interface if __name__ == "__main__": parser = argparse.ArgumentParser(description="Calculate the average TEM score for reference and prediction texts.") parser.add_argument("--predictions", nargs='+', required=True, help="List of prediction texts.") parser.add_argument("--reference", nargs='+', required=True, help="List of reference texts or lists of reference texts.") args = parser.parse_args() # Convert references into a nested list if necessary reference_list = [eval(ref) if ref.startswith('[') else ref for ref in args.reference] # Calculate the average TEM score temporal_f1_score(predictions=args.predictions, references=reference_list)