cody82 commited on
Commit
6275495
·
verified ·
1 Parent(s): 9113cb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -14
app.py CHANGED
@@ -1,13 +1,18 @@
 
 
 
 
 
1
  from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
-
6
- app = FastAPI()
7
 
8
  model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_id)
10
  model = AutoModelForCausalLM.from_pretrained(model_id)
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
@@ -17,12 +22,8 @@ context = (
17
  "расположенный в городе Иннополис, Татарстан.\n"
18
  )
19
 
20
- class Question(BaseModel):
21
- message: str
22
-
23
- @app.post("/ask")
24
- def ask(q: Question):
25
- prompt = f"{context}\nВопрос: {q.message}\nОтвет:"
26
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
27
 
28
  with torch.no_grad():
@@ -35,10 +36,37 @@ def ask(q: Question):
35
  pad_token_id=tokenizer.eos_token_id
36
  )
37
 
38
- output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
39
- if "Ответ:" in output:
40
- answer = output.split("Ответ:")[-1].strip()
 
41
  else:
42
- answer = output.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
44
  return {"answer": answer}
 
 
 
 
1
+ import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # отключаем нестабильную загрузку
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from fastapi import FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
+ import uvicorn
 
 
 
10
 
11
  model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
12
+
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  model = AutoModelForCausalLM.from_pretrained(model_id)
15
+
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model.to(device)
18
 
 
22
  "расположенный в городе Иннополис, Татарстан.\n"
23
  )
24
 
25
+ def respond(message: str) -> str:
26
+ prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {message}\nОтвет:"
 
 
 
 
27
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
28
 
29
  with torch.no_grad():
 
36
  pad_token_id=tokenizer.eos_token_id
37
  )
38
 
39
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
+
41
+ if "Ответ:" in full_output:
42
+ answer = full_output.split("Ответ:")[-1].strip()
43
  else:
44
+ answer = full_output[len(prompt):].strip()
45
+
46
+ return answer
47
+
48
+ # FastAPI app
49
+ app = FastAPI(title="Иннополис бот API")
50
+
51
+ # Чтобы Unity или браузеры могли обращаться, разрешим CORS (подстрой по своему домену)
52
+ app.add_middleware(
53
+ CORSMiddleware,
54
+ allow_origins=["*"], # или укажи нужный адрес, например ["http://localhost:3000"]
55
+ allow_credentials=True,
56
+ allow_methods=["*"],
57
+ allow_headers=["*"],
58
+ )
59
+
60
+ class QuestionRequest(BaseModel):
61
+ question: str
62
 
63
+ class AnswerResponse(BaseModel):
64
+ answer: str
65
+
66
+ @app.post("/api/ask", response_model=AnswerResponse)
67
+ def ask_question(request: QuestionRequest):
68
+ answer = respond(request.question)
69
  return {"answer": answer}
70
+
71
+ if __name__ == "__main__":
72
+ uvicorn.run("app:app", host="0.0.0.0", port=8000)