sachin
commited on
Commit
·
475b0b9
1
Parent(s):
b9d432e
test-gemma-llm
Browse files- .gitignore +176 -0
- Dockerfile +36 -0
- requirements.txt +10 -0
- src/server/auth.py +21 -0
- src/server/gemma_llm.py +186 -0
- src/server/logging_config.py +35 -0
- src/server/main.py +367 -0
- src/server/text.py +3 -0
- src/server/tts_config.py +27 -0
.gitignore
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
venv_new/
|
6 |
+
NeMo/
|
7 |
+
asr_venv
|
8 |
+
*.nemo
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
*.mp3
|
12 |
+
|
13 |
+
!kannada_female.wav
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# UV
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
#uv.lock
|
107 |
+
|
108 |
+
# poetry
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
110 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
111 |
+
# commonly ignored for libraries.
|
112 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
113 |
+
#poetry.lock
|
114 |
+
|
115 |
+
# pdm
|
116 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
117 |
+
#pdm.lock
|
118 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
119 |
+
# in version control.
|
120 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
121 |
+
.pdm.toml
|
122 |
+
.pdm-python
|
123 |
+
.pdm-build/
|
124 |
+
|
125 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
126 |
+
__pypackages__/
|
127 |
+
|
128 |
+
# Celery stuff
|
129 |
+
celerybeat-schedule
|
130 |
+
celerybeat.pid
|
131 |
+
|
132 |
+
# SageMath parsed files
|
133 |
+
*.sage.py
|
134 |
+
|
135 |
+
# Environments
|
136 |
+
.env
|
137 |
+
.venv
|
138 |
+
env/
|
139 |
+
venv/
|
140 |
+
ENV/
|
141 |
+
env.bak/
|
142 |
+
venv.bak/
|
143 |
+
|
144 |
+
# Spyder project settings
|
145 |
+
.spyderproject
|
146 |
+
.spyproject
|
147 |
+
|
148 |
+
# Rope project settings
|
149 |
+
.ropeproject
|
150 |
+
|
151 |
+
# mkdocs documentation
|
152 |
+
/site
|
153 |
+
|
154 |
+
# mypy
|
155 |
+
.mypy_cache/
|
156 |
+
.dmypy.json
|
157 |
+
dmypy.json
|
158 |
+
|
159 |
+
# Pyre type checker
|
160 |
+
.pyre/
|
161 |
+
|
162 |
+
# pytype static type analyzer
|
163 |
+
.pytype/
|
164 |
+
|
165 |
+
# Cython debug symbols
|
166 |
+
cython_debug/
|
167 |
+
|
168 |
+
# PyCharm
|
169 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
170 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
171 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
172 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
173 |
+
#.idea/
|
174 |
+
|
175 |
+
# PyPI configuration file
|
176 |
+
.pypirc
|
Dockerfile
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:22.04
|
2 |
+
WORKDIR /app
|
3 |
+
|
4 |
+
RUN apt-get update && apt-get install -y \
|
5 |
+
python3 \
|
6 |
+
python3-pip \
|
7 |
+
git \
|
8 |
+
ffmpeg \
|
9 |
+
sudo \
|
10 |
+
wget libvips\
|
11 |
+
build-essential \
|
12 |
+
curl \
|
13 |
+
&& ln -s /usr/bin/python3 /usr/bin/python \
|
14 |
+
&& rm -rf /var/lib/apt/lists/*
|
15 |
+
|
16 |
+
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
17 |
+
ENV PATH="/root/.cargo/bin:${PATH}"
|
18 |
+
|
19 |
+
RUN pip install --upgrade pip setuptools setuptools-rust torch
|
20 |
+
COPY requirements.txt .
|
21 |
+
#RUN pip install --no-cache-dir torch==2.6.0 torchvision
|
22 |
+
#RUN pip install --no-cache-dir transformers
|
23 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
24 |
+
RUN pip install git+https://github.com/ai4bharat/IndicF5.git
|
25 |
+
|
26 |
+
COPY . .
|
27 |
+
|
28 |
+
RUN useradd -ms /bin/bash appuser \
|
29 |
+
&& chown -R appuser:appuser /app
|
30 |
+
|
31 |
+
USER appuser
|
32 |
+
|
33 |
+
EXPOSE 7860
|
34 |
+
|
35 |
+
# Use absolute path for clarity
|
36 |
+
CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860"]
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers @ git+https://github.com/huggingface/[email protected]
|
2 |
+
torch
|
3 |
+
accelerate
|
4 |
+
bitsandbytes
|
5 |
+
pillow
|
6 |
+
uvicorn
|
7 |
+
fastapi
|
8 |
+
pydantic_settings
|
9 |
+
slowapi
|
10 |
+
python-multipart
|
src/server/auth.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.security import APIKeyHeader
|
2 |
+
from fastapi import HTTPException, status, Depends
|
3 |
+
from pydantic_settings import BaseSettings
|
4 |
+
from config.logging_config import logger
|
5 |
+
|
6 |
+
class Settings(BaseSettings):
|
7 |
+
api_key: str
|
8 |
+
class Config:
|
9 |
+
env_file = ".env"
|
10 |
+
|
11 |
+
settings = Settings()
|
12 |
+
|
13 |
+
API_KEY_NAME = "X-API-Key"
|
14 |
+
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
15 |
+
|
16 |
+
async def get_api_key(api_key: str = Depends(api_key_header)):
|
17 |
+
if api_key != settings.api_key:
|
18 |
+
logger.warning(f"Failed API key attempt: {api_key}")
|
19 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
|
20 |
+
logger.info("API key validated successfully")
|
21 |
+
return api_key
|
src/server/gemma_llm.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from config.logging_config import logger
|
3 |
+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
4 |
+
from PIL import Image
|
5 |
+
from fastapi import HTTPException
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
|
9 |
+
class LLMManager:
|
10 |
+
def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
11 |
+
self.model_name = model_name
|
12 |
+
self.device = torch.device(device)
|
13 |
+
self.torch_dtype = torch.float16 if self.device.type != "cpu" else torch.float32
|
14 |
+
self.model = None
|
15 |
+
self.is_loaded = False
|
16 |
+
self.processor = None
|
17 |
+
|
18 |
+
def unload(self):
|
19 |
+
if self.is_loaded:
|
20 |
+
# Delete the model and processor to free memory
|
21 |
+
del self.model
|
22 |
+
del self.processor
|
23 |
+
# If using CUDA, clear the cache to free GPU memory
|
24 |
+
if self.device.type == "cuda":
|
25 |
+
torch.cuda.empty_cache()
|
26 |
+
self.is_loaded = False
|
27 |
+
logger.info(f"LLM {self.model_name} unloaded from {self.device}")
|
28 |
+
def load(self):
|
29 |
+
if not self.is_loaded:
|
30 |
+
|
31 |
+
#self.model_name = "google/gemma-3-4b-it"
|
32 |
+
|
33 |
+
self.model = Gemma3ForConditionalGeneration.from_pretrained(
|
34 |
+
self.model_name, device_map="auto"
|
35 |
+
).eval()
|
36 |
+
|
37 |
+
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
38 |
+
|
39 |
+
self.is_loaded = True
|
40 |
+
logger.info(f"LLM {self.model_name} loaded on {self.device}")
|
41 |
+
|
42 |
+
async def generate(self, prompt: str, max_tokens: int = 2048, temperature: float = 0.7) -> str:
|
43 |
+
if not self.is_loaded:
|
44 |
+
self.load()
|
45 |
+
|
46 |
+
messages_vlm = [
|
47 |
+
{
|
48 |
+
"role": "system",
|
49 |
+
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state, Provide a concise response in one sentence maximum."}]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"role": "user",
|
53 |
+
"content": []
|
54 |
+
}
|
55 |
+
]
|
56 |
+
|
57 |
+
# Add text prompt to user content
|
58 |
+
messages_vlm[1]["content"].append({"type": "text", "text": prompt})
|
59 |
+
|
60 |
+
# Process the chat template with the processor
|
61 |
+
inputs_vlm = self.processor.apply_chat_template(
|
62 |
+
messages_vlm,
|
63 |
+
add_generation_prompt=True,
|
64 |
+
tokenize=True,
|
65 |
+
return_dict=True,
|
66 |
+
return_tensors="pt"
|
67 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
68 |
+
|
69 |
+
input_len = inputs_vlm["input_ids"].shape[-1]
|
70 |
+
|
71 |
+
# Generate response
|
72 |
+
with torch.inference_mode():
|
73 |
+
generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
|
74 |
+
generation = generation[0][input_len:]
|
75 |
+
|
76 |
+
# Decode the output
|
77 |
+
response = self.processor.decode(generation, skip_special_tokens=True)
|
78 |
+
|
79 |
+
return response
|
80 |
+
|
81 |
+
async def vision_query(self, image: Image.Image, query: str) -> str:
|
82 |
+
if not self.is_loaded:
|
83 |
+
self.load()
|
84 |
+
|
85 |
+
messages_vlm = [
|
86 |
+
{
|
87 |
+
"role": "system",
|
88 |
+
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarise your answer in max 2 lines."}]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"role": "user",
|
92 |
+
"content": []
|
93 |
+
}
|
94 |
+
]
|
95 |
+
|
96 |
+
# Add text prompt to user content
|
97 |
+
messages_vlm[1]["content"].append({"type": "text", "text": query})
|
98 |
+
|
99 |
+
# Handle image if provided and valid
|
100 |
+
if image and image.size[0] > 0 and image.size[1] > 0: # Check for valid dimensions
|
101 |
+
# Image is already a PIL Image, no need to read or reopen
|
102 |
+
messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
|
103 |
+
logger.info(f"Received valid image for processing")
|
104 |
+
else:
|
105 |
+
logger.info("No valid image provided, processing text only")
|
106 |
+
|
107 |
+
# Process the chat template with the processor
|
108 |
+
try:
|
109 |
+
inputs_vlm = self.processor.apply_chat_template(
|
110 |
+
messages_vlm,
|
111 |
+
add_generation_prompt=True,
|
112 |
+
tokenize=True,
|
113 |
+
return_dict=True,
|
114 |
+
return_tensors="pt"
|
115 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"Error in apply_chat_template: {str(e)}")
|
118 |
+
raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
|
119 |
+
|
120 |
+
input_len = inputs_vlm["input_ids"].shape[-1]
|
121 |
+
|
122 |
+
# Generate response
|
123 |
+
with torch.inference_mode():
|
124 |
+
generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
|
125 |
+
generation = generation[0][input_len:]
|
126 |
+
|
127 |
+
# Decode the output
|
128 |
+
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
129 |
+
logger.info(f"Chat Response: {decoded}")
|
130 |
+
|
131 |
+
return decoded
|
132 |
+
|
133 |
+
async def chat_v2(self, image: Image.Image, query: str) -> str:
|
134 |
+
if not self.is_loaded:
|
135 |
+
self.load()
|
136 |
+
# Construct the message structure
|
137 |
+
messages_vlm = [
|
138 |
+
{
|
139 |
+
"role": "system",
|
140 |
+
"content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and karnataka as base state"}]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"role": "user",
|
144 |
+
"content": []
|
145 |
+
}
|
146 |
+
]
|
147 |
+
|
148 |
+
# Add text prompt to user content
|
149 |
+
messages_vlm[1]["content"].append({"type": "text", "text": query})
|
150 |
+
|
151 |
+
# Handle image only if provided and valid
|
152 |
+
if image and image.file and image.size > 0: # Check for valid file with content
|
153 |
+
# Read the image file
|
154 |
+
image_data = await image.read()
|
155 |
+
if not image_data:
|
156 |
+
raise HTTPException(status_code=400, detail="Uploaded image is empty")
|
157 |
+
# Open image with PIL for processing
|
158 |
+
img = Image.open(BytesIO(image_data))
|
159 |
+
# Add image to content (assuming processor accepts PIL images)
|
160 |
+
messages_vlm[1]["content"].insert(0, {"type": "image", "image": img})
|
161 |
+
logger.info(f"Received image: {image.filename}")
|
162 |
+
else:
|
163 |
+
if image and (not image.file or image.size == 0):
|
164 |
+
logger.warning("Received invalid or empty image parameter, treating as text-only")
|
165 |
+
logger.info("No valid image provided, processing text only")
|
166 |
+
|
167 |
+
# Process the chat template with the processor
|
168 |
+
inputs_vlm = self.processor.apply_chat_template(
|
169 |
+
messages_vlm,
|
170 |
+
add_generation_prompt=True,
|
171 |
+
tokenize=True,
|
172 |
+
return_dict=True,
|
173 |
+
return_tensors="pt"
|
174 |
+
).to(self.model.device, dtype=torch.bfloat16)
|
175 |
+
|
176 |
+
input_len = inputs_vlm["input_ids"].shape[-1]
|
177 |
+
|
178 |
+
# Generate response
|
179 |
+
with torch.inference_mode():
|
180 |
+
generation = self.model.generate(**inputs_vlm, max_new_tokens=100, do_sample=False)
|
181 |
+
generation = generation[0][input_len:]
|
182 |
+
|
183 |
+
# Decode the output
|
184 |
+
decoded = self.processor.decode(generation, skip_special_tokens=True)
|
185 |
+
logger.info(f"Chat Response: {decoded}")
|
186 |
+
return decoded
|
src/server/logging_config.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.config
|
3 |
+
from logging.handlers import RotatingFileHandler
|
4 |
+
from .tts_config import config
|
5 |
+
|
6 |
+
logging_config = {
|
7 |
+
"version": 1,
|
8 |
+
"disable_existing_loggers": False,
|
9 |
+
"formatters": {
|
10 |
+
"simple": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"},
|
11 |
+
},
|
12 |
+
"handlers": {
|
13 |
+
"stdout": {
|
14 |
+
"class": "logging.StreamHandler",
|
15 |
+
"formatter": "simple",
|
16 |
+
"stream": "ext://sys.stdout",
|
17 |
+
},
|
18 |
+
"file": {
|
19 |
+
"class": "logging.handlers.RotatingFileHandler",
|
20 |
+
"formatter": "simple",
|
21 |
+
"filename": "dhwani_api.log",
|
22 |
+
"maxBytes": 10 * 1024 * 1024, # 10MB
|
23 |
+
"backupCount": 5,
|
24 |
+
},
|
25 |
+
},
|
26 |
+
"loggers": {
|
27 |
+
"root": {
|
28 |
+
"level": config.log_level.upper(),
|
29 |
+
"handlers": ["stdout", "file"],
|
30 |
+
},
|
31 |
+
},
|
32 |
+
}
|
33 |
+
|
34 |
+
logging.config.dictConfig(logging_config)
|
35 |
+
logger = logging.getLogger("indic_all_server")
|
src/server/main.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
from time import time
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import tempfile
|
8 |
+
import uvicorn
|
9 |
+
from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
|
10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
11 |
+
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
|
12 |
+
from PIL import Image
|
13 |
+
from pydantic import BaseModel, field_validator
|
14 |
+
from pydantic_settings import BaseSettings
|
15 |
+
from slowapi import Limiter
|
16 |
+
from slowapi.util import get_remote_address
|
17 |
+
import requests
|
18 |
+
|
19 |
+
from logging_config import logger
|
20 |
+
from tts_config import SPEED, ResponseFormat, config as tts_config
|
21 |
+
from gemma_llm import LLMManager
|
22 |
+
from utils.auth import get_api_key, settings as auth_settings
|
23 |
+
|
24 |
+
# Supported language codes
|
25 |
+
SUPPORTED_LANGUAGES = {
|
26 |
+
"asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
|
27 |
+
"brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
|
28 |
+
"eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
|
29 |
+
"guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
|
30 |
+
"kan_Knda", "ory_Orya"
|
31 |
+
}
|
32 |
+
|
33 |
+
class Settings(BaseSettings):
|
34 |
+
llm_model_name: str = "google/gemma-3-4b-it"
|
35 |
+
max_tokens: int = 512
|
36 |
+
host: str = "0.0.0.0"
|
37 |
+
port: int = 7860
|
38 |
+
chat_rate_limit: str = "100/minute"
|
39 |
+
speech_rate_limit: str = "5/minute"
|
40 |
+
|
41 |
+
@field_validator("chat_rate_limit", "speech_rate_limit")
|
42 |
+
def validate_rate_limit(cls, v):
|
43 |
+
if not v.count("/") == 1 or not v.split("/")[0].isdigit():
|
44 |
+
raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
|
45 |
+
return v
|
46 |
+
|
47 |
+
class Config:
|
48 |
+
env_file = ".env"
|
49 |
+
|
50 |
+
settings = Settings()
|
51 |
+
|
52 |
+
app = FastAPI(
|
53 |
+
title="Dhwani API",
|
54 |
+
description="AI Chat API supporting Indian languages",
|
55 |
+
version="1.0.0",
|
56 |
+
redirect_slashes=False,
|
57 |
+
)
|
58 |
+
app.add_middleware(
|
59 |
+
CORSMiddleware,
|
60 |
+
allow_origins=["*"],
|
61 |
+
allow_credentials=False,
|
62 |
+
allow_methods=["*"],
|
63 |
+
allow_headers=["*"],
|
64 |
+
)
|
65 |
+
|
66 |
+
limiter = Limiter(key_func=get_remote_address)
|
67 |
+
app.state.limiter = limiter
|
68 |
+
|
69 |
+
llm_manager = LLMManager(settings.llm_model_name)
|
70 |
+
|
71 |
+
class ChatRequest(BaseModel):
|
72 |
+
prompt: str
|
73 |
+
src_lang: str = "kan_Knda" # Default to Kannada
|
74 |
+
tgt_lang: str = "kan_Knda" # Default to Kannada
|
75 |
+
|
76 |
+
@field_validator("prompt")
|
77 |
+
def prompt_must_be_valid(cls, v):
|
78 |
+
if len(v) > 1000:
|
79 |
+
raise ValueError("Prompt cannot exceed 1000 characters")
|
80 |
+
return v.strip()
|
81 |
+
|
82 |
+
@field_validator("src_lang", "tgt_lang")
|
83 |
+
def validate_language(cls, v):
|
84 |
+
if v not in SUPPORTED_LANGUAGES:
|
85 |
+
raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
|
86 |
+
return v
|
87 |
+
|
88 |
+
class ChatResponse(BaseModel):
|
89 |
+
response: str
|
90 |
+
|
91 |
+
class TranslationRequest(BaseModel):
|
92 |
+
sentences: List[str]
|
93 |
+
src_lang: str
|
94 |
+
tgt_lang: str
|
95 |
+
|
96 |
+
@field_validator("src_lang", "tgt_lang")
|
97 |
+
def validate_language(cls, v):
|
98 |
+
if v not in SUPPORTED_LANGUAGES:
|
99 |
+
raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
|
100 |
+
return v
|
101 |
+
|
102 |
+
class TranslationResponse(BaseModel):
|
103 |
+
translations: List[str]
|
104 |
+
|
105 |
+
async def call_external_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
106 |
+
external_url = "https://gaganyatri-dhwani-server.hf.space/v1/translate"
|
107 |
+
payload = {
|
108 |
+
"sentences": sentences,
|
109 |
+
"src_lang": src_lang,
|
110 |
+
"tgt_lang": tgt_lang
|
111 |
+
}
|
112 |
+
try:
|
113 |
+
response = requests.post(
|
114 |
+
external_url,
|
115 |
+
json=payload,
|
116 |
+
headers={
|
117 |
+
"accept": "application/json",
|
118 |
+
"Content-Type": "application/json"
|
119 |
+
},
|
120 |
+
timeout=10
|
121 |
+
)
|
122 |
+
response.raise_for_status()
|
123 |
+
translations = response.json().get("translations", [])
|
124 |
+
if not translations or len(translations) != len(sentences):
|
125 |
+
logger.warning(f"Unexpected response format: {response.json()}")
|
126 |
+
raise ValueError("Invalid response from translation service")
|
127 |
+
return translations
|
128 |
+
except requests.Timeout:
|
129 |
+
logger.error("Translation request timed out")
|
130 |
+
raise HTTPException(status_code=504, detail="Translation service timeout")
|
131 |
+
except requests.RequestException as e:
|
132 |
+
logger.error(f"Error during translation: {str(e)}")
|
133 |
+
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
|
134 |
+
except ValueError as e:
|
135 |
+
logger.error(f"Invalid response: {str(e)}")
|
136 |
+
raise HTTPException(status_code=500, detail=str(e))
|
137 |
+
|
138 |
+
@app.get("/v1/health")
|
139 |
+
async def health_check():
|
140 |
+
return {"status": "healthy", "model": settings.llm_model_name}
|
141 |
+
|
142 |
+
@app.get("/")
|
143 |
+
async def home():
|
144 |
+
return RedirectResponse(url="/docs")
|
145 |
+
|
146 |
+
@app.post("/v1/unload_all_models")
|
147 |
+
async def unload_all_models(api_key: str = Depends(get_api_key)):
|
148 |
+
try:
|
149 |
+
logger.info("Starting to unload all models...")
|
150 |
+
llm_manager.unload()
|
151 |
+
logger.info("All models unloaded successfully")
|
152 |
+
return {"status": "success", "message": "All models unloaded"}
|
153 |
+
except Exception as e:
|
154 |
+
logger.error(f"Error unloading models: {str(e)}")
|
155 |
+
raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
|
156 |
+
|
157 |
+
@app.post("/v1/load_all_models")
|
158 |
+
async def load_all_models(api_key: str = Depends(get_api_key)):
|
159 |
+
try:
|
160 |
+
logger.info("Starting to load all models...")
|
161 |
+
llm_manager.load()
|
162 |
+
logger.info("All models loaded successfully")
|
163 |
+
return {"status": "success", "message": "All models loaded"}
|
164 |
+
except Exception as e:
|
165 |
+
logger.error(f"Error loading models: {str(e)}")
|
166 |
+
raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
|
167 |
+
|
168 |
+
@app.post("/v1/translate", response_model=TranslationResponse)
|
169 |
+
async def translate(request: TranslationRequest):
|
170 |
+
logger.info(f"Received translation request: {request.dict()}")
|
171 |
+
try:
|
172 |
+
translations = await call_external_translation(
|
173 |
+
sentences=request.sentences,
|
174 |
+
src_lang=request.src_lang,
|
175 |
+
tgt_lang=request.tgt_lang
|
176 |
+
)
|
177 |
+
logger.info(f"Translation successful: {translations}")
|
178 |
+
return TranslationResponse(translations=translations)
|
179 |
+
except HTTPException as e:
|
180 |
+
raise e
|
181 |
+
except Exception as e:
|
182 |
+
logger.error(f"Unexpected error during translation: {str(e)}")
|
183 |
+
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
|
184 |
+
|
185 |
+
@app.post("/v1/chat", response_model=ChatResponse)
|
186 |
+
@limiter.limit(settings.chat_rate_limit)
|
187 |
+
async def chat(request: Request, chat_request: ChatRequest, api_key: str = Depends(get_api_key)):
|
188 |
+
if not chat_request.prompt:
|
189 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
190 |
+
logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
|
191 |
+
try:
|
192 |
+
# Translate prompt to English if src_lang is not English
|
193 |
+
if chat_request.src_lang != "eng_Latn":
|
194 |
+
translated_prompt = await call_external_translation(
|
195 |
+
sentences=[chat_request.prompt],
|
196 |
+
src_lang=chat_request.src_lang,
|
197 |
+
tgt_lang="eng_Latn"
|
198 |
+
)
|
199 |
+
prompt_to_process = translated_prompt[0]
|
200 |
+
logger.info(f"Translated prompt to English: {prompt_to_process}")
|
201 |
+
else:
|
202 |
+
prompt_to_process = chat_request.prompt
|
203 |
+
logger.info("Prompt already in English, no translation needed")
|
204 |
+
|
205 |
+
# Generate response in English
|
206 |
+
response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
207 |
+
logger.info(f"Generated English response: {response}")
|
208 |
+
|
209 |
+
# Translate response to target language if tgt_lang is not English
|
210 |
+
if chat_request.tgt_lang != "eng_Latn":
|
211 |
+
translated_response = await call_external_translation(
|
212 |
+
sentences=[response],
|
213 |
+
src_lang="eng_Latn",
|
214 |
+
tgt_lang=chat_request.tgt_lang
|
215 |
+
)
|
216 |
+
final_response = translated_response[0]
|
217 |
+
logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
|
218 |
+
else:
|
219 |
+
final_response = response
|
220 |
+
logger.info("Response kept in English, no translation needed")
|
221 |
+
|
222 |
+
return ChatResponse(response=final_response)
|
223 |
+
except Exception as e:
|
224 |
+
logger.error(f"Error processing request: {str(e)}")
|
225 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
226 |
+
|
227 |
+
@app.post("/v1/visual_query/")
|
228 |
+
async def visual_query(
|
229 |
+
file: UploadFile = File(...),
|
230 |
+
query: str = Body(...),
|
231 |
+
src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
|
232 |
+
tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
|
233 |
+
api_key: str = Depends(get_api_key)
|
234 |
+
):
|
235 |
+
try:
|
236 |
+
image = Image.open(file.file)
|
237 |
+
if image.size == (0, 0):
|
238 |
+
raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
|
239 |
+
|
240 |
+
# Translate query to English if src_lang is not English
|
241 |
+
if src_lang != "eng_Latn":
|
242 |
+
translated_query = await call_external_translation(
|
243 |
+
sentences=[query],
|
244 |
+
src_lang=src_lang,
|
245 |
+
tgt_lang="eng_Latn"
|
246 |
+
)
|
247 |
+
query_to_process = translated_query[0]
|
248 |
+
logger.info(f"Translated query to English: {query_to_process}")
|
249 |
+
else:
|
250 |
+
query_to_process = query
|
251 |
+
logger.info("Query already in English, no translation needed")
|
252 |
+
|
253 |
+
# Generate response in English
|
254 |
+
answer = await llm_manager.vision_query(image, query_to_process)
|
255 |
+
logger.info(f"Generated English answer: {answer}")
|
256 |
+
|
257 |
+
# Translate answer to target language if tgt_lang is not English
|
258 |
+
if tgt_lang != "eng_Latn":
|
259 |
+
translated_answer = await call_external_translation(
|
260 |
+
sentences=[answer],
|
261 |
+
src_lang="eng_Latn",
|
262 |
+
tgt_lang=tgt_lang
|
263 |
+
)
|
264 |
+
final_answer = translated_answer[0]
|
265 |
+
logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
|
266 |
+
else:
|
267 |
+
final_answer = answer
|
268 |
+
logger.info("Answer kept in English, no translation needed")
|
269 |
+
|
270 |
+
return {"answer": final_answer}
|
271 |
+
except Exception as e:
|
272 |
+
logger.error(f"Error processing request: {str(e)}")
|
273 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
274 |
+
|
275 |
+
@app.post("/v1/chat_v2", response_model=ChatResponse)
|
276 |
+
@limiter.limit(settings.chat_rate_limit)
|
277 |
+
async def chat_v2(
|
278 |
+
request: Request,
|
279 |
+
prompt: str = Form(...),
|
280 |
+
image: UploadFile = File(default=None),
|
281 |
+
src_lang: str = Form("kan_Knda"),
|
282 |
+
tgt_lang: str = Form("kan_Knda"),
|
283 |
+
api_key: str = Depends(get_api_key)
|
284 |
+
):
|
285 |
+
if not prompt:
|
286 |
+
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
287 |
+
if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
|
288 |
+
raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
|
289 |
+
|
290 |
+
logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
|
291 |
+
|
292 |
+
try:
|
293 |
+
if image:
|
294 |
+
image_data = await image.read()
|
295 |
+
if not image_data:
|
296 |
+
raise HTTPException(status_code=400, detail="Uploaded image is empty")
|
297 |
+
img = Image.open(io.BytesIO(image_data))
|
298 |
+
|
299 |
+
# Translate prompt to English if src_lang is not English
|
300 |
+
if src_lang != "eng_Latn":
|
301 |
+
translated_prompt = await call_external_translation(
|
302 |
+
sentences=[prompt],
|
303 |
+
src_lang=src_lang,
|
304 |
+
tgt_lang="eng_Latn"
|
305 |
+
)
|
306 |
+
prompt_to_process = translated_prompt[0]
|
307 |
+
logger.info(f"Translated prompt to English: {prompt_to_process}")
|
308 |
+
else:
|
309 |
+
prompt_to_process = prompt
|
310 |
+
logger.info("Prompt already in English, no translation needed")
|
311 |
+
|
312 |
+
decoded = await llm_manager.chat_v2(img, prompt_to_process)
|
313 |
+
logger.info(f"Generated English response: {decoded}")
|
314 |
+
|
315 |
+
# Translate response to target language if tgt_lang is not English
|
316 |
+
if tgt_lang != "eng_Latn":
|
317 |
+
translated_response = await call_external_translation(
|
318 |
+
sentences=[decoded],
|
319 |
+
src_lang="eng_Latn",
|
320 |
+
tgt_lang=tgt_lang
|
321 |
+
)
|
322 |
+
final_response = translated_response[0]
|
323 |
+
logger.info(f"Translated response to {tgt_lang}: {final_response}")
|
324 |
+
else:
|
325 |
+
final_response = decoded
|
326 |
+
logger.info("Response kept in English, no translation needed")
|
327 |
+
else:
|
328 |
+
# Translate prompt to English if src_lang is not English
|
329 |
+
if src_lang != "eng_Latn":
|
330 |
+
translated_prompt = await call_external_translation(
|
331 |
+
sentences=[prompt],
|
332 |
+
src_lang=src_lang,
|
333 |
+
tgt_lang="eng_Latn"
|
334 |
+
)
|
335 |
+
prompt_to_process = translated_prompt[0]
|
336 |
+
logger.info(f"Translated prompt to English: {prompt_to_process}")
|
337 |
+
else:
|
338 |
+
prompt_to_process = prompt
|
339 |
+
logger.info("Prompt already in English, no translation needed")
|
340 |
+
|
341 |
+
decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
|
342 |
+
logger.info(f"Generated English response: {decoded}")
|
343 |
+
|
344 |
+
# Translate response to target language if tgt_lang is not English
|
345 |
+
if tgt_lang != "eng_Latn":
|
346 |
+
translated_response = await call_external_translation(
|
347 |
+
sentences=[decoded],
|
348 |
+
src_lang="eng_Latn",
|
349 |
+
tgt_lang=tgt_lang
|
350 |
+
)
|
351 |
+
final_response = translated_response[0]
|
352 |
+
logger.info(f"Translated response to {tgt_lang}: {final_response}")
|
353 |
+
else:
|
354 |
+
final_response = decoded
|
355 |
+
logger.info("Response kept in English, no translation needed")
|
356 |
+
|
357 |
+
return ChatResponse(response=final_response)
|
358 |
+
except Exception as e:
|
359 |
+
logger.error(f"Error processing request: {str(e)}")
|
360 |
+
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
|
361 |
+
|
362 |
+
if __name__ == "__main__":
|
363 |
+
parser = argparse.ArgumentParser(description="Run the FastAPI server.")
|
364 |
+
parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
|
365 |
+
parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
|
366 |
+
args = parser.parse_args()
|
367 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
src/server/text.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
def chunk_text(text: str, chunk_size: int = 15) -> list[str]:
|
2 |
+
words = text.split()
|
3 |
+
return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
|
src/server/tts_config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
from pydantic_settings import BaseSettings
|
3 |
+
|
4 |
+
SPEED = 1.0
|
5 |
+
|
6 |
+
class StrEnum(str, enum.Enum):
|
7 |
+
def __str__(self):
|
8 |
+
return str(self.value)
|
9 |
+
|
10 |
+
class ResponseFormat(StrEnum):
|
11 |
+
MP3 = "mp3"
|
12 |
+
FLAC = "flac"
|
13 |
+
WAV = "wav"
|
14 |
+
|
15 |
+
class Config(BaseSettings):
|
16 |
+
log_level: str = "info"
|
17 |
+
model: str = "ai4bharat/indic-parler-tts"
|
18 |
+
max_models: int = 1
|
19 |
+
lazy_load_model: bool = False # Unused now, as all models are lazy-loaded
|
20 |
+
input: str = "ನಿಮ್ಮ ಇನ್ಪುಟ್ ಪಠ್ಯವನ್ನು ಇಲ್ಲಿ ಸೇರಿಸಿ"
|
21 |
+
voice: str = (
|
22 |
+
"Female speaks with a high pitch at a normal pace in a clear, close-sounding environment. "
|
23 |
+
"Her neutral tone is captured with excellent audio quality."
|
24 |
+
)
|
25 |
+
response_format: ResponseFormat = ResponseFormat.MP3
|
26 |
+
|
27 |
+
config = Config()
|