Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update models/qwen.py
Browse files- models/qwen.py +48 -33
    	
        models/qwen.py
    CHANGED
    
    | @@ -3,31 +3,21 @@ import json | |
| 3 | 
             
            from PIL import Image
         | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
         | 
| 6 | 
            -
            from qwen_vl_utils import process_vision_info
         | 
| 7 | 
             
            from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
         | 
|  | |
| 8 |  | 
| 9 | 
            -
            #  | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
                #attn_implementation="flash_attention_2"
         | 
| 15 | 
            -
            )
         | 
| 16 | 
            -
             | 
| 17 | 
            -
            min_pixels = 256 * 28 * 28
         | 
| 18 | 
            -
            max_pixels = 1080 * 28 * 28
         | 
| 19 | 
            -
            processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
         | 
| 20 |  | 
| 21 | 
            -
            # Initialize Pix2Struct OCR model
         | 
| 22 | 
            -
            ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
         | 
| 23 | 
            -
            ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
         | 
| 24 |  | 
| 25 | 
            -
            # Load prompt
         | 
| 26 | 
             
            def load_prompt():
         | 
| 27 | 
            -
                with open("prompts/prompt.txt", "r") as f:
         | 
| 28 | 
             
                    return f.read()
         | 
| 29 |  | 
| 30 | 
            -
             | 
| 31 | 
             
            def try_extract_json(text):
         | 
| 32 | 
             
                try:
         | 
| 33 | 
             
                    return json.loads(text)
         | 
| @@ -50,20 +40,25 @@ def try_extract_json(text): | |
| 50 | 
             
                    except json.JSONDecodeError:
         | 
| 51 | 
             
                        return text
         | 
| 52 |  | 
| 53 | 
            -
             | 
| 54 | 
             
            def extract_all_text_pix2struct(image: Image.Image):
         | 
| 55 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 56 | 
             
                predictions = ocr_model.generate(**inputs, max_new_tokens=512)
         | 
| 57 | 
            -
                 | 
| 58 | 
            -
             | 
| 59 |  | 
| 60 | 
            -
            # Assign event/gateway names from OCR text
         | 
| 61 | 
             
            def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
         | 
| 62 | 
             
                if not ocr_text or not json_data:
         | 
| 63 | 
             
                    return json_data
         | 
| 64 |  | 
| 65 | 
            -
                lines = [line.strip() for line in ocr_text.split('\n') if line.strip()]
         | 
| 66 | 
            -
             | 
| 67 | 
             
                def assign_best_guess(obj):
         | 
| 68 | 
             
                    if not obj.get("name") or obj["name"].strip() == "":
         | 
| 69 | 
             
                        obj["name"] = "(label unknown)"
         | 
| @@ -76,9 +71,29 @@ def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str): | |
| 76 |  | 
| 77 | 
             
                return json_data
         | 
| 78 |  | 
| 79 | 
            -
             | 
| 80 | 
             
            def run_model(image: Image.Image):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 81 | 
             
                prompt = load_prompt()
         | 
|  | |
| 82 | 
             
                messages = [
         | 
| 83 | 
             
                    {
         | 
| 84 | 
             
                        "role": "user",
         | 
| @@ -89,21 +104,21 @@ def run_model(image: Image.Image): | |
| 89 | 
             
                    }
         | 
| 90 | 
             
                ]
         | 
| 91 |  | 
| 92 | 
            -
                text =  | 
| 93 | 
             
                image_inputs, video_inputs = process_vision_info(messages)
         | 
| 94 |  | 
| 95 | 
            -
                inputs =  | 
| 96 | 
             
                    text=[text],
         | 
| 97 | 
             
                    images=image_inputs,
         | 
| 98 | 
             
                    videos=video_inputs,
         | 
| 99 | 
             
                    padding=True,
         | 
| 100 | 
             
                    return_tensors="pt"
         | 
| 101 | 
            -
                ).to( | 
| 102 |  | 
| 103 | 
            -
                generated_ids =  | 
| 104 | 
             
                generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
         | 
| 105 |  | 
| 106 | 
            -
                output_text =  | 
| 107 | 
             
                    generated_ids_trimmed,
         | 
| 108 | 
             
                    skip_special_tokens=True,
         | 
| 109 | 
             
                    clean_up_tokenization_spaces=False
         | 
| @@ -111,11 +126,11 @@ def run_model(image: Image.Image): | |
| 111 |  | 
| 112 | 
             
                parsed_json = try_extract_json(output_text)
         | 
| 113 |  | 
| 114 | 
            -
                #  | 
| 115 | 
             
                ocr_text = extract_all_text_pix2struct(image)
         | 
| 116 | 
             
                parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
         | 
| 117 |  | 
| 118 | 
             
                return {
         | 
| 119 | 
             
                    "json": parsed_json,
         | 
| 120 | 
             
                    "raw": output_text
         | 
| 121 | 
            -
                }
         | 
|  | |
| 3 | 
             
            from PIL import Image
         | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
         | 
|  | |
| 6 | 
             
            from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
         | 
| 7 | 
            +
            from qwen_vl_utils import process_vision_info
         | 
| 8 |  | 
| 9 | 
            +
            # Globals (lazy-loaded at runtime)
         | 
| 10 | 
            +
            qwen_model = None
         | 
| 11 | 
            +
            qwen_processor = None
         | 
| 12 | 
            +
            ocr_model = None
         | 
| 13 | 
            +
            ocr_processor = None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 14 |  | 
|  | |
|  | |
|  | |
| 15 |  | 
|  | |
| 16 | 
             
            def load_prompt():
         | 
| 17 | 
            +
                with open("prompts/prompt.txt", "r", encoding="utf-8") as f:
         | 
| 18 | 
             
                    return f.read()
         | 
| 19 |  | 
| 20 | 
            +
             | 
| 21 | 
             
            def try_extract_json(text):
         | 
| 22 | 
             
                try:
         | 
| 23 | 
             
                    return json.loads(text)
         | 
|  | |
| 40 | 
             
                    except json.JSONDecodeError:
         | 
| 41 | 
             
                        return text
         | 
| 42 |  | 
| 43 | 
            +
             | 
| 44 | 
             
            def extract_all_text_pix2struct(image: Image.Image):
         | 
| 45 | 
            +
                global ocr_model, ocr_processor
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                if ocr_model is None or ocr_processor is None:
         | 
| 48 | 
            +
                    ocr_processor = Pix2StructProcessor.from_pretrained("google/pix2struct-textcaps-base")
         | 
| 49 | 
            +
                    ocr_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
         | 
| 50 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 51 | 
            +
                    ocr_model = ocr_model.to(device)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                inputs = ocr_processor(images=image, return_tensors="pt").to(ocr_model.device)
         | 
| 54 | 
             
                predictions = ocr_model.generate(**inputs, max_new_tokens=512)
         | 
| 55 | 
            +
                return ocr_processor.decode(predictions[0], skip_special_tokens=True).strip()
         | 
| 56 | 
            +
             | 
| 57 |  | 
|  | |
| 58 | 
             
            def assign_event_gateway_names_from_ocr(json_data: dict, ocr_text: str):
         | 
| 59 | 
             
                if not ocr_text or not json_data:
         | 
| 60 | 
             
                    return json_data
         | 
| 61 |  | 
|  | |
|  | |
| 62 | 
             
                def assign_best_guess(obj):
         | 
| 63 | 
             
                    if not obj.get("name") or obj["name"].strip() == "":
         | 
| 64 | 
             
                        obj["name"] = "(label unknown)"
         | 
|  | |
| 71 |  | 
| 72 | 
             
                return json_data
         | 
| 73 |  | 
| 74 | 
            +
             | 
| 75 | 
             
            def run_model(image: Image.Image):
         | 
| 76 | 
            +
                global qwen_model, qwen_processor
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                if qwen_model is None or qwen_processor is None:
         | 
| 79 | 
            +
                    qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
         | 
| 80 | 
            +
                        "Qwen/Qwen2.5-VL-7B-Instruct",
         | 
| 81 | 
            +
                        torch_dtype=torch.bfloat16,
         | 
| 82 | 
            +
                        device_map="auto"
         | 
| 83 | 
            +
                        # You can enable flash attention here if needed:
         | 
| 84 | 
            +
                        # attn_implementation="flash_attention_2"
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    min_pixels = 256 * 28 * 28
         | 
| 88 | 
            +
                    max_pixels = 1080 * 28 * 28
         | 
| 89 | 
            +
                    qwen_processor = AutoProcessor.from_pretrained(
         | 
| 90 | 
            +
                        "Qwen/Qwen2.5-VL-7B-Instruct",
         | 
| 91 | 
            +
                        min_pixels=min_pixels,
         | 
| 92 | 
            +
                        max_pixels=max_pixels
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
             | 
| 95 | 
             
                prompt = load_prompt()
         | 
| 96 | 
            +
             | 
| 97 | 
             
                messages = [
         | 
| 98 | 
             
                    {
         | 
| 99 | 
             
                        "role": "user",
         | 
|  | |
| 104 | 
             
                    }
         | 
| 105 | 
             
                ]
         | 
| 106 |  | 
| 107 | 
            +
                text = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
         | 
| 108 | 
             
                image_inputs, video_inputs = process_vision_info(messages)
         | 
| 109 |  | 
| 110 | 
            +
                inputs = qwen_processor(
         | 
| 111 | 
             
                    text=[text],
         | 
| 112 | 
             
                    images=image_inputs,
         | 
| 113 | 
             
                    videos=video_inputs,
         | 
| 114 | 
             
                    padding=True,
         | 
| 115 | 
             
                    return_tensors="pt"
         | 
| 116 | 
            +
                ).to(qwen_model.device)
         | 
| 117 |  | 
| 118 | 
            +
                generated_ids = qwen_model.generate(**inputs, max_new_tokens=5000)
         | 
| 119 | 
             
                generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
         | 
| 120 |  | 
| 121 | 
            +
                output_text = qwen_processor.batch_decode(
         | 
| 122 | 
             
                    generated_ids_trimmed,
         | 
| 123 | 
             
                    skip_special_tokens=True,
         | 
| 124 | 
             
                    clean_up_tokenization_spaces=False
         | 
|  | |
| 126 |  | 
| 127 | 
             
                parsed_json = try_extract_json(output_text)
         | 
| 128 |  | 
| 129 | 
            +
                # OCR post-processing
         | 
| 130 | 
             
                ocr_text = extract_all_text_pix2struct(image)
         | 
| 131 | 
             
                parsed_json = assign_event_gateway_names_from_ocr(parsed_json, ocr_text)
         | 
| 132 |  | 
| 133 | 
             
                return {
         | 
| 134 | 
             
                    "json": parsed_json,
         | 
| 135 | 
             
                    "raw": output_text
         | 
| 136 | 
            +
                }
         |