File size: 3,597 Bytes
077ae35
 
33101ab
 
 
077ae35
d064df9
077ae35
 
 
6051509
495550d
6051509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495550d
6051509
 
33101ab
 
 
495550d
33101ab
6051509
 
 
33101ab
6051509
 
 
 
 
 
 
 
495550d
6051509
 
 
33101ab
6051509
 
 
 
 
 
 
 
 
495550d
6051509
495550d
6051509
 
495550d
 
33101ab
 
667a64e
33101ab
 
667a64e
33101ab
495550d
33101ab
495550d
667a64e
 
 
6b8495c
667a64e
33101ab
667a64e
 
33101ab
 
667a64e
 
 
 
33101ab
 
667a64e
 
33101ab
495550d
667a64e
 
33101ab
667a64e
 
495550d
667a64e
 
 
 
33101ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_PATH = "THUDM/cogvlm2-video-llama3-chat"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16

def load_model():
    """Loads the pre-trained model and tokenizer with quantization configurations."""
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=TORCH_TYPE,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=TORCH_TYPE,
        trust_remote_code=True,
        quantization_config=quantization_config,
        device_map="auto"
    ).eval()
    
    return model, tokenizer

def predict_image(prompt, image, temperature, model, tokenizer):
    """Generates predictions based on the image and textual prompt."""
    image = image.convert("RGB")  # Ensure image is in RGB format
    
    # Convert image to model-expected format
    inputs = model.build_conversation_input_ids(
        tokenizer=tokenizer,
        query=prompt,
        images=[image],
        history=[],
        template_version='chat'
    )
    
    inputs = {
        'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
        'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
        'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
        'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
    }
    
    gen_kwargs = {
        "max_new_tokens": 512,
        "pad_token_id": 128002,
        "top_k": 1,
        "do_sample": False,
        "top_p": 0.1,
        "temperature": temperature,
    }
    
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return response

model, tokenizer = load_model()

def inference(image):
    """Generates a description of the input image."""
    try:
        if not image:
            return "Please upload an image first."
        
        prompt = "Describe the image and the components observed in the given input image."
        temperature = 0.3
        response = predict_image(prompt, image, temperature, model, tokenizer)
        
        return response
    except Exception as e:
        return f"An error occurred during analysis: {str(e)}"

def create_interface():
    """Creates the Gradio interface for Image Description System."""
    with gr.Blocks() as demo:
        gr.Markdown("""
        # Image Description System
        Upload an image, and the system will describe the image and its components.
        """)
        
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(label="Upload Image", type="pil")
                analyze_btn = gr.Button("Describe Image", variant="primary")
            
            with gr.Column():
                output = gr.Textbox(label="Image Description", lines=10)
        
        analyze_btn.click(
            fn=inference,
            inputs=[image_input],
            outputs=[output]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.queue().launch(share=True)