ColPali-multi / app.py
ginipick's picture
Update app.py
45b6f79 verified
raw
history blame
9.86 kB
import subprocess # ๐Ÿฅฒ
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import spaces
import gradio as gr
import re
import torch
import os
import json
import time
from pydantic import BaseModel
from typing import Tuple
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# ----------------------- ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ ----------------------- #
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# ----------------------- Pydantic ๋ชจ๋ธ ์ •์˜ ----------------------- #
class GeneralRetrievalQuery(BaseModel):
broad_topical_query: str
broad_topical_explanation: str
specific_detail_query: str
specific_detail_explanation: str
visual_element_query: str
visual_element_explanation: str
def extract_json_with_regex(text):
pattern = r'```(?:json)?\s*(.+?)\s*```'
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[0]
return None
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
if prompt_name != "general":
raise ValueError("Only 'general' prompt is available in this version")
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
Please generate 3 different types of retrieval queries:
1. A broad topical query: This should cover the main subject of the document.
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to.
Important guidelines:
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
- Frame the queries as if someone is searching for this document, not asking questions about its content.
- Make the queries diverse and representative of different search strategies.
For each query, also provide a brief explanation of why this query would be effective in retrieving this document.
Format your response as a JSON object with the following structure:
{
"broad_topical_query": "Your query here",
"broad_topical_explanation": "Brief explanation",
"specific_detail_query": "Your query here",
"specific_detail_explanation": "Brief explanation",
"visual_element_query": "Your query here",
"visual_element_explanation": "Brief explanation"
}
If there are no relevant visual elements, replace the third query with another specific detail query.
Here is the document image to analyze:
<image>
Generate the queries based on this image and provide the response in the specified JSON format."""
return prompt, GeneralRetrievalQuery
prompt, pydantic_model = get_retrieval_prompt("general")
# ----------------------- ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ----------------------- #
def _prep_data_for_input(image):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# ----------------------- ์ถœ๋ ฅ ํ˜•์‹ ๋ณ€ํ™˜ ํ•จ์ˆ˜ ----------------------- #
def format_output(data: dict, output_format: str) -> str:
"""
data: ํŒŒ์‹ฑ๋œ JSON ๋”•์…”๋„ˆ๋ฆฌ
output_format: "JSON", "Markdown", "Table" ์ค‘ ํ•˜๋‚˜
"""
if output_format == "JSON":
return json.dumps(data, indent=2, ensure_ascii=False)
elif output_format == "Markdown":
# ๊ฐ ํ•ญ๋ชฉ์„ Markdown ๋ฌธ๋‹จ ํ˜•์‹์œผ๋กœ ์ถœ๋ ฅ
md_lines = []
for key, value in data.items():
md_lines.append(f"**{key.replace('_', ' ').title()}:** {value}")
return "\n\n".join(md_lines)
elif output_format == "Table":
# ๊ฐ„๋‹จํ•œ Markdown ํ‘œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
headers = ["Field", "Content"]
separator = "|".join(["---"] * len(headers))
rows = [f"| {' | '.join(headers)} |", f"|{separator}|"]
for key, value in data.items():
rows.append(f"| {key.replace('_', ' ').title()} | {value} |")
return "\n".join(rows)
else:
return json.dumps(data, indent=2, ensure_ascii=False)
# ----------------------- ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜ ----------------------- #
@spaces.GPU
def generate_response(image, output_format: str = "JSON"):
inputs = _prep_data_for_input(image)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=200)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
try:
json_str = extract_json_with_regex(output_text)
if json_str:
parsed = json.loads(json_str)
return format_output(parsed, output_format)
parsed = json.loads(output_text)
return format_output(parsed, output_format)
except Exception:
gr.Warning("Failed to parse JSON from output")
return output_text
# ----------------------- ์ธํ„ฐํŽ˜์ด์Šค ์ œ๋ชฉ ๋ฐ ์„ค๋ช… ----------------------- #
title = "Elegant ColPali Query Generator using Qwen2.5-VL"
description = """**ColPali**๋Š” ๋ฌธ์„œ ๊ฒ€์ƒ‰์— ์ตœ์ ํ™”๋œ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ ‘๊ทผ๋ฒ•์ž…๋‹ˆ๋‹ค.
์ด ์ธํ„ฐํŽ˜์ด์Šค๋Š” [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ, ๋ฌธ์„œ ์ด๋ฏธ์ง€๋กœ๋ถ€ํ„ฐ ๊ด€๋ จ ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
- **Broad Topical Query:** ๋ฌธ์„œ์˜ ์ฃผ์š” ์ฃผ์ œ๋ฅผ ํฌ๊ด„ํ•˜๋Š” ์ฟผ๋ฆฌ
- **Specific Detail Query:** ๋ฌธ์„œ ๋‚ด ํŠน์ • ์‚ฌ์‹ค์ด๋‚˜ ์ˆ˜์น˜๋ฅผ ํฌํ•จํ•œ ์ฟผ๋ฆฌ
- **Visual Element Query:** ๋ฌธ์„œ์˜ ์‹œ๊ฐ์  ์š”์†Œ(์˜ˆ: ์ฐจํŠธ, ๊ทธ๋ž˜ํ”„ ๋“ฑ)๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ์ฟผ๋ฆฌ
์•„๋ž˜ ์˜ˆ์ œ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ, ๋ฌธ์„œ ์ด๋ฏธ์ง€์— ์ ํ•ฉํ•œ ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•ด ๋ณด์„ธ์š”.
๋” ์ž์„ธํ•œ ์ •๋ณด๋Š” [๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html)๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
"""
examples = [
"examples/Approche_no_13_1977.pdf_page_22.jpg",
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
]
# ----------------------- ์ปค์Šคํ…€ CSS ----------------------- #
custom_css = """
body {
background: #f7f9fb;
font-family: 'Segoe UI', sans-serif;
color: #333;
}
header {
text-align: center;
padding: 20px;
margin-bottom: 20px;
}
header h1 {
font-size: 3em;
color: #2c3e50;
}
.gradio-container {
padding: 20px;
}
.gr-button {
background-color: #3498db !important;
color: #fff !important;
border: none !important;
padding: 10px 20px !important;
border-radius: 5px !important;
font-size: 1em !important;
}
.gr-button:hover {
background-color: #2980b9 !important;
}
.gr-gallery-item {
border-radius: 10px;
overflow: hidden;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
footer {
text-align: center;
padding: 20px 0;
font-size: 0.9em;
color: #555;
}
"""
# ----------------------- Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ----------------------- #
with gr.Blocks(css=custom_css, title=title) as demo:
with gr.Column(variant="panel"):
gr.Markdown(f"<header><h1>{title}</h1></header>")
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Query Generation"):
gr.Markdown("### Generate Retrieval Queries from a Document Image")
with gr.Row():
image_input = gr.Image(label="Upload Document Image", type="pil")
with gr.Row():
# ์ถœ๋ ฅ ํ˜•์‹ ์„ ํƒ ์˜ต์…˜ ์ถ”๊ฐ€
output_format = gr.Radio(
choices=["JSON", "Markdown", "Table"],
value="JSON",
label="Output Format",
info="Select the desired output format."
)
generate_button = gr.Button("Generate Query")
output_text = gr.Textbox(label="Generated Query", lines=10)
with gr.Accordion("Examples", open=False):
gr.Examples(
label="Query Examples",
examples=[
"examples/Approche_no_13_1977.pdf_page_22.jpg",
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
],
inputs=image_input,
)
generate_button.click(
fn=generate_response,
inputs=[image_input, output_format],
outputs=output_text
)
gr.Markdown("<footer>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
demo.launch()