|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from llama_cpp import Llama |
|
from typing import Optional |
|
import uvicorn |
|
import huggingface_hub |
|
import os |
|
from PIL import Image |
|
import io |
|
import base64 |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI( |
|
title="OmniVLM API", |
|
description="API for text and image processing using OmniVLM model", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
try: |
|
model_path = huggingface_hub.hf_hub_download( |
|
repo_id="NexaAIDev/OmniVLM-968M", |
|
filename="omnivision-text-optimized-llm-Q8_0.gguf" |
|
) |
|
logger.info(f"Model downloaded successfully to {model_path}") |
|
except Exception as e: |
|
logger.error(f"Error downloading model: {e}") |
|
raise |
|
|
|
|
|
try: |
|
llm = Llama( |
|
model_path=model_path, |
|
n_ctx=2048, |
|
n_threads=4, |
|
n_batch=512, |
|
verbose=True |
|
) |
|
logger.info("Model initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Error initializing model: {e}") |
|
raise |
|
|
|
class GenerationRequest(BaseModel): |
|
prompt: str |
|
max_tokens: Optional[int] = 100 |
|
temperature: Optional[float] = 0.7 |
|
top_p: Optional[float] = 0.9 |
|
|
|
class GenerationResponse(BaseModel): |
|
generated_text: str |
|
error: Optional[str] = None |
|
|
|
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} |
|
MAX_IMAGE_SIZE = 10 * 1024 * 1024 |
|
|
|
def allowed_file(filename): |
|
return '.' in filename and \ |
|
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
|
@app.post("/generate", response_model=GenerationResponse) |
|
async def generate_text(request: GenerationRequest): |
|
try: |
|
output = llm( |
|
request.prompt, |
|
max_tokens=request.max_tokens, |
|
temperature=request.temperature, |
|
top_p=request.top_p |
|
) |
|
|
|
return GenerationResponse(generated_text=output["choices"][0]["text"]) |
|
except Exception as e: |
|
logger.error(f"Error in text generation: {e}") |
|
return GenerationResponse(generated_text="", error=str(e)) |
|
|
|
@app.post("/process-image", response_model=GenerationResponse) |
|
async def process_image( |
|
file: UploadFile = File(...), |
|
prompt: str = Form("Describe this image in detail"), |
|
max_tokens: int = Form(200), |
|
temperature: float = Form(0.7) |
|
): |
|
try: |
|
|
|
file_size = 0 |
|
file_content = await file.read() |
|
file_size = len(file_content) |
|
|
|
if file_size > MAX_IMAGE_SIZE: |
|
raise HTTPException(status_code=400, detail="File too large") |
|
|
|
|
|
if not allowed_file(file.filename): |
|
raise HTTPException(status_code=400, detail="File type not allowed") |
|
|
|
|
|
try: |
|
image = Image.open(io.BytesIO(file_content)) |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
max_size = (1024, 1024) |
|
if image.size[0] > max_size[0] or image.size[1] > max_size[1]: |
|
image.thumbnail(max_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG", quality=85) |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
full_prompt = f""" |
|
<image>data:image/jpeg;base64,{img_str}</image> |
|
{prompt} |
|
""" |
|
|
|
logger.info("Processing image with prompt") |
|
|
|
output = llm( |
|
full_prompt, |
|
max_tokens=max_tokens, |
|
temperature=temperature |
|
) |
|
|
|
return GenerationResponse(generated_text=output["choices"][0]["text"]) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {e}") |
|
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") |
|
|
|
except HTTPException as he: |
|
raise he |
|
except Exception as e: |
|
logger.error(f"Unexpected error: {e}") |
|
return GenerationResponse(generated_text="", error=str(e)) |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return { |
|
"status": "healthy", |
|
"model_loaded": llm is not None |
|
} |
|
|
|
if __name__ == "__main__": |
|
port = int(os.environ.get("PORT", 7860)) |
|
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |