zb9's picture
Update app.py
7276a9d verified
raw
history blame
1.63 kB
# app.py in Hugging Face Space
import gradio as gr
from colpali_engine.models import ColQwen2, ColQwen2Processor
import torch
from PIL import Image
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("colqwen-api")
# Initialize model
logger.info("Loading ColQwen2 model...")
model = ColQwen2.from_pretrained(
"vidore/colqwen2-v1.0",
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
model = model.eval()
logger.info("Model loaded successfully")
def process_image(image):
try:
logger.info("Processing image")
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
inputs = processor(
images=image,
return_tensors="pt"
).to(model.device)
logger.info("Generating embeddings")
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
logger.info(f"Embeddings shape: {embeddings.shape}")
return {
"embeddings": embeddings.tolist(),
"shape": embeddings.shape
}
except Exception as e:
logger.error(f"Error: {str(e)}", exc_info=True)
raise
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(),
outputs="json",
title="ColQwen2 Embedding API",
description="Generate embeddings from images using ColQwen2"
)
# Let Gradio choose an available port
interface.launch(server_name="0.0.0.0")