File size: 4,982 Bytes
5401975 4a73fad bf190b6 e0172c2 1be012e 5401975 4a73fad bf190b6 5401975 bf190b6 4a73fad e0172c2 4a73fad e0172c2 4a73fad bf190b6 5401975 bf190b6 5401975 4a73fad bf190b6 5401975 4a73fad 5401975 4a73fad 5401975 4a73fad 5401975 4a73fad 5401975 4a73fad 5401975 4a73fad 5401975 4a73fad bf190b6 4a73fad bf190b6 5401975 4a73fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# app.py
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from llama_cpp import Llama
from typing import Optional
import uvicorn
import huggingface_hub
import os
from PIL import Image
import io
import base64
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="OmniVLM API",
description="API for text and image processing using OmniVLM model",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Download the model from Hugging Face Hub
try:
model_path = huggingface_hub.hf_hub_download(
repo_id="NexaAIDev/OmniVLM-968M",
filename="omnivision-text-optimized-llm-Q8_0.gguf"
)
logger.info(f"Model downloaded successfully to {model_path}")
except Exception as e:
logger.error(f"Error downloading model: {e}")
raise
# Initialize the model with the downloaded file
try:
llm = Llama(
model_path=model_path,
n_ctx=2048,
n_threads=4,
n_batch=512,
verbose=True
)
logger.info("Model initialized successfully")
except Exception as e:
logger.error(f"Error initializing model: {e}")
raise
class GenerationRequest(BaseModel):
prompt: str
max_tokens: Optional[int] = 100
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
class GenerationResponse(BaseModel):
generated_text: str
error: Optional[str] = None
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_IMAGE_SIZE = 10 * 1024 * 1024 # 10MB
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
try:
output = llm(
request.prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
)
return GenerationResponse(generated_text=output["choices"][0]["text"])
except Exception as e:
logger.error(f"Error in text generation: {e}")
return GenerationResponse(generated_text="", error=str(e))
@app.post("/process-image", response_model=GenerationResponse)
async def process_image(
file: UploadFile = File(...),
prompt: str = Form("Describe this image in detail"),
max_tokens: int = Form(200),
temperature: float = Form(0.7)
):
try:
# Validate file size
file_size = 0
file_content = await file.read()
file_size = len(file_content)
if file_size > MAX_IMAGE_SIZE:
raise HTTPException(status_code=400, detail="File too large")
# Validate file type
if not allowed_file(file.filename):
raise HTTPException(status_code=400, detail="File type not allowed")
# Process image
try:
image = Image.open(io.BytesIO(file_content))
# Convert image to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize image if too large
max_size = (1024, 1024)
if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
image.thumbnail(max_size, Image.Resampling.LANCZOS)
# Convert to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
# Create prompt with image
full_prompt = f"""
<image>data:image/jpeg;base64,{img_str}</image>
{prompt}
"""
logger.info("Processing image with prompt")
# Generate description
output = llm(
full_prompt,
max_tokens=max_tokens,
temperature=temperature
)
return GenerationResponse(generated_text=output["choices"][0]["text"])
except Exception as e:
logger.error(f"Error processing image: {e}")
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Unexpected error: {e}")
return GenerationResponse(generated_text="", error=str(e))
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": llm is not None
}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |