m1k3wn commited on
Commit
5c94eeb
·
verified ·
1 Parent(s): 2580a1e

Update app.py

Browse files

debugging generation configs

Files changed (1) hide show
  1. app.py +64 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  import logging
5
  import os
6
  import sys
@@ -20,9 +20,37 @@ MODELS = {
20
  "nidra-v2": "m1k3wn/nidra-v2"
21
  }
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class PredictionRequest(BaseModel):
24
  inputs: str
25
  model: str = "nidra-v1"
 
26
 
27
  class PredictionResponse(BaseModel):
28
  generated_text: str
@@ -38,6 +66,10 @@ async def health():
38
  @app.post("/predict", response_model=PredictionResponse)
39
  async def predict(request: PredictionRequest):
40
  try:
 
 
 
 
41
  logger.info(f"Loading model: {request.model}")
42
  model_path = MODELS[request.model]
43
 
@@ -58,6 +90,27 @@ async def predict(request: PredictionRequest):
58
  local_files_only=False
59
  )
60
  logger.info("Model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  full_input = "Interpret this dream: " + request.inputs
63
  logger.info(f"Processing input: {full_input}")
@@ -73,7 +126,16 @@ async def predict(request: PredictionRequest):
73
  logger.info("Input tokenized successfully")
74
 
75
  logger.info("Generating output...")
76
- outputs = model.generate(**inputs, max_length=200)
 
 
 
 
 
 
 
 
 
77
  logger.info("Output generated successfully")
78
 
79
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, GenerationConfig
4
  import logging
5
  import os
6
  import sys
 
20
  "nidra-v2": "m1k3wn/nidra-v2"
21
  }
22
 
23
+ # Define default generation configurations for each model
24
+ DEFAULT_GENERATION_CONFIGS = {
25
+ "nidra-v1": {
26
+ "max_length": 300,
27
+ "min_length": 150,
28
+ "num_beams": 8,
29
+ "temperature": 0.55,
30
+ "do_sample": True,
31
+ "top_p": 0.95,
32
+ "repetition_penalty": 4.5,
33
+ "no_repeat_ngram_size": 4,
34
+ "early_stopping": True,
35
+ "length_penalty": 1.2,
36
+ },
37
+ "nidra-v2": {
38
+ "max_length": 300,
39
+ "min_length": 150,
40
+ "num_beams": 8,
41
+ "temperature": 0.4,
42
+ "do_sample": True,
43
+ "top_p": 0.95,
44
+ "repetition_penalty": 3.5,
45
+ "no_repeat_ngram_size": 4,
46
+ "early_stopping": True,
47
+ "length_penalty": 1.2,
48
+ }
49
+ }
50
  class PredictionRequest(BaseModel):
51
  inputs: str
52
  model: str = "nidra-v1"
53
+ parameters: Optional[Dict[str, Any]] = None # Allow custom parameters
54
 
55
  class PredictionResponse(BaseModel):
56
  generated_text: str
 
66
  @app.post("/predict", response_model=PredictionResponse)
67
  async def predict(request: PredictionRequest):
68
  try:
69
+ # Validate model
70
+ if request.model not in MODELS:
71
+ raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}")
72
+
73
  logger.info(f"Loading model: {request.model}")
74
  model_path = MODELS[request.model]
75
 
 
90
  local_files_only=False
91
  )
92
  logger.info("Model loaded successfully")
93
+
94
+ # Priority: 1. Request parameters, 2. Model's saved generation_config, 3. Default configs
95
+ generation_params = DEFAULT_GENERATION_CONFIGS[request.model].copy()
96
+
97
+ # Try to load model's saved generation config
98
+ try:
99
+ model_generation_config = GenerationConfig.from_pretrained(model_path)
100
+ # Convert to dict to merge with default configs
101
+ generation_params.update({
102
+ k: v for k, v in model_generation_config.to_dict().items()
103
+ if v is not None
104
+ })
105
+ except Exception as config_load_error:
106
+ logger.warning(f"Could not load model's generation config: {config_load_error}")
107
+
108
+ # Override with request-specific parameters if provided
109
+ if request.parameters:
110
+ generation_params.update(request.parameters)
111
+
112
+ logger.info(f"Final Generation Parameters: {generation_params}")
113
+
114
 
115
  full_input = "Interpret this dream: " + request.inputs
116
  logger.info(f"Processing input: {full_input}")
 
126
  logger.info("Input tokenized successfully")
127
 
128
  logger.info("Generating output...")
129
+
130
+ # Generate with final parameters
131
+ outputs = model.generate(
132
+ **inputs,
133
+ **{k: v for k, v in generation_params.items() if k in [
134
+ 'max_length', 'min_length', 'do_sample', 'temperature',
135
+ 'top_p', 'top_k', 'num_beams', 'no_repeat_ngram_size',
136
+ 'repetition_penalty', 'early_stopping'
137
+ ]}
138
+ )
139
  logger.info("Output generated successfully")
140
 
141
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)