Spaces:
Running
Running
File size: 2,451 Bytes
6275495 9113cb8 6275495 9113cb8 6275495 db606bb 9113cb8 6275495 b10ba12 80ceb8c 6275495 9c601ea 9d9c29a db606bb af5c917 6275495 db606bb c263659 db606bb d334b30 db606bb 5e09a54 24fb973 c263659 9d9c29a 6275495 5e09a54 6275495 5e09a54 6275495 9113cb8 6275495 48ffe3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # отключаем нестабильную загрузку
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
context = (
"Университет Иннополис был основан в 2012 году. "
"Это современный вуз в России, специализирующийся на IT и робототехнике, "
"расположенный в городе Иннополис, Татарстан.\n"
)
def respond(message: str) -> str:
prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.8,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if "Ответ:" in full_output:
answer = full_output.split("Ответ:")[-1].strip()
else:
answer = full_output[len(prompt):].strip()
return answer
# FastAPI app
app = FastAPI(title="Иннополис бот API")
# Чтобы Unity или браузеры могли обращаться, разрешим CORS (подстрой по своему домену)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # или укажи нужный адрес, например ["http://localhost:3000"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
@app.post("/api/ask", response_model=AnswerResponse)
def ask_question(request: QuestionRequest):
answer = respond(request.question)
return {"answer": answer}
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=8000)
|