mateoluksenberg commited on
Commit
ba0edd3
·
verified ·
1 Parent(s): 1f1f572

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -50
app.py CHANGED
@@ -1,32 +1,60 @@
1
  import torch
2
  from PIL import Image
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  import os
 
 
5
  import pymupdf
6
  import docx
7
  from pptx import Presentation
 
8
  from fastapi import FastAPI, File, UploadFile, HTTPException
9
- from typing import List, Dict
10
 
11
  app = FastAPI()
12
 
13
- # Model and tokenizer initialization
 
 
 
 
 
 
 
 
14
  MODEL_LIST = ["nikravan/glm-4vq"]
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  MODEL_ID = MODEL_LIST[0]
17
  MODEL_NAME = "GLM-4vq"
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- MODEL_ID,
22
- torch_dtype=torch.bfloat16,
23
- low_cpu_mem_usage=True,
24
- trust_remote_code=True
25
- )
26
 
27
  def extract_text(path):
28
  return open(path, 'r').read()
29
 
 
30
  def extract_pdf(path):
31
  doc = pymupdf.open(path)
32
  text = ""
@@ -34,10 +62,15 @@ def extract_pdf(path):
34
  text += page.get_text()
35
  return text
36
 
 
37
  def extract_docx(path):
38
  doc = docx.Document(path)
39
- data = [paragraph.text for paragraph in doc.paragraphs]
40
- return '\n\n'.join(data)
 
 
 
 
41
 
42
  def extract_pptx(path):
43
  prs = Presentation(path)
@@ -48,44 +81,49 @@ def extract_pptx(path):
48
  text += shape.text + "\n"
49
  return text
50
 
 
51
  def mode_load(path):
52
- file_type = path.split(".")[-1].lower()
53
- if file_type in ["pdf", "txt", "py", "docx", "pptx"]:
54
- if file_type == "pdf":
 
 
55
  content = extract_pdf(path)
56
- elif file_type == "docx":
57
  content = extract_docx(path)
58
- elif file_type == "pptx":
59
  content = extract_pptx(path)
60
  else:
61
  content = extract_text(path)
62
- return "doc", content[:5000]
 
 
 
 
63
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
64
  content = Image.open(path).convert('RGB')
65
- return "image", content
 
 
66
  else:
67
- raise HTTPException(status_code=400, detail="Unsupported file type")
68
 
69
- @app.post("/test/")
70
- async def test_endpoint(message: Dict[str, str]):
71
- if "text" not in message:
72
- raise HTTPException(status_code=400, detail="Missing 'text' in request body")
73
-
74
- response = {"message": f"Received your message: {message['text']}"}
75
- return response
76
 
77
- @app.post("/chat/")
78
- async def chat_endpoint(
79
- message: Dict[str, str],
80
- history: List[Dict[str, str]] = [],
81
- temperature: float = 0.8,
82
- max_length: int = 4096,
83
- top_p: float = 1.0,
84
- top_k: int = 10,
85
- penalty: float = 1.0
86
- ):
 
 
87
  conversation = []
88
- if "files" in message and message["files"]:
 
89
  choice, contents = mode_load(message["files"][-1])
90
  if choice == "image":
91
  conversation.append({"role": "user", "image": contents, "content": message['text']})
@@ -94,26 +132,35 @@ async def chat_endpoint(
94
  conversation.append({"role": "user", "content": format_msg})
95
  else:
96
  if len(history) == 0:
 
 
97
  conversation.append({"role": "user", "content": message['text']})
98
  else:
 
99
  for prompt, answer in history:
100
  if answer is None:
 
101
  conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
102
  else:
103
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
104
- if len(history) > 0:
105
- choice, contents = mode_load(history[-1][0])
106
- if choice == "image":
107
- conversation.append({"role": "user", "image": contents, "content": message['text']})
108
- elif choice == "doc":
109
- format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
110
- conversation.append({"role": "user", "content": format_msg})
111
- else:
112
- conversation.append({"role": "user", "content": message['text']})
113
 
114
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
 
 
 
 
 
 
 
 
 
115
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
116
-
117
  generate_kwargs = dict(
118
  max_length=max_length,
119
  streamer=streamer,
@@ -121,11 +168,97 @@ async def chat_endpoint(
121
  top_p=top_p,
122
  top_k=top_k,
123
  temperature=temperature,
124
- repetition_penalty=penalty
 
125
  )
126
-
 
127
  with torch.no_grad():
 
 
128
  buffer = ""
129
  for new_text in streamer:
130
  buffer += new_text
131
- return {"response": buffer}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from PIL import Image
3
+ import gradio as gr
4
+ import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
+ from threading import Thread
8
+
9
  import pymupdf
10
  import docx
11
  from pptx import Presentation
12
+
13
  from fastapi import FastAPI, File, UploadFile, HTTPException
14
+ from fastapi.responses import HTMLResponse
15
 
16
  app = FastAPI()
17
 
18
+ @app.post("/test/")
19
+ async def test_endpoint(message: dict):
20
+ if "text" not in message:
21
+ raise HTTPException(status_code=400, detail="Missing 'text' in request body")
22
+
23
+ response = {"message": f"Received your message: {message['text']}"}
24
+ return response
25
+
26
+
27
  MODEL_LIST = ["nikravan/glm-4vq"]
28
+
29
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
30
  MODEL_ID = MODEL_LIST[0]
31
  MODEL_NAME = "GLM-4vq"
32
 
33
+ TITLE = "<h1>AI CHAT DOCS</h1>"
34
+
35
+ DESCRIPTION = f"""
36
+ <center>
37
+ <p>
38
+ <br>
39
+ USANDO MODELO: <a href="https://hf.co/nikravan/glm-4vq">{MODEL_NAME}</a>
40
+ </center>"""
41
+
42
+ CSS = """
43
+ h1 {
44
+ text-align: center;
45
+ display: block;
46
+ }
47
+ """
48
+
49
+
50
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
51
+
52
+
 
 
 
 
53
 
54
  def extract_text(path):
55
  return open(path, 'r').read()
56
 
57
+
58
  def extract_pdf(path):
59
  doc = pymupdf.open(path)
60
  text = ""
 
62
  text += page.get_text()
63
  return text
64
 
65
+
66
  def extract_docx(path):
67
  doc = docx.Document(path)
68
+ data = []
69
+ for paragraph in doc.paragraphs:
70
+ data.append(paragraph.text)
71
+ content = '\n\n'.join(data)
72
+ return content
73
+
74
 
75
  def extract_pptx(path):
76
  prs = Presentation(path)
 
81
  text += shape.text + "\n"
82
  return text
83
 
84
+
85
  def mode_load(path):
86
+ choice = ""
87
+ file_type = path.split(".")[-1]
88
+ print(file_type)
89
+ if file_type in ["pdf", "txt", "py", "docx", "pptx", "json", "cpp", "md"]:
90
+ if file_type.endswith("pdf"):
91
  content = extract_pdf(path)
92
+ elif file_type.endswith("docx"):
93
  content = extract_docx(path)
94
+ elif file_type.endswith("pptx"):
95
  content = extract_pptx(path)
96
  else:
97
  content = extract_text(path)
98
+ choice = "doc"
99
+ print(content[:100])
100
+ return choice, content[:5000]
101
+
102
+
103
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
104
  content = Image.open(path).convert('RGB')
105
+ choice = "image"
106
+ return choice, content
107
+
108
  else:
109
+ raise gr.Error("Oops, unsupported files.")
110
 
 
 
 
 
 
 
 
111
 
112
+ @spaces.GPU()
113
+ def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
114
+
115
+ model = AutoModelForCausalLM.from_pretrained(
116
+ MODEL_ID,
117
+ torch_dtype=torch.bfloat16,
118
+ low_cpu_mem_usage=True,
119
+ trust_remote_code=True
120
+ )
121
+
122
+ print(f'message is - {message}')
123
+ print(f'history is - {history}')
124
  conversation = []
125
+ prompt_files = []
126
+ if message["files"]:
127
  choice, contents = mode_load(message["files"][-1])
128
  if choice == "image":
129
  conversation.append({"role": "user", "image": contents, "content": message['text']})
 
132
  conversation.append({"role": "user", "content": format_msg})
133
  else:
134
  if len(history) == 0:
135
+ # raise gr.Error("Please upload an image first.")
136
+ contents = None
137
  conversation.append({"role": "user", "content": message['text']})
138
  else:
139
+ # image = Image.open(history[0][0][0])
140
  for prompt, answer in history:
141
  if answer is None:
142
+ prompt_files.append(prompt[0])
143
  conversation.extend([{"role": "user", "content": ""}, {"role": "assistant", "content": ""}])
144
  else:
145
  conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
146
+ if len(prompt_files) > 0:
147
+ choice, contents = mode_load(prompt_files[-1])
148
+ else:
149
+ choice = ""
150
+ conversation.append({"role": "user", "image": "", "content": message['text']})
 
 
 
 
151
 
152
+
153
+ if choice == "image":
154
+ conversation.append({"role": "user", "image": contents, "content": message['text']})
155
+ elif choice == "doc":
156
+ format_msg = contents + "\n\n\n" + "{} files uploaded.\n" + message['text']
157
+ conversation.append({"role": "user", "content": format_msg})
158
+ print(f"Conversation is -\n{conversation}")
159
+
160
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
161
+ return_tensors="pt", return_dict=True).to(model.device)
162
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
163
+
164
  generate_kwargs = dict(
165
  max_length=max_length,
166
  streamer=streamer,
 
168
  top_p=top_p,
169
  top_k=top_k,
170
  temperature=temperature,
171
+ repetition_penalty=penalty,
172
+ eos_token_id=[151329, 151336, 151338],
173
  )
174
+ gen_kwargs = {**input_ids, **generate_kwargs}
175
+
176
  with torch.no_grad():
177
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
178
+ thread.start()
179
  buffer = ""
180
  for new_text in streamer:
181
  buffer += new_text
182
+ yield buffer
183
+
184
+
185
+ chatbot = gr.Chatbot(
186
+ #rtl=True,
187
+ )
188
+ chat_input = gr.MultimodalTextbox(
189
+ interactive=True,
190
+ placeholder="Enter message or upload a file ...",
191
+ show_label=False,
192
+ #rtl=True,
193
+
194
+
195
+
196
+ )
197
+
198
+ EXAMPLES = [
199
+ [{"text": "Resumir Documento"}],
200
+ [{"text": "Explicar la Imagen"}],
201
+ [{"text": "¿De qué es la foto?", "files": ["perro.jpg"]}],
202
+ [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores numéricos son decimales de cuatro dígitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vacíos, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicción con sus respectivos valores.", }]
203
+ ]
204
+
205
+ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
206
+ gr.HTML(TITLE)
207
+ gr.HTML(DESCRIPTION)
208
+ gr.ChatInterface(
209
+ fn=stream_chat,
210
+ multimodal=True,
211
+
212
+
213
+ textbox=chat_input,
214
+ chatbot=chatbot,
215
+ fill_height=True,
216
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
217
+ additional_inputs=[
218
+ gr.Slider(
219
+ minimum=0,
220
+ maximum=1,
221
+ step=0.1,
222
+ value=0.8,
223
+ label="Temperature",
224
+ render=False,
225
+ ),
226
+ gr.Slider(
227
+ minimum=1024,
228
+ maximum=8192,
229
+ step=1,
230
+ value=4096,
231
+ label="Max Length",
232
+ render=False,
233
+ ),
234
+ gr.Slider(
235
+ minimum=0.0,
236
+ maximum=1.0,
237
+ step=0.1,
238
+ value=1.0,
239
+ label="top_p",
240
+ render=False,
241
+ ),
242
+ gr.Slider(
243
+ minimum=1,
244
+ maximum=20,
245
+ step=1,
246
+ value=10,
247
+ label="top_k",
248
+ render=False,
249
+ ),
250
+ gr.Slider(
251
+ minimum=0.0,
252
+ maximum=2.0,
253
+ step=0.1,
254
+ value=1.0,
255
+ label="Repetition penalty",
256
+ render=False,
257
+ ),
258
+ ],
259
+ ),
260
+ gr.Examples(EXAMPLES, [chat_input])
261
+
262
+ if __name__ == "__main__":
263
+
264
+ demo.queue(api_open=False).launch(show_api=False, share=False, )#server_name="0.0.0.0", )