File size: 1,977 Bytes
c2302bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from threading import Thread
from typing import Tuple, Generator
from optimum.bettertransformer import BetterTransformer
import streamlit as st
import torch
from torch.quantization import quantize_dynamic
from torch import nn, qint8
from transformers import T5ForConditionalGeneration, T5Tokenizer, TextStreamer, TextIteratorStreamer
@st.cache_resource(show_spinner=False)
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer, TextIteratorStreamer]:
"""
"""
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False)
model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023")
model = BetterTransformer.transform(model, keep_original_model=False)
model.resize_token_embeddings(len(tokenizer))
if torch.cuda.is_available() and not no_cuda:
model = model.to("cuda")
elif quantize: # Quantization not supported on CUDA
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
model.eval()
streamer = TextIteratorStreamer(tokenizer, decode_kwargs={"skip_special_tokens": True, "clean_up_tokenization_spaces": True})
return model, tokenizer, streamer
def simplify(
text: str,
model: T5ForConditionalGeneration,
tokenizer: T5Tokenizer,
streamer: TextIteratorStreamer
) -> Generator:
"""
"""
text = "[NLG] " + text
encoded = tokenizer(text, return_tensors="pt")
encoded = {k: v.to(model.device) for k, v in encoded.items()}
gen_kwargs = {
**encoded,
"max_new_tokens": 128,
"streamer": streamer,
}
with torch.no_grad():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
|