m1k3wn commited on
Commit
d1a2d5d
·
verified ·
1 Parent(s): 5fc0c7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -7
app.py CHANGED
@@ -1,8 +1,9 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration # Note: Using specific T5 classes
4
  import logging
5
  import os
 
6
 
7
  logging.basicConfig(level=logging.DEBUG)
8
  logger = logging.getLogger(__name__)
@@ -15,6 +16,17 @@ MODELS = {
15
  "nidra-v2": "m1k3wn/nidra-v2"
16
  }
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  class PredictionRequest(BaseModel):
19
  inputs: str
20
  model: str = "nidra-v1"
@@ -28,33 +40,56 @@ async def predict(request: PredictionRequest):
28
  logger.info(f"Loading model: {request.model}")
29
  model_path = MODELS[request.model]
30
 
31
- # Use T5-specific classes instead of Auto classes
32
  tokenizer = T5Tokenizer.from_pretrained(
33
  model_path,
34
  token=HF_TOKEN,
35
- legacy=True # Try with legacy mode first
 
 
 
 
 
 
36
  )
37
 
38
  model = T5ForConditionalGeneration.from_pretrained(
39
  model_path,
40
- token=HF_TOKEN
 
41
  )
42
 
43
  full_input = "Interpret this dream: " + request.inputs
44
  logger.info(f"Processing: {full_input}")
45
 
 
46
  inputs = tokenizer(
47
  full_input,
48
  return_tensors="pt",
49
  truncation=True,
50
- max_length=512
 
 
 
 
 
 
 
 
 
 
51
  )
52
 
53
- outputs = model.generate(**inputs, max_length=200)
54
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
55
 
56
  return PredictionResponse(generated_text=result)
57
 
58
  except Exception as e:
59
  logger.error(f"Error: {str(e)}")
60
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  import logging
5
  import os
6
+ import json
7
 
8
  logging.basicConfig(level=logging.DEBUG)
9
  logger = logging.getLogger(__name__)
 
16
  "nidra-v2": "m1k3wn/nidra-v2"
17
  }
18
 
19
+ # Define the tokenizer configuration explicitly
20
+ TOKENIZER_CONFIG = {
21
+ "model_max_length": 512,
22
+ "clean_up_tokenization_spaces": False,
23
+ "tokenizer_class": "T5Tokenizer",
24
+ "pad_token": "<pad>",
25
+ "eos_token": "</s>",
26
+ "unk_token": "<unk>",
27
+ "extra_ids": 100
28
+ }
29
+
30
  class PredictionRequest(BaseModel):
31
  inputs: str
32
  model: str = "nidra-v1"
 
40
  logger.info(f"Loading model: {request.model}")
41
  model_path = MODELS[request.model]
42
 
43
+ # Initialize tokenizer with explicit config
44
  tokenizer = T5Tokenizer.from_pretrained(
45
  model_path,
46
  token=HF_TOKEN,
47
+ model_max_length=TOKENIZER_CONFIG["model_max_length"],
48
+ clean_up_tokenization_spaces=TOKENIZER_CONFIG["clean_up_tokenization_spaces"],
49
+ pad_token=TOKENIZER_CONFIG["pad_token"],
50
+ eos_token=TOKENIZER_CONFIG["eos_token"],
51
+ unk_token=TOKENIZER_CONFIG["unk_token"],
52
+ extra_ids=TOKENIZER_CONFIG["extra_ids"],
53
+ use_fast=True # Try forcing the fast tokenizer
54
  )
55
 
56
  model = T5ForConditionalGeneration.from_pretrained(
57
  model_path,
58
+ token=HF_TOKEN,
59
+ torch_dtype="auto"
60
  )
61
 
62
  full_input = "Interpret this dream: " + request.inputs
63
  logger.info(f"Processing: {full_input}")
64
 
65
+ # Add explicit encoding parameters
66
  inputs = tokenizer(
67
  full_input,
68
  return_tensors="pt",
69
  truncation=True,
70
+ max_length=512,
71
+ padding=True,
72
+ add_special_tokens=True
73
+ )
74
+
75
+ outputs = model.generate(
76
+ **inputs,
77
+ max_length=200,
78
+ num_beams=4,
79
+ no_repeat_ngram_size=2,
80
+ length_penalty=1.0
81
  )
82
 
 
83
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
+ logger.info(f"Generated result: {result}")
85
 
86
  return PredictionResponse(generated_text=result)
87
 
88
  except Exception as e:
89
  logger.error(f"Error: {str(e)}")
90
+ raise HTTPException(status_code=500, detail=str(e))
91
+
92
+ # Add health check endpoint
93
+ @app.get("/health")
94
+ async def health():
95
+ return {"status": "healthy"}