Spaces:
Runtime error
Runtime error
File size: 2,255 Bytes
6275495 9113cb8 6275495 9113cb8 db606bb 9113cb8 6275495 b10ba12 80ceb8c 6275495 9c601ea 9d9c29a db606bb af5c917 6275495 db606bb c263659 db606bb d334b30 db606bb 5e09a54 24fb973 c263659 9d9c29a 6275495 5e09a54 6275495 8ed0339 6275495 5e09a54 6275495 9113cb8 6275495 8ed0339 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # отключаем нестабильную загрузку
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
context = (
"Университет Иннополис был основан в 2012 году. "
"Это современный вуз в России, специализирующийся на IT и робототехнике, "
"расположенный в городе Иннополис, Татарстан.\n"
)
def respond(message: str) -> str:
prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.8,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if "Ответ:" in full_output:
answer = full_output.split("Ответ:")[-1].strip()
else:
answer = full_output[len(prompt):].strip()
return answer
app = FastAPI(title="Иннополис бот API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # можно указать конкретный домен
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
@app.post("/api/ask", response_model=AnswerResponse)
def ask_question(request: QuestionRequest):
answer = respond(request.question)
return {"answer": answer}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|