File size: 1,741 Bytes
2e03cda
9c601ea
c440868
e474e6b
d87c9d2
9c601ea
b10ba12
 
9c601ea
 
9d9c29a
 
af5c917
c440868
b10ba12
 
c440868
 
 
 
12dd231
9c601ea
c263659
 
 
 
c440868
9d9c29a
 
c440868
c263659
9d9c29a
c440868
 
 
b10ba12
 
d87c9d2
 
 
c440868
 
 
d50c842
c440868
d87c9d2
c440868
d87c9d2
 
c440868
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr

model_id = "cody82/unitrip"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

system_message = "Ты — умный помощник по Университету Иннополис."

def respond(user_message, history):
    if history is None:
        history = []
    prompt = system_message + "\n"
    for user_text, bot_text in history:
        prompt += f"User: {user_text}\nAssistant: {bot_text}\n"
    prompt += f"User: {user_message}\nAssistant:"

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=False,
        )

    generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()

    history.append((user_message, generated_text))
    return history, history

def clear_textbox():
    return ""

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    message = gr.Textbox(placeholder="Введите вопрос...")
    state = gr.State([])

    # При отправке сообщения вызываем respond, обновляем чат и состояние
    message.submit(respond, inputs=[message, state], outputs=[chatbot, state])
    # Очищаем поле ввода после отправки
    message.submit(clear_textbox, inputs=[], outputs=[message])

demo.launch(share=True)