Spaces:
Sleeping
Sleeping
Update app.py
Browse filesdebugging generation configs
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
-
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
4 |
import logging
|
5 |
import os
|
6 |
import sys
|
@@ -20,9 +20,37 @@ MODELS = {
|
|
20 |
"nidra-v2": "m1k3wn/nidra-v2"
|
21 |
}
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
class PredictionRequest(BaseModel):
|
24 |
inputs: str
|
25 |
model: str = "nidra-v1"
|
|
|
26 |
|
27 |
class PredictionResponse(BaseModel):
|
28 |
generated_text: str
|
@@ -38,6 +66,10 @@ async def health():
|
|
38 |
@app.post("/predict", response_model=PredictionResponse)
|
39 |
async def predict(request: PredictionRequest):
|
40 |
try:
|
|
|
|
|
|
|
|
|
41 |
logger.info(f"Loading model: {request.model}")
|
42 |
model_path = MODELS[request.model]
|
43 |
|
@@ -58,6 +90,27 @@ async def predict(request: PredictionRequest):
|
|
58 |
local_files_only=False
|
59 |
)
|
60 |
logger.info("Model loaded successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
full_input = "Interpret this dream: " + request.inputs
|
63 |
logger.info(f"Processing input: {full_input}")
|
@@ -73,7 +126,16 @@ async def predict(request: PredictionRequest):
|
|
73 |
logger.info("Input tokenized successfully")
|
74 |
|
75 |
logger.info("Generating output...")
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
logger.info("Output generated successfully")
|
78 |
|
79 |
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
|
4 |
import logging
|
5 |
import os
|
6 |
import sys
|
|
|
20 |
"nidra-v2": "m1k3wn/nidra-v2"
|
21 |
}
|
22 |
|
23 |
+
# Define default generation configurations for each model
|
24 |
+
DEFAULT_GENERATION_CONFIGS = {
|
25 |
+
"nidra-v1": {
|
26 |
+
"max_length": 300,
|
27 |
+
"min_length": 150,
|
28 |
+
"num_beams": 8,
|
29 |
+
"temperature": 0.55,
|
30 |
+
"do_sample": True,
|
31 |
+
"top_p": 0.95,
|
32 |
+
"repetition_penalty": 4.5,
|
33 |
+
"no_repeat_ngram_size": 4,
|
34 |
+
"early_stopping": True,
|
35 |
+
"length_penalty": 1.2,
|
36 |
+
},
|
37 |
+
"nidra-v2": {
|
38 |
+
"max_length": 300,
|
39 |
+
"min_length": 150,
|
40 |
+
"num_beams": 8,
|
41 |
+
"temperature": 0.4,
|
42 |
+
"do_sample": True,
|
43 |
+
"top_p": 0.95,
|
44 |
+
"repetition_penalty": 3.5,
|
45 |
+
"no_repeat_ngram_size": 4,
|
46 |
+
"early_stopping": True,
|
47 |
+
"length_penalty": 1.2,
|
48 |
+
}
|
49 |
+
}
|
50 |
class PredictionRequest(BaseModel):
|
51 |
inputs: str
|
52 |
model: str = "nidra-v1"
|
53 |
+
parameters: Optional[Dict[str, Any]] = None # Allow custom parameters
|
54 |
|
55 |
class PredictionResponse(BaseModel):
|
56 |
generated_text: str
|
|
|
66 |
@app.post("/predict", response_model=PredictionResponse)
|
67 |
async def predict(request: PredictionRequest):
|
68 |
try:
|
69 |
+
# Validate model
|
70 |
+
if request.model not in MODELS:
|
71 |
+
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
|
72 |
+
|
73 |
logger.info(f"Loading model: {request.model}")
|
74 |
model_path = MODELS[request.model]
|
75 |
|
|
|
90 |
local_files_only=False
|
91 |
)
|
92 |
logger.info("Model loaded successfully")
|
93 |
+
|
94 |
+
# Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs
|
95 |
+
generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
|
96 |
+
|
97 |
+
# Try to load model's saved generation config
|
98 |
+
try:
|
99 |
+
model_generation_config = GenerationConfig.from_pretrained(model_path)
|
100 |
+
# Convert to dict to merge with default configs
|
101 |
+
generation_params.update({
|
102 |
+
k: v for k, v in model_generation_config.to_dict().items()
|
103 |
+
if v is not None
|
104 |
+
})
|
105 |
+
except Exception as config_load_error:
|
106 |
+
logger.warning(f"Could not load model's generation config: {config_load_error}")
|
107 |
+
|
108 |
+
# Override with request-specific parameters if provided
|
109 |
+
if request.parameters:
|
110 |
+
generation_params.update(request.parameters)
|
111 |
+
|
112 |
+
logger.info(f"Final Generation Parameters: {generation_params}")
|
113 |
+
|
114 |
|
115 |
full_input = "Interpret this dream: " + request.inputs
|
116 |
logger.info(f"Processing input: {full_input}")
|
|
|
126 |
logger.info("Input tokenized successfully")
|
127 |
|
128 |
logger.info("Generating output...")
|
129 |
+
|
130 |
+
# Generate with final parameters
|
131 |
+
outputs = model.generate(
|
132 |
+
**inputs,
|
133 |
+
**{k: v for k, v in generation_params.items() if k in [
|
134 |
+
'max_length', 'min_length', 'do_sample', 'temperature',
|
135 |
+
'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
|
136 |
+
'repetition_penalty', 'early_stopping'
|
137 |
+
]}
|
138 |
+
)
|
139 |
logger.info("Output generated successfully")
|
140 |
|
141 |
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|