Bram Vanroy
push dummy
c2302bf
raw
history blame
1.98 kB
from threading import Thread
from typing import Tuple, Generator
from optimum.bettertransformer import BetterTransformer
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer
@st.cache_resource(show_spinner=False)
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer, TextIteratorStreamer]:
"""
"""
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False)
model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023")
model = BetterTransformer.transform(model, keep_original_model=False)
model.resize_token_embeddings(len(tokenizer))
if torch.cuda.is_available() and not no_cuda:
model = model.to("cuda")
elif quantize: # Quantization not supported on CUDA
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
model.eval()
streamer = TextIteratorStreamer(tokenizer, decode_kwargs={"skip_special_tokens": True, "clean_up_tokenization_spaces": True})
return model, tokenizer, streamer
def simplify(
text: str,
model: T5ForConditionalGeneration,
tokenizer: T5Tokenizer,
streamer: TextIteratorStreamer
) -> Generator:
"""
"""
text = "[NLG] " + text
encoded = tokenizer(text, return_tensors="pt")
encoded = {k: v.to(model.device) for k, v in encoded.items()}
gen_kwargs = {
**encoded,
"max_new_tokens": 128,
"streamer": streamer,
}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text