hantech commited on
Commit
c8a9051
·
verified ·
1 Parent(s): f7fe59b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -7,21 +7,20 @@ from pdf2image import convert_from_path
7
  from PIL import Image
8
  from torch.utils.data import DataLoader
9
  from tqdm import tqdm
10
- import os
11
 
 
 
 
12
  if not os.path.exists('/tmp/gradio'):
13
  os.makedirs('/tmp/gradio')
14
-
15
- from colpali_engine.models import ColQwen2, ColQwen2Processor
16
 
17
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
 
19
  model = ColQwen2.from_pretrained(
20
- "manu/colqwen2-v1.0-alpha",
21
- torch_dtype=torch.bfloat16,
22
- device_map=device, # or "mps" if on Apple Silicon
23
- # attn_implementation="flash_attention_2", # should work on A100
24
- ).eval()
25
  processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")
26
 
27
  def search(query: str, ds, images, k):
@@ -49,14 +48,16 @@ def search(query: str, ds, images, k):
49
 
50
  return results
51
 
52
-
53
  def index(files, ds):
 
 
 
54
  print("Converting files")
55
  images = convert_files(files)
56
  print(f"Files converted with {len(images)} images.")
57
- return index_gpu(images, ds)
58
-
59
-
60
 
61
  def convert_files(files):
62
  images = []
@@ -67,15 +68,11 @@ def convert_files(files):
67
  raise gr.Error("The number of images in the dataset should be less than 150.")
68
  return images
69
 
70
-
71
  def index_gpu(images, ds):
72
- """Example script to run inference with ColPali (ColQwen2)"""
73
-
74
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
75
  if device != model.device:
76
  model.to(device)
77
 
78
- # run inference - docs
79
  dataloader = DataLoader(
80
  images,
81
  batch_size=4,
@@ -90,17 +87,15 @@ def index_gpu(images, ds):
90
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
91
  return f"Uploaded and converted {len(images)} pages", ds, images
92
 
93
-
94
-
95
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
96
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
97
  gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
98
- ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
99
 
100
  This demo allows you to upload PDF files and search for the most relevant pages based on your query.
101
  Refresh the page if you change documents !
102
 
103
- ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.
104
  Other models will be released with better robustness towards different languages and document formats !
105
  """)
106
  with gr.Row():
@@ -118,12 +113,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
118
  query = gr.Textbox(placeholder="Enter your query here", label="Query")
119
  k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
120
 
121
-
122
  # Define the actions
 
123
  search_button = gr.Button("🔍 Search", variant="primary")
124
  output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
125
 
126
- convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
127
  search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
128
 
129
  if __name__ == "__main__":
 
7
  from PIL import Image
8
  from torch.utils.data import DataLoader
9
  from tqdm import tqdm
 
10
 
11
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
12
+
13
+ # Ensure the temporary directory exists
14
  if not os.path.exists('/tmp/gradio'):
15
  os.makedirs('/tmp/gradio')
 
 
16
 
17
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
 
19
  model = ColQwen2.from_pretrained(
20
+ "manu/colqwen2-v1.0-alpha",
21
+ torch_dtype=torch.bfloat16,
22
+ device_map=device,
23
+ ).eval()
 
24
  processor = ColQwen2Processor.from_pretrained("manu/colqwen2-v1.0-alpha")
25
 
26
  def search(query: str, ds, images, k):
 
48
 
49
  return results
50
 
 
51
  def index(files, ds):
52
+ if not files:
53
+ return gr.Error("No files uploaded. Please upload PDF files to index."), ds, []
54
+
55
  print("Converting files")
56
  images = convert_files(files)
57
  print(f"Files converted with {len(images)} images.")
58
+ status_message, ds, images = index_gpu(images, ds)
59
+ print(f"Indexed {len(ds)} embeddings.")
60
+ return status_message, ds, images
61
 
62
  def convert_files(files):
63
  images = []
 
68
  raise gr.Error("The number of images in the dataset should be less than 150.")
69
  return images
70
 
 
71
  def index_gpu(images, ds):
 
 
72
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
73
  if device != model.device:
74
  model.to(device)
75
 
 
76
  dataloader = DataLoader(
77
  images,
78
  batch_size=4,
 
87
  ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
88
  return f"Uploaded and converted {len(images)} pages", ds, images
89
 
 
 
90
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
  gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
92
  gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
93
+ ColPali is a model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
94
 
95
  This demo allows you to upload PDF files and search for the most relevant pages based on your query.
96
  Refresh the page if you change documents !
97
 
98
+ ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages.
99
  Other models will be released with better robustness towards different languages and document formats !
100
  """)
101
  with gr.Row():
 
113
  query = gr.Textbox(placeholder="Enter your query here", label="Query")
114
  k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
115
 
 
116
  # Define the actions
117
+ convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
118
  search_button = gr.Button("🔍 Search", variant="primary")
119
  output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
120
 
 
121
  search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
122
 
123
  if __name__ == "__main__":