|
from threading import Thread |
|
from typing import Tuple, Generator |
|
|
|
from optimum.bettertransformer import BetterTransformer |
|
import streamlit as st |
|
import torch |
|
from torch.quantization import quantize_dynamic |
|
from torch import nn, qint8 |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer, TextIteratorStreamer]: |
|
""" |
|
""" |
|
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False) |
|
model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023") |
|
|
|
model = BetterTransformer.transform(model, keep_original_model=False) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if torch.cuda.is_available() and not no_cuda: |
|
model = model.to("cuda") |
|
elif quantize: |
|
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) |
|
|
|
model.eval() |
|
streamer = TextIteratorStreamer(tokenizer, decode_kwargs={"skip_special_tokens": True, "clean_up_tokenization_spaces": True}) |
|
|
|
return model, tokenizer, streamer |
|
|
|
|
|
def simplify( |
|
text: str, |
|
model: T5ForConditionalGeneration, |
|
tokenizer: T5Tokenizer, |
|
streamer: TextIteratorStreamer |
|
) -> Generator: |
|
""" |
|
""" |
|
text = "[NLG] " + text |
|
|
|
encoded = tokenizer(text, return_tensors="pt") |
|
encoded = {k: v.to(model.device) for k, v in encoded.items()} |
|
gen_kwargs = { |
|
**encoded, |
|
"max_new_tokens": 128, |
|
"streamer": streamer, |
|
} |
|
|
|
with torch.no_grad(): |
|
thread = Thread(target=model.generate, kwargs=gen_kwargs) |
|
thread.start() |
|
|
|
generated_text = "" |
|
for new_text in streamer: |
|
generated_text += new_text |
|
yield generated_text |
|
|