sachin commited on
Commit
475b0b9
·
1 Parent(s): b9d432e

test-gemma-llm

Browse files
.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ venv_new/
6
+ NeMo/
7
+ asr_venv
8
+ *.nemo
9
+ # C extensions
10
+ *.so
11
+ *.mp3
12
+
13
+ !kannada_female.wav
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # UV
103
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ #uv.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121
+ .pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ #.idea/
174
+
175
+ # PyPI configuration file
176
+ .pypirc
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+ WORKDIR /app
3
+
4
+ RUN apt-get update && apt-get install -y \
5
+ python3 \
6
+ python3-pip \
7
+ git \
8
+ ffmpeg \
9
+ sudo \
10
+ wget libvips\
11
+ build-essential \
12
+ curl \
13
+ && ln -s /usr/bin/python3 /usr/bin/python \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
17
+ ENV PATH="/root/.cargo/bin:${PATH}"
18
+
19
+ RUN pip install --upgrade pip setuptools setuptools-rust torch
20
+ COPY requirements.txt .
21
+ #RUN pip install --no-cache-dir torch==2.6.0 torchvision
22
+ #RUN pip install --no-cache-dir transformers
23
+ RUN pip install --no-cache-dir -r requirements.txt
24
+ RUN pip install git+https://github.com/ai4bharat/IndicF5.git
25
+
26
+ COPY . .
27
+
28
+ RUN useradd -ms /bin/bash appuser \
29
+ && chown -R appuser:appuser /app
30
+
31
+ USER appuser
32
+
33
+ EXPOSE 7860
34
+
35
+ # Use absolute path for clarity
36
+ CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers @ git+https://github.com/huggingface/[email protected]
2
+ torch
3
+ accelerate
4
+ bitsandbytes
5
+ pillow
6
+ uvicorn
7
+ fastapi
8
+ pydantic_settings
9
+ slowapi
10
+ python-multipart
src/server/auth.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.security import APIKeyHeader
2
+ from fastapi import HTTPException, status, Depends
3
+ from pydantic_settings import BaseSettings
4
+ from config.logging_config import logger
5
+
6
+ class Settings(BaseSettings):
7
+ api_key: str
8
+ class Config:
9
+ env_file = ".env"
10
+
11
+ settings = Settings()
12
+
13
+ API_KEY_NAME = "X-API-Key"
14
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
15
+
16
+ async def get_api_key(api_key: str = Depends(api_key_header)):
17
+ if api_key != settings.api_key:
18
+ logger.warning(f"Failed API key attempt: {api_key}")
19
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
20
+ logger.info("API key validated successfully")
21
+ return api_key
src/server/gemma_llm.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config.logging_config import logger
3
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
4
+ from PIL import Image
5
+ from fastapi import HTTPException
6
+ from io import BytesIO
7
+
8
+
9
+ class LLMManager:
10
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
11
+ self.model_name = model_name
12
+ self.device = torch.device(device)
13
+ self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
14
+ self.model = None
15
+ self.is_loaded = False
16
+ self.processor = None
17
+
18
+ def unload(self):
19
+ if self.is_loaded:
20
+ # Delete the model and processor to free memory
21
+ del self.model
22
+ del self.processor
23
+ # If using CUDA, clear the cache to free GPU memory
24
+ if self.device.type == "cuda":
25
+ torch.cuda.empty_cache()
26
+ self.is_loaded = False
27
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
28
+ def load(self):
29
+ if not self.is_loaded:
30
+
31
+ #self.model_name = "google/gemma-3-4b-it"
32
+
33
+ self.model = Gemma3ForConditionalGeneration.from_pretrained(
34
+ self.model_name, device_map="auto"
35
+ ).eval()
36
+
37
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
38
+
39
+ self.is_loaded = True
40
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
41
+
42
+ async def generate(self, prompt: str, max_tokens: int = 2048, temperature: float = 0.7) -> str:
43
+ if not self.is_loaded:
44
+ self.load()
45
+
46
+ messages_vlm = [
47
+ {
48
+ "role": "system",
49
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state, Provide a concise response in one sentence maximum."}]
50
+ },
51
+ {
52
+ "role": "user",
53
+ "content": []
54
+ }
55
+ ]
56
+
57
+ # Add text prompt to user content
58
+ messages_vlm[1]["content"].append({"type": "text", "text": prompt})
59
+
60
+ # Process the chat template with the processor
61
+ inputs_vlm = self.processor.apply_chat_template(
62
+ messages_vlm,
63
+ add_generation_prompt=True,
64
+ tokenize=True,
65
+ return_dict=True,
66
+ return_tensors="pt"
67
+ ).to(self.model.device, dtype=torch.bfloat16)
68
+
69
+ input_len = inputs_vlm["input_ids"].shape[-1]
70
+
71
+ # Generate response
72
+ with torch.inference_mode():
73
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
74
+ generation = generation[0][input_len:]
75
+
76
+ # Decode the output
77
+ response = self.processor.decode(generation, skip_special_tokens=True)
78
+
79
+ return response
80
+
81
+ async def vision_query(self, image: Image.Image, query: str) -> str:
82
+ if not self.is_loaded:
83
+ self.load()
84
+
85
+ messages_vlm = [
86
+ {
87
+ "role": "system",
88
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarise your answer in max 2 lines."}]
89
+ },
90
+ {
91
+ "role": "user",
92
+ "content": []
93
+ }
94
+ ]
95
+
96
+ # Add text prompt to user content
97
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
98
+
99
+ # Handle image if provided and valid
100
+ if image and image.size[0] > 0 and image.size[1] > 0: # Check for valid dimensions
101
+ # Image is already a PIL Image, no need to read or reopen
102
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
103
+ logger.info(f"Received valid image for processing")
104
+ else:
105
+ logger.info("No valid image provided, processing text only")
106
+
107
+ # Process the chat template with the processor
108
+ try:
109
+ inputs_vlm = self.processor.apply_chat_template(
110
+ messages_vlm,
111
+ add_generation_prompt=True,
112
+ tokenize=True,
113
+ return_dict=True,
114
+ return_tensors="pt"
115
+ ).to(self.model.device, dtype=torch.bfloat16)
116
+ except Exception as e:
117
+ logger.error(f"Error in apply_chat_template: {str(e)}")
118
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
119
+
120
+ input_len = inputs_vlm["input_ids"].shape[-1]
121
+
122
+ # Generate response
123
+ with torch.inference_mode():
124
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
125
+ generation = generation[0][input_len:]
126
+
127
+ # Decode the output
128
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
129
+ logger.info(f"Chat Response: {decoded}")
130
+
131
+ return decoded
132
+
133
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
134
+ if not self.is_loaded:
135
+ self.load()
136
+ # Construct the message structure
137
+ messages_vlm = [
138
+ {
139
+ "role": "system",
140
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state"}]
141
+ },
142
+ {
143
+ "role": "user",
144
+ "content": []
145
+ }
146
+ ]
147
+
148
+ # Add text prompt to user content
149
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
150
+
151
+ # Handle image only if provided and valid
152
+ if image and image.file and image.size > 0: # Check for valid file with content
153
+ # Read the image file
154
+ image_data = await image.read()
155
+ if not image_data:
156
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
157
+ # Open image with PIL for processing
158
+ img = Image.open(BytesIO(image_data))
159
+ # Add image to content (assuming processor accepts PIL images)
160
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": img})
161
+ logger.info(f"Received image: {image.filename}")
162
+ else:
163
+ if image and (not image.file or image.size == 0):
164
+ logger.warning("Received invalid or empty image parameter, treating as text-only")
165
+ logger.info("No valid image provided, processing text only")
166
+
167
+ # Process the chat template with the processor
168
+ inputs_vlm = self.processor.apply_chat_template(
169
+ messages_vlm,
170
+ add_generation_prompt=True,
171
+ tokenize=True,
172
+ return_dict=True,
173
+ return_tensors="pt"
174
+ ).to(self.model.device, dtype=torch.bfloat16)
175
+
176
+ input_len = inputs_vlm["input_ids"].shape[-1]
177
+
178
+ # Generate response
179
+ with torch.inference_mode():
180
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
181
+ generation = generation[0][input_len:]
182
+
183
+ # Decode the output
184
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
185
+ logger.info(f"Chat Response: {decoded}")
186
+ return decoded
src/server/logging_config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.config
3
+ from logging.handlers import RotatingFileHandler
4
+ from .tts_config import config
5
+
6
+ logging_config = {
7
+ "version": 1,
8
+ "disable_existing_loggers": False,
9
+ "formatters": {
10
+ "simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"},
11
+ },
12
+ "handlers": {
13
+ "stdout": {
14
+ "class": "logging.StreamHandler",
15
+ "formatter": "simple",
16
+ "stream": "ext://sys.stdout",
17
+ },
18
+ "file": {
19
+ "class": "logging.handlers.RotatingFileHandler",
20
+ "formatter": "simple",
21
+ "filename": "dhwani_api.log",
22
+ "maxBytes": 10 * 1024 * 1024, # 10MB
23
+ "backupCount": 5,
24
+ },
25
+ },
26
+ "loggers": {
27
+ "root": {
28
+ "level": config.log_level.upper(),
29
+ "handlers": ["stdout", "file"],
30
+ },
31
+ },
32
+ }
33
+
34
+ logging.config.dictConfig(logging_config)
35
+ logger = logging.getLogger("indic_all_server")
src/server/main.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ from time import time
5
+ from typing import List
6
+
7
+ import tempfile
8
+ import uvicorn
9
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
12
+ from PIL import Image
13
+ from pydantic import BaseModel, field_validator
14
+ from pydantic_settings import BaseSettings
15
+ from slowapi import Limiter
16
+ from slowapi.util import get_remote_address
17
+ import requests
18
+
19
+ from logging_config import logger
20
+ from tts_config import SPEED, ResponseFormat, config as tts_config
21
+ from gemma_llm import LLMManager
22
+ from utils.auth import get_api_key, settings as auth_settings
23
+
24
+ # Supported language codes
25
+ SUPPORTED_LANGUAGES = {
26
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
27
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
28
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
29
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
30
+ "kan_Knda", "ory_Orya"
31
+ }
32
+
33
+ class Settings(BaseSettings):
34
+ llm_model_name: str = "google/gemma-3-4b-it"
35
+ max_tokens: int = 512
36
+ host: str = "0.0.0.0"
37
+ port: int = 7860
38
+ chat_rate_limit: str = "100/minute"
39
+ speech_rate_limit: str = "5/minute"
40
+
41
+ @field_validator("chat_rate_limit", "speech_rate_limit")
42
+ def validate_rate_limit(cls, v):
43
+ if not v.count("/") == 1 or not v.split("/")[0].isdigit():
44
+ raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
45
+ return v
46
+
47
+ class Config:
48
+ env_file = ".env"
49
+
50
+ settings = Settings()
51
+
52
+ app = FastAPI(
53
+ title="Dhwani API",
54
+ description="AI Chat API supporting Indian languages",
55
+ version="1.0.0",
56
+ redirect_slashes=False,
57
+ )
58
+ app.add_middleware(
59
+ CORSMiddleware,
60
+ allow_origins=["*"],
61
+ allow_credentials=False,
62
+ allow_methods=["*"],
63
+ allow_headers=["*"],
64
+ )
65
+
66
+ limiter = Limiter(key_func=get_remote_address)
67
+ app.state.limiter = limiter
68
+
69
+ llm_manager = LLMManager(settings.llm_model_name)
70
+
71
+ class ChatRequest(BaseModel):
72
+ prompt: str
73
+ src_lang: str = "kan_Knda" # Default to Kannada
74
+ tgt_lang: str = "kan_Knda" # Default to Kannada
75
+
76
+ @field_validator("prompt")
77
+ def prompt_must_be_valid(cls, v):
78
+ if len(v) > 1000:
79
+ raise ValueError("Prompt cannot exceed 1000 characters")
80
+ return v.strip()
81
+
82
+ @field_validator("src_lang", "tgt_lang")
83
+ def validate_language(cls, v):
84
+ if v not in SUPPORTED_LANGUAGES:
85
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
86
+ return v
87
+
88
+ class ChatResponse(BaseModel):
89
+ response: str
90
+
91
+ class TranslationRequest(BaseModel):
92
+ sentences: List[str]
93
+ src_lang: str
94
+ tgt_lang: str
95
+
96
+ @field_validator("src_lang", "tgt_lang")
97
+ def validate_language(cls, v):
98
+ if v not in SUPPORTED_LANGUAGES:
99
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
100
+ return v
101
+
102
+ class TranslationResponse(BaseModel):
103
+ translations: List[str]
104
+
105
+ async def call_external_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
106
+ external_url = "https://gaganyatri-dhwani-server.hf.space/v1/translate"
107
+ payload = {
108
+ "sentences": sentences,
109
+ "src_lang": src_lang,
110
+ "tgt_lang": tgt_lang
111
+ }
112
+ try:
113
+ response = requests.post(
114
+ external_url,
115
+ json=payload,
116
+ headers={
117
+ "accept": "application/json",
118
+ "Content-Type": "application/json"
119
+ },
120
+ timeout=10
121
+ )
122
+ response.raise_for_status()
123
+ translations = response.json().get("translations", [])
124
+ if not translations or len(translations) != len(sentences):
125
+ logger.warning(f"Unexpected response format: {response.json()}")
126
+ raise ValueError("Invalid response from translation service")
127
+ return translations
128
+ except requests.Timeout:
129
+ logger.error("Translation request timed out")
130
+ raise HTTPException(status_code=504, detail="Translation service timeout")
131
+ except requests.RequestException as e:
132
+ logger.error(f"Error during translation: {str(e)}")
133
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
134
+ except ValueError as e:
135
+ logger.error(f"Invalid response: {str(e)}")
136
+ raise HTTPException(status_code=500, detail=str(e))
137
+
138
+ @app.get("/v1/health")
139
+ async def health_check():
140
+ return {"status": "healthy", "model": settings.llm_model_name}
141
+
142
+ @app.get("/")
143
+ async def home():
144
+ return RedirectResponse(url="/docs")
145
+
146
+ @app.post("/v1/unload_all_models")
147
+ async def unload_all_models(api_key: str = Depends(get_api_key)):
148
+ try:
149
+ logger.info("Starting to unload all models...")
150
+ llm_manager.unload()
151
+ logger.info("All models unloaded successfully")
152
+ return {"status": "success", "message": "All models unloaded"}
153
+ except Exception as e:
154
+ logger.error(f"Error unloading models: {str(e)}")
155
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
156
+
157
+ @app.post("/v1/load_all_models")
158
+ async def load_all_models(api_key: str = Depends(get_api_key)):
159
+ try:
160
+ logger.info("Starting to load all models...")
161
+ llm_manager.load()
162
+ logger.info("All models loaded successfully")
163
+ return {"status": "success", "message": "All models loaded"}
164
+ except Exception as e:
165
+ logger.error(f"Error loading models: {str(e)}")
166
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
167
+
168
+ @app.post("/v1/translate", response_model=TranslationResponse)
169
+ async def translate(request: TranslationRequest):
170
+ logger.info(f"Received translation request: {request.dict()}")
171
+ try:
172
+ translations = await call_external_translation(
173
+ sentences=request.sentences,
174
+ src_lang=request.src_lang,
175
+ tgt_lang=request.tgt_lang
176
+ )
177
+ logger.info(f"Translation successful: {translations}")
178
+ return TranslationResponse(translations=translations)
179
+ except HTTPException as e:
180
+ raise e
181
+ except Exception as e:
182
+ logger.error(f"Unexpected error during translation: {str(e)}")
183
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
184
+
185
+ @app.post("/v1/chat", response_model=ChatResponse)
186
+ @limiter.limit(settings.chat_rate_limit)
187
+ async def chat(request: Request, chat_request: ChatRequest, api_key: str = Depends(get_api_key)):
188
+ if not chat_request.prompt:
189
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
190
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
191
+ try:
192
+ # Translate prompt to English if src_lang is not English
193
+ if chat_request.src_lang != "eng_Latn":
194
+ translated_prompt = await call_external_translation(
195
+ sentences=[chat_request.prompt],
196
+ src_lang=chat_request.src_lang,
197
+ tgt_lang="eng_Latn"
198
+ )
199
+ prompt_to_process = translated_prompt[0]
200
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
201
+ else:
202
+ prompt_to_process = chat_request.prompt
203
+ logger.info("Prompt already in English, no translation needed")
204
+
205
+ # Generate response in English
206
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
207
+ logger.info(f"Generated English response: {response}")
208
+
209
+ # Translate response to target language if tgt_lang is not English
210
+ if chat_request.tgt_lang != "eng_Latn":
211
+ translated_response = await call_external_translation(
212
+ sentences=[response],
213
+ src_lang="eng_Latn",
214
+ tgt_lang=chat_request.tgt_lang
215
+ )
216
+ final_response = translated_response[0]
217
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
218
+ else:
219
+ final_response = response
220
+ logger.info("Response kept in English, no translation needed")
221
+
222
+ return ChatResponse(response=final_response)
223
+ except Exception as e:
224
+ logger.error(f"Error processing request: {str(e)}")
225
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
226
+
227
+ @app.post("/v1/visual_query/")
228
+ async def visual_query(
229
+ file: UploadFile = File(...),
230
+ query: str = Body(...),
231
+ src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
232
+ tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
233
+ api_key: str = Depends(get_api_key)
234
+ ):
235
+ try:
236
+ image = Image.open(file.file)
237
+ if image.size == (0, 0):
238
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
239
+
240
+ # Translate query to English if src_lang is not English
241
+ if src_lang != "eng_Latn":
242
+ translated_query = await call_external_translation(
243
+ sentences=[query],
244
+ src_lang=src_lang,
245
+ tgt_lang="eng_Latn"
246
+ )
247
+ query_to_process = translated_query[0]
248
+ logger.info(f"Translated query to English: {query_to_process}")
249
+ else:
250
+ query_to_process = query
251
+ logger.info("Query already in English, no translation needed")
252
+
253
+ # Generate response in English
254
+ answer = await llm_manager.vision_query(image, query_to_process)
255
+ logger.info(f"Generated English answer: {answer}")
256
+
257
+ # Translate answer to target language if tgt_lang is not English
258
+ if tgt_lang != "eng_Latn":
259
+ translated_answer = await call_external_translation(
260
+ sentences=[answer],
261
+ src_lang="eng_Latn",
262
+ tgt_lang=tgt_lang
263
+ )
264
+ final_answer = translated_answer[0]
265
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
266
+ else:
267
+ final_answer = answer
268
+ logger.info("Answer kept in English, no translation needed")
269
+
270
+ return {"answer": final_answer}
271
+ except Exception as e:
272
+ logger.error(f"Error processing request: {str(e)}")
273
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
274
+
275
+ @app.post("/v1/chat_v2", response_model=ChatResponse)
276
+ @limiter.limit(settings.chat_rate_limit)
277
+ async def chat_v2(
278
+ request: Request,
279
+ prompt: str = Form(...),
280
+ image: UploadFile = File(default=None),
281
+ src_lang: str = Form("kan_Knda"),
282
+ tgt_lang: str = Form("kan_Knda"),
283
+ api_key: str = Depends(get_api_key)
284
+ ):
285
+ if not prompt:
286
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
287
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
288
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
289
+
290
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
291
+
292
+ try:
293
+ if image:
294
+ image_data = await image.read()
295
+ if not image_data:
296
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
297
+ img = Image.open(io.BytesIO(image_data))
298
+
299
+ # Translate prompt to English if src_lang is not English
300
+ if src_lang != "eng_Latn":
301
+ translated_prompt = await call_external_translation(
302
+ sentences=[prompt],
303
+ src_lang=src_lang,
304
+ tgt_lang="eng_Latn"
305
+ )
306
+ prompt_to_process = translated_prompt[0]
307
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
308
+ else:
309
+ prompt_to_process = prompt
310
+ logger.info("Prompt already in English, no translation needed")
311
+
312
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
313
+ logger.info(f"Generated English response: {decoded}")
314
+
315
+ # Translate response to target language if tgt_lang is not English
316
+ if tgt_lang != "eng_Latn":
317
+ translated_response = await call_external_translation(
318
+ sentences=[decoded],
319
+ src_lang="eng_Latn",
320
+ tgt_lang=tgt_lang
321
+ )
322
+ final_response = translated_response[0]
323
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
324
+ else:
325
+ final_response = decoded
326
+ logger.info("Response kept in English, no translation needed")
327
+ else:
328
+ # Translate prompt to English if src_lang is not English
329
+ if src_lang != "eng_Latn":
330
+ translated_prompt = await call_external_translation(
331
+ sentences=[prompt],
332
+ src_lang=src_lang,
333
+ tgt_lang="eng_Latn"
334
+ )
335
+ prompt_to_process = translated_prompt[0]
336
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
337
+ else:
338
+ prompt_to_process = prompt
339
+ logger.info("Prompt already in English, no translation needed")
340
+
341
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
342
+ logger.info(f"Generated English response: {decoded}")
343
+
344
+ # Translate response to target language if tgt_lang is not English
345
+ if tgt_lang != "eng_Latn":
346
+ translated_response = await call_external_translation(
347
+ sentences=[decoded],
348
+ src_lang="eng_Latn",
349
+ tgt_lang=tgt_lang
350
+ )
351
+ final_response = translated_response[0]
352
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
353
+ else:
354
+ final_response = decoded
355
+ logger.info("Response kept in English, no translation needed")
356
+
357
+ return ChatResponse(response=final_response)
358
+ except Exception as e:
359
+ logger.error(f"Error processing request: {str(e)}")
360
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
361
+
362
+ if __name__ == "__main__":
363
+ parser = argparse.ArgumentParser(description="Run the FastAPI server.")
364
+ parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
365
+ parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
366
+ args = parser.parse_args()
367
+ uvicorn.run(app, host=args.host, port=args.port)
src/server/text.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ def chunk_text(text: str, chunk_size: int = 15) -> list[str]:
2
+ words = text.split()
3
+ return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
src/server/tts_config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from pydantic_settings import BaseSettings
3
+
4
+ SPEED = 1.0
5
+
6
+ class StrEnum(str, enum.Enum):
7
+ def __str__(self):
8
+ return str(self.value)
9
+
10
+ class ResponseFormat(StrEnum):
11
+ MP3 = "mp3"
12
+ FLAC = "flac"
13
+ WAV = "wav"
14
+
15
+ class Config(BaseSettings):
16
+ log_level: str = "info"
17
+ model: str = "ai4bharat/indic-parler-tts"
18
+ max_models: int = 1
19
+ lazy_load_model: bool = False # Unused now, as all models are lazy-loaded
20
+ input: str = "ನಿಮ್ಮ ಇನ್‌ಪುಟ್ ಪಠ್ಯವನ್ನು ಇಲ್ಲಿ ಸೇರಿಸಿ"
21
+ voice: str = (
22
+ "Female speaks with a high pitch at a normal pace in a clear, close-sounding environment. "
23
+ "Her neutral tone is captured with excellent audio quality."
24
+ )
25
+ response_format: ResponseFormat = ResponseFormat.MP3
26
+
27
+ config = Config()