mattcracker commited on
Commit
1461bea
·
verified ·
1 Parent(s): c24f398

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -33
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
- from fastapi import FastAPI, HTTPException
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import AutoTokenizer, AutoModel
5
  import torch
6
- import numpy as np
 
 
7
  from pydantic import BaseModel
8
  from typing import List, Dict, Any
9
  import time
@@ -26,20 +26,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
  model = AutoModel.from_pretrained(model_name)
27
  model.eval()
28
 
29
- # OpenAI 兼容的请求模型
30
  class EmbeddingRequest(BaseModel):
31
  input: List[str] | str
32
  model: str | None = model_name
33
  encoding_format: str | None = "float"
34
  user: str | None = None
35
 
36
- # OpenAI 兼容的响应模型
37
  class EmbeddingResponse(BaseModel):
38
  object: str = "list"
39
  data: List[Dict[str, Any]]
40
  model: str
41
  usage: Dict[str, int]
42
 
 
43
  def get_embedding(text: str) -> List[float]:
44
  inputs = tokenizer(
45
  text,
@@ -47,35 +46,29 @@ def get_embedding(text: str) -> List[float]:
47
  truncation=True,
48
  max_length=512,
49
  return_tensors="pt"
50
- )
51
 
52
  with torch.no_grad():
53
  outputs = model(**inputs)
54
- embeddings = outputs.last_hidden_state[:, 0, :].numpy()
55
 
56
  return embeddings[0].tolist()
57
 
58
- # OpenAI 兼容的 embeddings endpoint
59
  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
 
60
  async def create_embeddings(request: EmbeddingRequest):
61
- start_time = time.time()
62
-
63
- # 处理输入
64
  if isinstance(request.input, str):
65
  input_texts = [request.input]
66
  else:
67
  input_texts = request.input
68
 
69
- # 获取嵌入向量
70
  embeddings = []
71
  total_tokens = 0
72
 
73
  for text in input_texts:
74
- # 计算 token 数量
75
  tokens = tokenizer.encode(text)
76
  total_tokens += len(tokens)
77
 
78
- # 获取嵌入向量
79
  embedding = get_embedding(text)
80
 
81
  embeddings.append({
@@ -95,14 +88,10 @@ async def create_embeddings(request: EmbeddingRequest):
95
 
96
  return response
97
 
98
- # Gradio 界面
99
  def gradio_embedding(text: str) -> Dict:
100
- # 创建与 OpenAI 兼容的请求
101
  request = EmbeddingRequest(input=text)
102
-
103
- # 调用 API 处理函数
104
  response = create_embeddings(request)
105
-
106
  return response.dict()
107
 
108
  # 创建 Gradio 界面
@@ -118,19 +107,9 @@ demo = gr.Interface(
118
  ]
119
  )
120
 
121
- # 启动服务
 
 
122
  if __name__ == "__main__":
123
  import uvicorn
124
-
125
- # 首先启动 Gradio
126
- demo.queue()
127
-
128
- # 然后启动 FastAPI
129
- config = uvicorn.Config(
130
- app=app,
131
- host="0.0.0.0",
132
- port=7860,
133
- log_level="info"
134
- )
135
- server = uvicorn.Server(config)
136
- server.run()
 
1
  import gradio as gr
 
 
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
+ import spaces
5
+ from fastapi import FastAPI
6
+ from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
  from typing import List, Dict, Any
9
  import time
 
26
  model = AutoModel.from_pretrained(model_name)
27
  model.eval()
28
 
 
29
  class EmbeddingRequest(BaseModel):
30
  input: List[str] | str
31
  model: str | None = model_name
32
  encoding_format: str | None = "float"
33
  user: str | None = None
34
 
 
35
  class EmbeddingResponse(BaseModel):
36
  object: str = "list"
37
  data: List[Dict[str, Any]]
38
  model: str
39
  usage: Dict[str, int]
40
 
41
+ @spaces.GPU()
42
  def get_embedding(text: str) -> List[float]:
43
  inputs = tokenizer(
44
  text,
 
46
  truncation=True,
47
  max_length=512,
48
  return_tensors="pt"
49
+ ).to(model.device)
50
 
51
  with torch.no_grad():
52
  outputs = model(**inputs)
53
+ embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
54
 
55
  return embeddings[0].tolist()
56
 
 
57
  @app.post("/v1/embeddings", response_model=EmbeddingResponse)
58
+ @spaces.GPU()
59
  async def create_embeddings(request: EmbeddingRequest):
 
 
 
60
  if isinstance(request.input, str):
61
  input_texts = [request.input]
62
  else:
63
  input_texts = request.input
64
 
 
65
  embeddings = []
66
  total_tokens = 0
67
 
68
  for text in input_texts:
 
69
  tokens = tokenizer.encode(text)
70
  total_tokens += len(tokens)
71
 
 
72
  embedding = get_embedding(text)
73
 
74
  embeddings.append({
 
88
 
89
  return response
90
 
91
+ @spaces.GPU()
92
  def gradio_embedding(text: str) -> Dict:
 
93
  request = EmbeddingRequest(input=text)
 
 
94
  response = create_embeddings(request)
 
95
  return response.dict()
96
 
97
  # 创建 Gradio 界面
 
107
  ]
108
  )
109
 
110
+ # 挂载 Gradio 应用到 FastAPI
111
+ app = gr.mount_gradio_app(app, demo, path="/")
112
+
113
  if __name__ == "__main__":
114
  import uvicorn
115
+ uvicorn.run(app, host="0.0.0.0", port=7860)