Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import re | |
model_id = "google/flan-t5-base" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
context = ( | |
"Университет Иннополис был основан в 2012 году. " | |
"Это современный вуз в России, специализирующийся на IT и робототехнике, " | |
"расположенный в городе Иннополис, Татарстан." | |
) | |
def clean_answer(answer, prompt): | |
# Убираем prompt из начала, если остался | |
answer = answer[len(prompt):].strip() if answer.lower().startswith(prompt.lower()) else answer.strip() | |
# Оставляем только кириллицу, пробелы и знаки препинания | |
answer = re.sub(r"[^а-яА-ЯёЁ ,.\-:;?!]", "", answer) | |
# Дополнительно можно убрать повторяющиеся символы | |
answer = re.sub(r"(.)\1{2,}", r"\1", answer) | |
return answer | |
def respond(message, history=None): | |
if history is None: | |
history = [] | |
prompt = ( | |
"Используя следующий контекст, ответь на вопрос четко и кратко.\n" | |
f"Контекст: {context}\n" | |
f"Вопрос: {message}\n" | |
"Ответ:" | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
do_sample=False, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
raw_answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
answer = clean_answer(raw_answer, prompt) | |
history.append((message, answer)) | |
return history | |
iface = gr.ChatInterface(fn=respond, title="Innopolis Q&A") | |
iface.launch() | |