File size: 2,451 Bytes
6275495
 
 
 
 
9113cb8
6275495
9113cb8
6275495
db606bb
9113cb8
6275495
b10ba12
80ceb8c
6275495
9c601ea
 
9d9c29a
db606bb
 
 
 
 
af5c917
6275495
 
db606bb
c263659
 
db606bb
 
d334b30
db606bb
5e09a54
 
24fb973
c263659
9d9c29a
6275495
 
 
 
5e09a54
6275495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e09a54
6275495
 
 
 
 
 
9113cb8
6275495
 
48ffe3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"  # отключаем нестабильную загрузку

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn

model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

context = (
    "Университет Иннополис был основан в 2012 году. "
    "Это современный вуз в России, специализирующийся на IT и робототехнике, "
    "расположенный в городе Иннополис, Татарстан.\n"
)

def respond(message: str) -> str:
    prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=100,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    if "Ответ:" in full_output:
        answer = full_output.split("Ответ:")[-1].strip()
    else:
        answer = full_output[len(prompt):].strip()

    return answer

# FastAPI app
app = FastAPI(title="Иннополис бот API")

# Чтобы Unity или браузеры могли обращаться, разрешим CORS (подстрой по своему домену)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # или укажи нужный адрес, например ["http://localhost:3000"]
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class QuestionRequest(BaseModel):
    question: str

class AnswerResponse(BaseModel):
    answer: str

@app.post("/api/ask", response_model=AnswerResponse)
def ask_question(request: QuestionRequest):
    answer = respond(request.question)
    return {"answer": answer}

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=8000)