Molmo-4bit / app.py
zamal's picture
Update app.py
f82d3e1 verified
raw
history blame
2.55 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import os
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
os.system('pip install -U bitsandbytes-cuda117')
# Define the repository for the quantized model
repo_name = "cyan2k/molmo-7B-D-bnb-4bit"
# Load processor and model with GPU optimization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(repo_name, trust_remote_code=True)
# Load model with 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(repo_name,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True,
trust_remote_code=True)
model.to(device)
def process_image_and_text(image, text):
# Convert numpy image to PIL format
pil_image = Image.fromarray(image)
# Process image and text with processor
inputs = processor(images=[pil_image], text=text, return_tensors="pt").to(device)
# Generate output using the model
output = model.generate(**inputs, max_new_tokens=200)
# Decode the generated output
generated_text = processor.decode(output[0], skip_special_tokens=True)
return generated_text
def chatbot(image, text, history):
# Check if the image is uploaded
if image is None:
return history + [("Please upload an image first.", None)]
# Get response by processing the image and text
response = process_image_and_text(image, text)
# Append question and response to the chat history
history.append((text, response))
return history
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Chatbot with Molmo-7B-4 Bit Quantized")
with gr.Row():
image_input = gr.Image(type="numpy")
chatbot_output = gr.Chatbot()
text_input = gr.Textbox(placeholder="Ask a question about the image...")
submit_button = gr.Button("Submit")
state = gr.State([])
# Connect the submit button and textbox to the chatbot function
submit_button.click(fn=chatbot, inputs=[image_input, text_input, state], outputs=chatbot_output)
text_input.submit(fn=chatbot, inputs=[image_input, text_input, state], outputs=chatbot_output)
# Launch the Gradio app with GPU
demo.launch(share=True)