cody82 commited on
Commit
983eb46
·
verified ·
1 Parent(s): 000988a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -16
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
3
 
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
8
  import uvicorn
 
9
 
10
- # === Загрузка модели ===
11
  model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
12
  tokenizer = AutoTokenizer.from_pretrained(model_id)
13
  model = AutoModelForCausalLM.from_pretrained(model_id)
@@ -15,23 +16,14 @@ model = AutoModelForCausalLM.from_pretrained(model_id)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
17
 
18
- # Контекст для модели
19
  context = (
20
  "Университет Иннополис был основан в 2012 году. "
21
  "Это современный вуз в России, специализирующийся на IT и робототехнике, "
22
  "расположенный в городе Иннополис, Татарстан.\n"
23
  )
24
 
25
- # === FastAPI приложение ===
26
- app = FastAPI()
27
-
28
- class QuestionRequest(BaseModel):
29
- question: str
30
-
31
- @app.post("/ask")
32
- def generate_answer(request: QuestionRequest):
33
- """Обрабатывает POST-запрос с вопросом и возвращает ответ модели."""
34
- prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {request.question}\nОтвет:"
35
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
36
 
37
  with torch.no_grad():
@@ -51,8 +43,39 @@ def generate_answer(request: QuestionRequest):
51
  else:
52
  answer = output[len(prompt):].strip()
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return {"answer": answer}
55
 
56
- # Точка входа для запуска сервера
 
 
 
57
  if __name__ == "__main__":
58
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
3
 
 
 
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import gradio as gr
7
+ from fastapi import FastAPI, Request
8
  import uvicorn
9
+ from fastapi.middleware.cors import CORSMiddleware
10
 
11
+ # === Модель ===
12
  model_id = "sberbank-ai/rugpt3medium_based_on_gpt2"
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
  model = AutoModelForCausalLM.from_pretrained(model_id)
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model.to(device)
18
 
 
19
  context = (
20
  "Университет Иннополис был основан в 2012 году. "
21
  "Это современный вуз в России, специализирующийся на IT и робототехнике, "
22
  "расположенный в городе Иннополис, Татарстан.\n"
23
  )
24
 
25
+ def generate_response(question):
26
+ prompt = f"Прочитай текст и ответь на вопрос:\n\n{context}\n\nВопрос: {question}\nОтвет:"
 
 
 
 
 
 
 
 
27
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
28
 
29
  with torch.no_grad():
 
43
  else:
44
  answer = output[len(prompt):].strip()
45
 
46
+ return answer
47
+
48
+ # === Gradio интерфейс ===
49
+ def chat_interface(message, history):
50
+ return generate_response(message)
51
+
52
+ demo = gr.ChatInterface(
53
+ fn=chat_interface,
54
+ title="Иннополис Бот",
55
+ description="Задавайте вопросы о Университете Иннополис"
56
+ )
57
+
58
+ # === FastAPI приложение ===
59
+ app = FastAPI()
60
+
61
+ # Настройка CORS
62
+ app.add_middleware(
63
+ CORSMiddleware,
64
+ allow_origins=["*"],
65
+ allow_methods=["*"],
66
+ allow_headers=["*"],
67
+ )
68
+
69
+ @app.post("/api/ask")
70
+ async def api_ask(request: Request):
71
+ data = await request.json()
72
+ question = data.get("question", "")
73
+ answer = generate_response(question)
74
  return {"answer": answer}
75
 
76
+ # === Для работы в Spaces ===
77
+ app = gr.mount_gradio_app(app, demo, path="/")
78
+
79
+ # === Для локального тестирования ===
80
  if __name__ == "__main__":
81
+ uvicorn.run(app, host="0.0.0.0", port=7860)