import subprocess # ๐Ÿฅฒ subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import spaces import gradio as gr import re import torch import os import json import time from pydantic import BaseModel from typing import Tuple from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info from PIL import Image os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # ----------------------- ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ ----------------------- # model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") # ----------------------- Pydantic ๋ชจ๋ธ ์ •์˜ ----------------------- # class GeneralRetrievalQuery(BaseModel): broad_topical_query: str broad_topical_explanation: str specific_detail_query: str specific_detail_explanation: str visual_element_query: str visual_element_explanation: str def extract_json_with_regex(text): pattern = r'```(?:json)?\s*(.+?)\s*```' matches = re.findall(pattern, text, re.DOTALL) if matches: return matches[0] return None def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]: if prompt_name != "general": raise ValueError("Only 'general' prompt is available in this version") prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus. Please generate 3 different types of retrieval queries: 1. A broad topical query: This should cover the main subject of the document. 2. A specific detail query: This should focus on a particular fact, figure, or point made in the document. 3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to. Important guidelines: - Ensure the queries are relevant for retrieval tasks, not just describing the page content. - Frame the queries as if someone is searching for this document, not asking questions about its content. - Make the queries diverse and representative of different search strategies. For each query, also provide a brief explanation of why this query would be effective in retrieving this document. Format your response as a JSON object with the following structure: { "broad_topical_query": "Your query here", "broad_topical_explanation": "Brief explanation", "specific_detail_query": "Your query here", "specific_detail_explanation": "Brief explanation", "visual_element_query": "Your query here", "visual_element_explanation": "Brief explanation" } If there are no relevant visual elements, replace the third query with another specific detail query. Here is the document image to analyze: Generate the queries based on this image and provide the response in the specified JSON format.""" return prompt, GeneralRetrievalQuery prompt, pydantic_model = get_retrieval_prompt("general") # ----------------------- ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ----------------------- # def _prep_data_for_input(image): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) return processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # ----------------------- ์ถœ๋ ฅ ํ˜•์‹ ๋ณ€ํ™˜ ํ•จ์ˆ˜ ----------------------- # def format_output(data: dict, output_format: str) -> str: """ data: ํŒŒ์‹ฑ๋œ JSON ๋”•์…”๋„ˆ๋ฆฌ output_format: "JSON", "Markdown", "Table" ์ค‘ ํ•˜๋‚˜ """ if output_format == "JSON": return json.dumps(data, indent=2, ensure_ascii=False) elif output_format == "Markdown": # ๊ฐ ํ•ญ๋ชฉ์„ Markdown ๋ฌธ๋‹จ ํ˜•์‹์œผ๋กœ ์ถœ๋ ฅ md_lines = [] for key, value in data.items(): md_lines.append(f"**{key.replace('_', ' ').title()}:** {value}") return "\n\n".join(md_lines) elif output_format == "Table": # ๊ฐ„๋‹จํ•œ Markdown ํ‘œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ headers = ["Field", "Content"] separator = "|".join(["---"] * len(headers)) rows = [f"| {' | '.join(headers)} |", f"|{separator}|"] for key, value in data.items(): rows.append(f"| {key.replace('_', ' ').title()} | {value} |") return "\n".join(rows) else: return json.dumps(data, indent=2, ensure_ascii=False) # ----------------------- ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜ ----------------------- # @spaces.GPU def generate_response(image, output_format: str = "JSON"): inputs = _prep_data_for_input(image) inputs = inputs.to("cuda") generated_ids = model.generate(**inputs, max_new_tokens=200) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] try: json_str = extract_json_with_regex(output_text) if json_str: parsed = json.loads(json_str) return format_output(parsed, output_format) parsed = json.loads(output_text) return format_output(parsed, output_format) except Exception: gr.Warning("Failed to parse JSON from output") return output_text # ----------------------- ์ธํ„ฐํŽ˜์ด์Šค ์ œ๋ชฉ ๋ฐ ์„ค๋ช… ----------------------- # title = "Elegant ColPali Query Generator using Qwen2.5-VL" description = """**ColPali**๋Š” ๋ฌธ์„œ ๊ฒ€์ƒ‰์— ์ตœ์ ํ™”๋œ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ ‘๊ทผ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด ์ธํ„ฐํŽ˜์ด์Šค๋Š” [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ, ๋ฌธ์„œ ์ด๋ฏธ์ง€๋กœ๋ถ€ํ„ฐ ๊ด€๋ จ ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. - **Broad Topical Query:** ๋ฌธ์„œ์˜ ์ฃผ์š” ์ฃผ์ œ๋ฅผ ํฌ๊ด„ํ•˜๋Š” ์ฟผ๋ฆฌ - **Specific Detail Query:** ๋ฌธ์„œ ๋‚ด ํŠน์ • ์‚ฌ์‹ค์ด๋‚˜ ์ˆ˜์น˜๋ฅผ ํฌํ•จํ•œ ์ฟผ๋ฆฌ - **Visual Element Query:** ๋ฌธ์„œ์˜ ์‹œ๊ฐ์  ์š”์†Œ(์˜ˆ: ์ฐจํŠธ, ๊ทธ๋ž˜ํ”„ ๋“ฑ)๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ์ฟผ๋ฆฌ ์•„๋ž˜ ์˜ˆ์ œ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ, ๋ฌธ์„œ ์ด๋ฏธ์ง€์— ์ ํ•ฉํ•œ ์ฟผ๋ฆฌ๋ฅผ ์ƒ์„ฑํ•ด ๋ณด์„ธ์š”. ๋” ์ž์„ธํ•œ ์ •๋ณด๋Š” [๋ธ”๋กœ๊ทธ ํฌ์ŠคํŠธ](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html)๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”. """ examples = [ "examples/Approche_no_13_1977.pdf_page_22.jpg", "examples/SRCCL_Technical-Summary.pdf_page_7.jpg", ] # ----------------------- ์ปค์Šคํ…€ CSS ----------------------- # custom_css = """ body { background: #f7f9fb; font-family: 'Segoe UI', sans-serif; color: #333; } header { text-align: center; padding: 20px; margin-bottom: 20px; } header h1 { font-size: 3em; color: #2c3e50; } .gradio-container { padding: 20px; } .gr-button { background-color: #3498db !important; color: #fff !important; border: none !important; padding: 10px 20px !important; border-radius: 5px !important; font-size: 1em !important; } .gr-button:hover { background-color: #2980b9 !important; } .gr-gallery-item { border-radius: 10px; overflow: hidden; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } footer { text-align: center; padding: 20px 0; font-size: 0.9em; color: #555; } """ # ----------------------- Gradio ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ ----------------------- # with gr.Blocks(css=custom_css, title=title) as demo: with gr.Column(variant="panel"): gr.Markdown(f"

{title}

") gr.Markdown(description) with gr.Tabs(): with gr.TabItem("Query Generation"): gr.Markdown("### Generate Retrieval Queries from a Document Image") with gr.Row(): image_input = gr.Image(label="Upload Document Image", type="pil") with gr.Row(): # ์ถœ๋ ฅ ํ˜•์‹ ์„ ํƒ ์˜ต์…˜ ์ถ”๊ฐ€ output_format = gr.Radio( choices=["JSON", "Markdown", "Table"], value="JSON", label="Output Format", info="Select the desired output format." ) generate_button = gr.Button("Generate Query") output_text = gr.Textbox(label="Generated Query", lines=10) with gr.Accordion("Examples", open=False): gr.Examples( label="Query Examples", examples=[ "examples/Approche_no_13_1977.pdf_page_22.jpg", "examples/SRCCL_Technical-Summary.pdf_page_7.jpg", ], inputs=image_input, ) generate_button.click( fn=generate_response, inputs=[image_input, output_format], outputs=output_text ) gr.Markdown("") demo.launch()