Apollo_GenAI / app.py
VishalD1234's picture
Update app.py
d064df9 verified
raw
history blame
3.6 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_PATH = "THUDM/cogvlm2-video-llama3-chat"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
def load_model():
"""Loads the pre-trained model and tokenizer with quantization configurations."""
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=TORCH_TYPE,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
quantization_config=quantization_config,
device_map="auto"
).eval()
return model, tokenizer
def predict_image(prompt, image, temperature, model, tokenizer):
"""Generates predictions based on the image and textual prompt."""
image = image.convert("RGB") # Ensure image is in RGB format
# Convert image to model-expected format
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=prompt,
images=[image],
history=[],
template_version='chat'
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 512,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
model, tokenizer = load_model()
def inference(image):
"""Generates a description of the input image."""
try:
if not image:
return "Please upload an image first."
prompt = "Describe the image and the components observed in the given input image."
temperature = 0.3
response = predict_image(prompt, image, temperature, model, tokenizer)
return response
except Exception as e:
return f"An error occurred during analysis: {str(e)}"
def create_interface():
"""Creates the Gradio interface for Image Description System."""
with gr.Blocks() as demo:
gr.Markdown("""
# Image Description System
Upload an image, and the system will describe the image and its components.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil")
analyze_btn = gr.Button("Describe Image", variant="primary")
with gr.Column():
output = gr.Textbox(label="Image Description", lines=10)
analyze_btn.click(
fn=inference,
inputs=[image_input],
outputs=[output]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue().launch(share=True)