m1k3wn commited on
Commit
5fc0c7a
·
verified ·
1 Parent(s): 78a09b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -1,57 +1,60 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
  import os
6
- import torch
7
 
8
- # Set up logging
9
- logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # Initialize FastAPI
13
  app = FastAPI()
14
-
15
- # Get HF token
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
17
 
18
- # Define models
19
  MODELS = {
20
  "nidra-v1": "m1k3wn/nidra-v1",
21
  "nidra-v2": "m1k3wn/nidra-v2"
22
  }
23
 
24
- # Simple request model
25
  class PredictionRequest(BaseModel):
26
  inputs: str
27
  model: str = "nidra-v1"
28
 
29
- # Simple response model
30
  class PredictionResponse(BaseModel):
31
  generated_text: str
32
 
33
- @app.get("/")
34
- async def root():
35
- return {"message": "Dream Interpretation API", "status": "running"}
36
-
37
- @app.get("/health")
38
- async def health():
39
- return {"status": "healthy"}
40
-
41
  @app.post("/predict", response_model=PredictionResponse)
42
  async def predict(request: PredictionRequest):
43
  try:
44
- # Load model
45
  model_path = MODELS[request.model]
46
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
47
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=HF_TOKEN)
48
 
49
- # Process input
 
 
 
 
 
 
 
 
 
 
 
50
  full_input = "Interpret this dream: " + request.inputs
51
- inputs = tokenizer(full_input, return_tensors="pt")
52
- outputs = model.generate(**inputs)
 
 
 
 
 
 
 
 
53
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
 
55
  return PredictionResponse(generated_text=result)
 
56
  except Exception as e:
 
57
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration # Note: Using specific T5 classes
4
  import logging
5
  import os
 
6
 
7
+ logging.basicConfig(level=logging.DEBUG)
 
8
  logger = logging.getLogger(__name__)
9
 
 
10
  app = FastAPI()
 
 
11
  HF_TOKEN = os.environ.get("HF_TOKEN")
12
 
 
13
  MODELS = {
14
  "nidra-v1": "m1k3wn/nidra-v1",
15
  "nidra-v2": "m1k3wn/nidra-v2"
16
  }
17
 
 
18
  class PredictionRequest(BaseModel):
19
  inputs: str
20
  model: str = "nidra-v1"
21
 
 
22
  class PredictionResponse(BaseModel):
23
  generated_text: str
24
 
 
 
 
 
 
 
 
 
25
  @app.post("/predict", response_model=PredictionResponse)
26
  async def predict(request: PredictionRequest):
27
  try:
28
+ logger.info(f"Loading model: {request.model}")
29
  model_path = MODELS[request.model]
 
 
30
 
31
+ # Use T5-specific classes instead of Auto classes
32
+ tokenizer = T5Tokenizer.from_pretrained(
33
+ model_path,
34
+ token=HF_TOKEN,
35
+ legacy=True # Try with legacy mode first
36
+ )
37
+
38
+ model = T5ForConditionalGeneration.from_pretrained(
39
+ model_path,
40
+ token=HF_TOKEN
41
+ )
42
+
43
  full_input = "Interpret this dream: " + request.inputs
44
+ logger.info(f"Processing: {full_input}")
45
+
46
+ inputs = tokenizer(
47
+ full_input,
48
+ return_tensors="pt",
49
+ truncation=True,
50
+ max_length=512
51
+ )
52
+
53
+ outputs = model.generate(**inputs, max_length=200)
54
  result = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
 
56
  return PredictionResponse(generated_text=result)
57
+
58
  except Exception as e:
59
+ logger.error(f"Error: {str(e)}")
60
  raise HTTPException(status_code=500, detail=str(e))