omniVLM / app.py
sksstudio
sa
4a73fad
# app.py
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from llama_cpp import Llama
from typing import Optional
import uvicorn
import huggingface_hub
import os
from PIL import Image
import io
import base64
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="OmniVLM API",
description="API for text and image processing using OmniVLM model",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Download the model from Hugging Face Hub
try:
model_path = huggingface_hub.hf_hub_download(
repo_id="NexaAIDev/OmniVLM-968M",
filename="omnivision-text-optimized-llm-Q8_0.gguf"
)
logger.info(f"Model downloaded successfully to {model_path}")
except Exception as e:
logger.error(f"Error downloading model: {e}")
raise
# Initialize the model with the downloaded file
try:
llm = Llama(
model_path=model_path,
n_ctx=2048,
n_threads=4,
n_batch=512,
verbose=True
)
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {e}")
raise
class GenerationRequest(BaseModel):
prompt: str
max_tokens: Optional[int] = 100
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
class GenerationResponse(BaseModel):
generated_text: str
error: Optional[str] = None
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
try:
output = llm(
request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
)
return GenerationResponse(generated_text=output["choices"][0]["text"])
except Exception as e:
logger.error(f"Error in text generation: {e}")
return GenerationResponse(generated_text="", error=str(e))
@app.post("/process-image", response_model=GenerationResponse)
async def process_image(
file: UploadFile = File(...),
prompt: str = Form("Describe this image in detail"),
max_tokens: int = Form(200),
temperature: float = Form(0.7)
):
try:
# Validate file size
file_size = 0
file_content = await file.read()
file_size = len(file_content)
if file_size > MAX_IMAGE_SIZE:
raise HTTPException(status_code=400, detail="File too large")
# Validate file type
if not allowed_file(file.filename):
raise HTTPException(status_code=400, detail="File type not allowed")
# Process image
try:
image = Image.open(io.BytesIO(file_content))
# Convert image to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize image if too large
max_size = (1024, 1024)
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.Resampling.LANCZOS)
# Convert to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
# Create prompt with image
full_prompt = f"""
<image>data:image/jpeg;base64,{img_str}</image>
{prompt}
"""
logger.info("Processing image with prompt")
# Generate description
output = llm(
full_prompt,
max_tokens=max_tokens,
temperature=temperature
)
return GenerationResponse(generated_text=output["choices"][0]["text"])
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Unexpected error: {e}")
return GenerationResponse(generated_text="", error=str(e))
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": llm is not None
}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")