mateoluksenberg commited on
Commit
c8f3971
·
verified ·
1 Parent(s): d0c5413

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -79
app.py CHANGED
@@ -5,17 +5,27 @@ import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
- from fastapi import FastAPI, UploadFile, File, Form
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from pydantic import BaseModel
11
- from typing import Optional, List
12
- import logging
13
 
14
- import fitz # PyMuPDF
15
  import docx
16
  from pptx import Presentation
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  MODEL_LIST = ["nikravan/glm-4vq"]
 
19
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
  MODEL_ID = MODEL_LIST[0]
21
  MODEL_NAME = "GLM-4vq"
@@ -36,18 +46,23 @@ h1 {
36
  }
37
  """
38
 
 
39
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
40
 
 
 
41
  def extract_text(path):
42
  return open(path, 'r').read()
43
 
 
44
  def extract_pdf(path):
45
- doc = fitz.open(path)
46
  text = ""
47
  for page in doc:
48
  text += page.get_text()
49
  return text
50
 
 
51
  def extract_docx(path):
52
  doc = docx.Document(path)
53
  data = []
@@ -56,6 +71,7 @@ def extract_docx(path):
56
  content = '\n\n'.join(data)
57
  return content
58
 
 
59
  def extract_pptx(path):
60
  prs = Presentation(path)
61
  text = ""
@@ -65,6 +81,7 @@ def extract_pptx(path):
65
  text += shape.text + "\n"
66
  return text
67
 
 
68
  def mode_load(path):
69
  choice = ""
70
  file_type = path.split(".")[-1]
@@ -82,6 +99,7 @@ def mode_load(path):
82
  print(content[:100])
83
  return choice, content[:5000]
84
 
 
85
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
86
  content = Image.open(path).convert('RGB')
87
  choice = "image"
@@ -90,6 +108,7 @@ def mode_load(path):
90
  else:
91
  raise gr.Error("Oops, unsupported files.")
92
 
 
93
  @spaces.GPU()
94
  def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
95
 
@@ -113,9 +132,11 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
113
  conversation.append({"role": "user", "content": format_msg})
114
  else:
115
  if len(history) == 0:
 
116
  contents = None
117
  conversation.append({"role": "user", "content": message['text']})
118
  else:
 
119
  for prompt, answer in history:
120
  if answer is None:
121
  prompt_files.append(prompt[0])
@@ -128,6 +149,7 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
128
  choice = ""
129
  conversation.append({"role": "user", "image": "", "content": message['text']})
130
 
 
131
  if choice == "image":
132
  conversation.append({"role": "user", "image": contents, "content": message['text']})
133
  elif choice == "doc":
@@ -159,11 +181,18 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
159
  buffer += new_text
160
  yield buffer
161
 
162
- chatbot = gr.Chatbot()
 
 
 
163
  chat_input = gr.MultimodalTextbox(
164
  interactive=True,
165
  placeholder="Enter message or upload a file ...",
166
  show_label=False,
 
 
 
 
167
  )
168
 
169
  EXAMPLES = [
@@ -173,80 +202,14 @@ EXAMPLES = [
173
  [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
174
  ]
175
 
176
- app = FastAPI()
177
- app.add_middleware(
178
- CORSMiddleware,
179
- allow_origins=["*"],
180
- allow_credentials=True,
181
- allow_methods=["*"],
182
- allow_headers=["*"],
183
- )
184
-
185
- class ChatMessage(BaseModel):
186
- text: str
187
- history: Optional[List] = []
188
- temperature: float = 0.8
189
- max_length: int = 4096
190
- top_p: float = 1.0
191
- top_k: int = 10
192
- penalty: float = 1.0
193
-
194
-
195
- @app.post("/test/")
196
- async def test_endpoint(message: dict):
197
- logging.info(f"Received message: {message}")
198
- if "text" not in message:
199
- raise HTTPException(status_code=400, detail="Missing 'text' in request body")
200
-
201
- response = {"message": f"Received your message: {message['text']}"}
202
- return response
203
-
204
- @app.post("/chat/")
205
- async def chat_endpoint(message: ChatMessage, file: Optional[UploadFile] = None):
206
- conversation = []
207
- if file:
208
- path = f"/tmp/{file.filename}"
209
- with open(path, "wb") as f:
210
- f.write(await file.read())
211
- choice, contents = mode_load(path)
212
- if choice == "image":
213
- conversation.append({"role": "user", "image": contents, "content": message.text})
214
- elif choice == "doc":
215
- format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message.text
216
- conversation.append({"role": "user", "content": format_msg})
217
- else:
218
- conversation.append({"role": "user", "content": message.text})
219
-
220
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
221
- return_tensors="pt", return_dict=True).to(model.device)
222
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
223
-
224
- generate_kwargs = dict(
225
- max_length=message.max_length,
226
- streamer=streamer,
227
- do_sample=True,
228
- top_p=message.top_p,
229
- top_k=message.top_k,
230
- temperature=message.temperature,
231
- repetition_penalty=message.penalty,
232
- eos_token_id=[151329, 151336, 151338],
233
- )
234
- gen_kwargs = {**input_ids, **generate_kwargs}
235
-
236
- with torch.no_grad():
237
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
238
- thread.start()
239
- buffer = ""
240
- for new_text in streamer:
241
- buffer += new_text
242
- return {"response": buffer}
243
-
244
  with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
245
  gr.HTML(TITLE)
246
  gr.HTML(DESCRIPTION)
247
  gr.ChatInterface(
248
  fn=stream_chat,
249
  multimodal=True,
 
 
250
  textbox=chat_input,
251
  chatbot=chatbot,
252
  fill_height=True,
@@ -297,6 +260,5 @@ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
297
  gr.Examples(EXAMPLES, [chat_input])
298
 
299
  if __name__ == "__main__":
300
- demo.queue(api_open=False).launch(show_api=False, share=False)
301
- import uvicorn
302
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
 
 
 
 
 
8
 
9
+ import pymupdf
10
  import docx
11
  from pptx import Presentation
12
 
13
+ from fastapi import FastAPI, File, UploadFile, HTTPException
14
+ from fastapi.responses import HTMLResponse
15
+
16
+ app = FastAPI()
17
+
18
+ @app.post("/test/")
19
+ async def test_endpoint(message: dict):
20
+ if "text" not in message:
21
+ raise HTTPException(status_code=400, detail="Missing 'text' in request body")
22
+
23
+ response = {"message": f"Received your message: {message['text']}"}
24
+ return response
25
+
26
+
27
  MODEL_LIST = ["nikravan/glm-4vq"]
28
+
29
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
30
  MODEL_ID = MODEL_LIST[0]
31
  MODEL_NAME = "GLM-4vq"
 
46
  }
47
  """
48
 
49
+
50
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
51
 
52
+
53
+
54
  def extract_text(path):
55
  return open(path, 'r').read()
56
 
57
+
58
  def extract_pdf(path):
59
+ doc = pymupdf.open(path)
60
  text = ""
61
  for page in doc:
62
  text += page.get_text()
63
  return text
64
 
65
+
66
  def extract_docx(path):
67
  doc = docx.Document(path)
68
  data = []
 
71
  content = '\n\n'.join(data)
72
  return content
73
 
74
+
75
  def extract_pptx(path):
76
  prs = Presentation(path)
77
  text = ""
 
81
  text += shape.text + "\n"
82
  return text
83
 
84
+
85
  def mode_load(path):
86
  choice = ""
87
  file_type = path.split(".")[-1]
 
99
  print(content[:100])
100
  return choice, content[:5000]
101
 
102
+
103
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
104
  content = Image.open(path).convert('RGB')
105
  choice = "image"
 
108
  else:
109
  raise gr.Error("Oops, unsupported files.")
110
 
111
+
112
  @spaces.GPU()
113
  def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
114
 
 
132
  conversation.append({"role": "user", "content": format_msg})
133
  else:
134
  if len(history) == 0:
135
+ # raise gr.Error("Please upload an image first.")
136
  contents = None
137
  conversation.append({"role": "user", "content": message['text']})
138
  else:
139
+ # image = Image.open(history[0][0][0])
140
  for prompt, answer in history:
141
  if answer is None:
142
  prompt_files.append(prompt[0])
 
149
  choice = ""
150
  conversation.append({"role": "user", "image": "", "content": message['text']})
151
 
152
+
153
  if choice == "image":
154
  conversation.append({"role": "user", "image": contents, "content": message['text']})
155
  elif choice == "doc":
 
181
  buffer += new_text
182
  yield buffer
183
 
184
+
185
+ chatbot = gr.Chatbot(
186
+ #rtl=True,
187
+ )
188
  chat_input = gr.MultimodalTextbox(
189
  interactive=True,
190
  placeholder="Enter message or upload a file ...",
191
  show_label=False,
192
+ #rtl=True,
193
+
194
+
195
+
196
  )
197
 
198
  EXAMPLES = [
 
202
  [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
203
  ]
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
206
  gr.HTML(TITLE)
207
  gr.HTML(DESCRIPTION)
208
  gr.ChatInterface(
209
  fn=stream_chat,
210
  multimodal=True,
211
+
212
+
213
  textbox=chat_input,
214
  chatbot=chatbot,
215
  fill_height=True,
 
260
  gr.Examples(EXAMPLES, [chat_input])
261
 
262
  if __name__ == "__main__":
263
+
264
+ demo.queue(api_open=False).launch(show_api=False, share=False, )#server_name="0.0.0.0", )