habdine's picture
Update app.py
34ea0be verified
raw
history blame
1.34 kB
from transformers import pipeline
import gradio as gr
import spaces
import os
from threading import Thread
from typing import Iterator
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
TextIteratorStreamer
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("dascim/greekbart-news24-abstract")
model = AutoModelForSeq2SeqLM.from_pretrained("dascim/greekbart-news24-abstract")
model.eval()
@spaces.GPU(duration=90)
def get_input(text) -> Iterator[str]:
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_special_tokens=True)
input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt')
generate_kwargs = dict(
input_ids=input_ids,
tokenizer=tokenizer,
device=device,
streamer=streamer,
max_new_tokens=120,
do_sample=False,
num_beams=1,
early_stopping=None
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
iface = gr.Interface(fn=get_input,inputs="text",outputs="text",title = "Greek News Summarizer",description="Enter your text (maximum of 512 tokens news article) to get a summary")
iface.launch()