File size: 4,982 Bytes
5401975
4a73fad
 
bf190b6
 
 
 
e0172c2
1be012e
5401975
 
 
4a73fad
 
 
 
 
bf190b6
 
5401975
 
 
bf190b6
 
4a73fad
 
 
 
 
 
 
e0172c2
 
4a73fad
 
 
 
 
 
 
 
 
 
 
e0172c2
4a73fad
 
 
 
 
 
 
 
 
 
 
 
bf190b6
 
5401975
 
 
 
 
bf190b6
5401975
4a73fad
 
 
 
 
 
 
 
bf190b6
 
 
5401975
 
 
 
 
 
 
 
 
 
4a73fad
 
5401975
 
 
 
4a73fad
 
 
5401975
 
4a73fad
 
 
 
5401975
4a73fad
 
5401975
4a73fad
 
 
5401975
4a73fad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5401975
4a73fad
 
bf190b6
 
 
4a73fad
 
 
 
bf190b6
 
5401975
4a73fad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# app.py
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from llama_cpp import Llama
from typing import Optional
import uvicorn
import huggingface_hub
import os
from PIL import Image
import io
import base64
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="OmniVLM API",
    description="API for text and image processing using OmniVLM model",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Download the model from Hugging Face Hub
try:
    model_path = huggingface_hub.hf_hub_download(
        repo_id="NexaAIDev/OmniVLM-968M",
        filename="omnivision-text-optimized-llm-Q8_0.gguf"
    )
    logger.info(f"Model downloaded successfully to {model_path}")
except Exception as e:
    logger.error(f"Error downloading model: {e}")
    raise

# Initialize the model with the downloaded file
try:
    llm = Llama(
        model_path=model_path,
        n_ctx=2048,
        n_threads=4,
        n_batch=512,
        verbose=True
    )
    logger.info("Model initialized successfully")
except Exception as e:
    logger.error(f"Error initializing model: {e}")
    raise

class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: Optional[int] = 100
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9

class GenerationResponse(BaseModel):
    generated_text: str
    error: Optional[str] = None

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_IMAGE_SIZE = 10 * 1024 * 1024  # 10MB

def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    try:
        output = llm(
            request.prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p
        )
        
        return GenerationResponse(generated_text=output["choices"][0]["text"])
    except Exception as e:
        logger.error(f"Error in text generation: {e}")
        return GenerationResponse(generated_text="", error=str(e))

@app.post("/process-image", response_model=GenerationResponse)
async def process_image(
    file: UploadFile = File(...),
    prompt: str = Form("Describe this image in detail"),
    max_tokens: int = Form(200),
    temperature: float = Form(0.7)
):
    try:
        # Validate file size
        file_size = 0
        file_content = await file.read()
        file_size = len(file_content)
        
        if file_size > MAX_IMAGE_SIZE:
            raise HTTPException(status_code=400, detail="File too large")
        
        # Validate file type
        if not allowed_file(file.filename):
            raise HTTPException(status_code=400, detail="File type not allowed")
        
        # Process image
        try:
            image = Image.open(io.BytesIO(file_content))
            
            # Convert image to RGB if necessary
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Resize image if too large
            max_size = (1024, 1024)
            if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
                image.thumbnail(max_size, Image.Resampling.LANCZOS)
            
            # Convert to base64
            buffered = io.BytesIO()
            image.save(buffered, format="JPEG", quality=85)
            img_str = base64.b64encode(buffered.getvalue()).decode()
            
            # Create prompt with image
            full_prompt = f"""
            <image>data:image/jpeg;base64,{img_str}</image>
            {prompt}
            """
            
            logger.info("Processing image with prompt")
            # Generate description
            output = llm(
                full_prompt,
                max_tokens=max_tokens,
                temperature=temperature
            )
            
            return GenerationResponse(generated_text=output["choices"][0]["text"])
            
        except Exception as e:
            logger.error(f"Error processing image: {e}")
            raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
            
    except HTTPException as he:
        raise he
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        return GenerationResponse(generated_text="", error=str(e))

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": llm is not None
    }

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")