habdine's picture
Upload 3 files
92b59a8 verified
raw
history blame
1.31 kB
from transformers import pipeline
import gradio as gr
import spaces
import os
from threading import Thread
from typing import Iterator
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
TextIteratorStreamer
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("dascim/greekbart-news24-abstract")
model = AutoModelForSeq2SeqLM.from_pretrained("dascim/greekbart-news24-abstract")
model.eval()
@spaces.GPU(duration=90)
def get_input(text) -> Iterator[str]:
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_special_tokens=True)
input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt')
generate_kwargs = dict(
input_ids=input_ids,
tokenizer=tokenizer,
device=device,
streamer=streamer,
max_new_tokens=120,
do_sample=False,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
iface = gr.Interface(fn=get_input,inputs="text",outputs="text",title = "Greek News Summarizer",description="Enter your text (maximum of 512 tokens news article) to get a summary")
iface.launch()