Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, EncoderDecoderModel | |
from transformers import pipeline as hf_pipeline | |
from pathlib import Path | |
import spaces | |
import re | |
from .app_logger import get_logger | |
class NpcBertGPT2(): | |
logger = get_logger() | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.pipeline = None | |
# relative to app.py | |
self.pretrained_model = "./models/npc-bert-gpt2-best" | |
self.logger.info(f"Created {__class__.__name__} instance.") | |
def load(self): | |
"""Loads the fine-tuned EncoderDecoder model and related components. | |
This method initializes the model, tokenizer, and pipeline for the | |
report conclusion generation task using the pre-trained weights from the | |
specified directory. | |
Raises: | |
FileNotFoundError: If the pretrained model directory is not found. | |
""" | |
if not Path(self.pretrained_model).is_dir(): | |
raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}") | |
self.model = EncoderDecoderModel.from_pretrained(self.pretrained_model) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model) | |
self.pipeline = hf_pipeline("text2text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device_map='auto', | |
num_beams=4, | |
do_sample=True, | |
top_k = 5, | |
temperature=.95, | |
early_stopping=True, | |
no_repeat_ngram_size=5, | |
max_new_tokens=60) | |
def __call__(self, *args): | |
"""Performs masked language modeling prediction. | |
This method should be called only after the `load` method has been executed | |
to ensure that the model and pipeline are properly initialized. It accepts | |
arguments to pass to the Hugging Face fill-mask pipeline. | |
Args: | |
*args: Variable length argument list to pass to the pipeline. | |
Returns: | |
The output of the fill-mask pipeline. | |
Raises: | |
BrokenPipeError: If the model has not been loaded before calling this method. | |
""" | |
if self.pipeline is None: | |
msg = "Model was not initialized, have you run load()?" | |
raise BrokenPipeError(msg) | |
self.logger.info(f"Model: {self.pipeline.model.device = }") | |
pipe_out, = self.pipeline(*args) | |
pipe_out = pipe_out['generated_text'] | |
self.logger.info(f"Generated text: {pipe_out}") | |
# remove repeated lines by hard coding | |
mo = re.search("\. (questionable|anterio|zius)", pipe_out) | |
if mo is not None: | |
end_sig = mo.start() | |
pipe_out = pipe_out[:end_sig + 1] | |
self.logger.info(f"Displayed text: {pipe_out}") | |
return pipe_out | |