File size: 7,241 Bytes
475b0b9 a506b86 475b0b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import torch
from logging_config import logger
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from PIL import Image
from fastapi import HTTPException
from io import BytesIO
class LLMManager:
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
self.model_name = model_name
self.device = torch.device(device)
self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
self.model = None
self.is_loaded = False
self.processor = None
def unload(self):
if self.is_loaded:
# Delete the model and processor to free memory
del self.model
del self.processor
# If using CUDA, clear the cache to free GPU memory
if self.device.type == "cuda":
torch.cuda.empty_cache()
self.is_loaded = False
logger.info(f"LLM {self.model_name} unloaded from {self.device}")
def load(self):
if not self.is_loaded:
#self.model_name = "google/gemma-3-4b-it"
self.model = Gemma3ForConditionalGeneration.from_pretrained(
self.model_name, device_map="auto"
).eval()
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.is_loaded = True
logger.info(f"LLM {self.model_name} loaded on {self.device}")
async def generate(self, prompt: str, max_tokens: int = 2048, temperature: float = 0.7) -> str:
if not self.is_loaded:
self.load()
messages_vlm = [
{
"role": "system",
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state, Provide a concise response in one sentence maximum."}]
},
{
"role": "user",
"content": []
}
]
# Add text prompt to user content
messages_vlm[1]["content"].append({"type": "text", "text": prompt})
# Process the chat template with the processor
inputs_vlm = self.processor.apply_chat_template(
messages_vlm,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
input_len = inputs_vlm["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
# Decode the output
response = self.processor.decode(generation, skip_special_tokens=True)
return response
async def vision_query(self, image: Image.Image, query: str) -> str:
if not self.is_loaded:
self.load()
messages_vlm = [
{
"role": "system",
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarise your answer in max 2 lines."}]
},
{
"role": "user",
"content": []
}
]
# Add text prompt to user content
messages_vlm[1]["content"].append({"type": "text", "text": query})
# Handle image if provided and valid
if image and image.size[0] > 0 and image.size[1] > 0: # Check for valid dimensions
# Image is already a PIL Image, no need to read or reopen
messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
logger.info(f"Received valid image for processing")
else:
logger.info("No valid image provided, processing text only")
# Process the chat template with the processor
try:
inputs_vlm = self.processor.apply_chat_template(
messages_vlm,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
except Exception as e:
logger.error(f"Error in apply_chat_template: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
input_len = inputs_vlm["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
# Decode the output
decoded = self.processor.decode(generation, skip_special_tokens=True)
logger.info(f"Chat Response: {decoded}")
return decoded
async def chat_v2(self, image: Image.Image, query: str) -> str:
if not self.is_loaded:
self.load()
# Construct the message structure
messages_vlm = [
{
"role": "system",
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state"}]
},
{
"role": "user",
"content": []
}
]
# Add text prompt to user content
messages_vlm[1]["content"].append({"type": "text", "text": query})
# Handle image only if provided and valid
if image and image.file and image.size > 0: # Check for valid file with content
# Read the image file
image_data = await image.read()
if not image_data:
raise HTTPException(status_code=400, detail="Uploaded image is empty")
# Open image with PIL for processing
img = Image.open(BytesIO(image_data))
# Add image to content (assuming processor accepts PIL images)
messages_vlm[1]["content"].insert(0, {"type": "image", "image": img})
logger.info(f"Received image: {image.filename}")
else:
if image and (not image.file or image.size == 0):
logger.warning("Received invalid or empty image parameter, treating as text-only")
logger.info("No valid image provided, processing text only")
# Process the chat template with the processor
inputs_vlm = self.processor.apply_chat_template(
messages_vlm,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16)
input_len = inputs_vlm["input_ids"].shape[-1]
# Generate response
with torch.inference_mode():
generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
# Decode the output
decoded = self.processor.decode(generation, skip_special_tokens=True)
logger.info(f"Chat Response: {decoded}")
return decoded
|