Spaces:
Runtime error
Runtime error
File size: 3,597 Bytes
077ae35 33101ab 077ae35 d064df9 077ae35 6051509 495550d 6051509 495550d 6051509 33101ab 495550d 33101ab 6051509 33101ab 6051509 495550d 6051509 33101ab 6051509 495550d 6051509 495550d 6051509 495550d 33101ab 667a64e 33101ab 667a64e 33101ab 495550d 33101ab 495550d 667a64e 6b8495c 667a64e 33101ab 667a64e 33101ab 667a64e 33101ab 667a64e 33101ab 495550d 667a64e 33101ab 667a64e 495550d 667a64e 33101ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_PATH = "THUDM/cogvlm2-video-llama3-chat"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
def load_model():
"""Loads the pre-trained model and tokenizer with quantization configurations."""
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=TORCH_TYPE,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="auto"
).eval()
return model, tokenizer
def predict_image(prompt, image, temperature, model, tokenizer):
"""Generates predictions based on the image and textual prompt."""
image = image.convert("RGB") # Ensure image is in RGB format
# Convert image to model-expected format
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=prompt,
images=[image],
history=[],
template_version='chat'
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 512,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
model, tokenizer = load_model()
def inference(image):
"""Generates a description of the input image."""
try:
if not image:
return "Please upload an image first."
prompt = "Describe the image and the components observed in the given input image."
temperature = 0.3
response = predict_image(prompt, image, temperature, model, tokenizer)
return response
except Exception as e:
return f"An error occurred during analysis: {str(e)}"
def create_interface():
"""Creates the Gradio interface for Image Description System."""
with gr.Blocks() as demo:
gr.Markdown("""
# Image Description System
Upload an image, and the system will describe the image and its components.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil")
analyze_btn = gr.Button("Describe Image", variant="primary")
with gr.Column():
output = gr.Textbox(label="Image Description", lines=10)
analyze_btn.click(
fn=inference,
inputs=[image_input],
outputs=[output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue().launch(share=True)
|