mateoluksenberg commited on
Commit
9ebe610
verified
1 Parent(s): aba71d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -71
app.py CHANGED
@@ -5,8 +5,6 @@ import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
- from fastapi import FastAPI
9
- import uvicorn
10
 
11
  import pymupdf
12
  import docx
@@ -35,11 +33,15 @@ h1 {
35
  }
36
  """
37
 
 
38
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
39
 
 
 
40
  def extract_text(path):
41
  return open(path, 'r').read()
42
 
 
43
  def extract_pdf(path):
44
  doc = pymupdf.open(path)
45
  text = ""
@@ -47,6 +49,7 @@ def extract_pdf(path):
47
  text += page.get_text()
48
  return text
49
 
 
50
  def extract_docx(path):
51
  doc = docx.Document(path)
52
  data = []
@@ -55,6 +58,7 @@ def extract_docx(path):
55
  content = '\n\n'.join(data)
56
  return content
57
 
 
58
  def extract_pptx(path):
59
  prs = Presentation(path)
60
  text = ""
@@ -64,6 +68,7 @@ def extract_pptx(path):
64
  text += shape.text + "\n"
65
  return text
66
 
 
67
  def mode_load(path):
68
  choice = ""
69
  file_type = path.split(".")[-1]
@@ -80,15 +85,20 @@ def mode_load(path):
80
  choice = "doc"
81
  print(content[:100])
82
  return choice, content[:5000]
 
 
83
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
84
  content = Image.open(path).convert('RGB')
85
  choice = "image"
86
  return choice, content
 
87
  else:
88
  raise gr.Error("Oops, unsupported files.")
89
 
 
90
  @spaces.GPU()
91
  def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
 
92
  model = AutoModelForCausalLM.from_pretrained(
93
  MODEL_ID,
94
  torch_dtype=torch.bfloat16,
@@ -126,6 +136,7 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
126
  choice = ""
127
  conversation.append({"role": "user", "image": "", "content": message['text']})
128
 
 
129
  if choice == "image":
130
  conversation.append({"role": "user", "image": contents, "content": message['text']})
131
  elif choice == "doc":
@@ -157,11 +168,18 @@ def stream_chat(message, history: list, temperature: float, max_length: int, top
157
  buffer += new_text
158
  yield buffer
159
 
160
- chatbot = gr.Chatbot()
 
 
 
161
  chat_input = gr.MultimodalTextbox(
162
  interactive=True,
163
  placeholder="Enter message or upload a file ...",
164
  show_label=False,
 
 
 
 
165
  )
166
 
167
  EXAMPLES = [
@@ -171,73 +189,63 @@ EXAMPLES = [
171
  [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores num茅ricos son decimales de cuatro d铆gitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vac铆os, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicci贸n con sus respectivos valores.", }]
172
  ]
173
 
174
- app = FastAPI()
175
-
176
- def test():
177
- return "Funci贸n test llamada con 茅xito"
178
-
179
- @app.get("/test")
180
- def call_test():
181
- return {"message": test()}
182
-
183
- def run_gradio():
184
- with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
185
- gr.HTML(TITLE)
186
- gr.HTML(DESCRIPTION)
187
- gr.ChatInterface(
188
- fn=stream_chat,
189
- multimodal=True,
190
- textbox=chat_input,
191
- chatbot=chatbot,
192
- fill_height=True,
193
- additional_inputs_accordion=gr.Accordion(label="鈿欙笍 Parameters", open=False, render=False),
194
- additional_inputs=[
195
- gr.Slider(
196
- minimum=0,
197
- maximum=1,
198
- step=0.1,
199
- value=0.8,
200
- label="Temperature",
201
- render=False,
202
- ),
203
- gr.Slider(
204
- minimum=1024,
205
- maximum=8192,
206
- step=1,
207
- value=4096,
208
- label="Max Length",
209
- render=False,
210
- ),
211
- gr.Slider(
212
- minimum=0.0,
213
- maximum=1.0,
214
- step=0.1,
215
- value=1.0,
216
- label="top_p",
217
- render=False,
218
- ),
219
- gr.Slider(
220
- minimum=1,
221
- maximum=20,
222
- step=1,
223
- value=10,
224
- label="top_k",
225
- render=False,
226
- ),
227
- gr.Slider(
228
- minimum=0.0,
229
- maximum=2.0,
230
- step=0.1,
231
- value=1.0,
232
- label="Repetition penalty",
233
- render=False,
234
- ),
235
- ],
236
- ),
237
- gr.Examples(EXAMPLES, [chat_input])
238
- demo.queue(api_open=False).launch(show_api=False, share=False)
239
 
240
  if __name__ == "__main__":
241
- gradio_thread = Thread(target=run_gradio)
242
- gradio_thread.start()
243
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
 
 
8
 
9
  import pymupdf
10
  import docx
 
33
  }
34
  """
35
 
36
+
37
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
38
 
39
+
40
+
41
  def extract_text(path):
42
  return open(path, 'r').read()
43
 
44
+
45
  def extract_pdf(path):
46
  doc = pymupdf.open(path)
47
  text = ""
 
49
  text += page.get_text()
50
  return text
51
 
52
+
53
  def extract_docx(path):
54
  doc = docx.Document(path)
55
  data = []
 
58
  content = '\n\n'.join(data)
59
  return content
60
 
61
+
62
  def extract_pptx(path):
63
  prs = Presentation(path)
64
  text = ""
 
68
  text += shape.text + "\n"
69
  return text
70
 
71
+
72
  def mode_load(path):
73
  choice = ""
74
  file_type = path.split(".")[-1]
 
85
  choice = "doc"
86
  print(content[:100])
87
  return choice, content[:5000]
88
+
89
+
90
  elif file_type in ["png", "jpg", "jpeg", "bmp", "tiff", "webp"]:
91
  content = Image.open(path).convert('RGB')
92
  choice = "image"
93
  return choice, content
94
+
95
  else:
96
  raise gr.Error("Oops, unsupported files.")
97
 
98
+
99
  @spaces.GPU()
100
  def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
101
+
102
  model = AutoModelForCausalLM.from_pretrained(
103
  MODEL_ID,
104
  torch_dtype=torch.bfloat16,
 
136
  choice = ""
137
  conversation.append({"role": "user", "image": "", "content": message['text']})
138
 
139
+
140
  if choice == "image":
141
  conversation.append({"role": "user", "image": contents, "content": message['text']})
142
  elif choice == "doc":
 
168
  buffer += new_text
169
  yield buffer
170
 
171
+
172
+ chatbot = gr.Chatbot(
173
+ #rtl=True,
174
+ )
175
  chat_input = gr.MultimodalTextbox(
176
  interactive=True,
177
  placeholder="Enter message or upload a file ...",
178
  show_label=False,
179
+ #rtl=True,
180
+
181
+
182
+
183
  )
184
 
185
  EXAMPLES = [
 
189
  [{"text": "Quiero armar un JSON, solo el JSON sin texto, que contenga los datos de la primera mitad de la tabla de la imagen (las primeras 10 jurisdicciones 901-910). Ten en cuenta que los valores num茅ricos son decimales de cuatro d铆gitos. La tabla contiene las siguientes columnas: Codigo, Nombre, Fecha Inicio, Fecha Cese, Coeficiente Ingresos, Coeficiente Gastos y Coeficiente Unificado. La tabla puede contener valores vac铆os, en ese caso dejarlos como null. Cada fila de la tabla representa una jurisdicci贸n con sus respectivos valores.", }]
190
  ]
191
 
192
+ with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
193
+ gr.HTML(TITLE)
194
+ gr.HTML(DESCRIPTION)
195
+ gr.ChatInterface(
196
+ fn=stream_chat,
197
+ multimodal=True,
198
+
199
+
200
+ textbox=chat_input,
201
+ chatbot=chatbot,
202
+ fill_height=True,
203
+ additional_inputs_accordion=gr.Accordion(label="鈿欙笍 Parameters", open=False, render=False),
204
+ additional_inputs=[
205
+ gr.Slider(
206
+ minimum=0,
207
+ maximum=1,
208
+ step=0.1,
209
+ value=0.8,
210
+ label="Temperature",
211
+ render=False,
212
+ ),
213
+ gr.Slider(
214
+ minimum=1024,
215
+ maximum=8192,
216
+ step=1,
217
+ value=4096,
218
+ label="Max Length",
219
+ render=False,
220
+ ),
221
+ gr.Slider(
222
+ minimum=0.0,
223
+ maximum=1.0,
224
+ step=0.1,
225
+ value=1.0,
226
+ label="top_p",
227
+ render=False,
228
+ ),
229
+ gr.Slider(
230
+ minimum=1,
231
+ maximum=20,
232
+ step=1,
233
+ value=10,
234
+ label="top_k",
235
+ render=False,
236
+ ),
237
+ gr.Slider(
238
+ minimum=0.0,
239
+ maximum=2.0,
240
+ step=0.1,
241
+ value=1.0,
242
+ label="Repetition penalty",
243
+ render=False,
244
+ ),
245
+ ],
246
+ ),
247
+ gr.Examples(EXAMPLES, [chat_input])
 
 
 
 
 
 
 
 
 
248
 
249
  if __name__ == "__main__":
250
+
251
+ demo.queue(api_open=False).launch(show_api=False, share=False, )#server_name="0.0.0.0", )