from typing import Dict, List, Any
import os
from flair.data import Sentence
from flair.models import SequenceTagger

class EndpointHandler(): 
    def __init__(self, path=str):
        #code
        self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        #code
        inputs = data.pop("inputs", data)
        sentence: Sentence = Sentence(inputs)

        self.tagger.predict(sentence, label_name="predicted")
        entities = []
        for span in sentence.get_spans("predicted"):
            if len(span.tokens) == 0:
                continue
            current_entity = {
                "entity_group": span.tag,
                "word": span.text,
                "start": span.tokens[0].start_position,
                "end": span.tokens[-1].end_position,
                "score": span.score,
            }

            entities.append(current_entity)

        return entities