m1k3wn commited on
Commit
2580a1e
·
verified ·
1 Parent(s): a4b1bdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -1,18 +1,18 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import T5Tokenizer, T5ForConditionalGeneration # Changed to specific classes
4
  import logging
5
  import os
6
  import sys
7
 
8
- @app.get("/version")
9
- async def version():
10
- return {"python_version": sys.version}
11
 
 
12
  logging.basicConfig(level=logging.DEBUG)
13
  logger = logging.getLogger(__name__)
14
 
15
- app = FastAPI()
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
17
 
18
  MODELS = {
@@ -27,6 +27,14 @@ class PredictionRequest(BaseModel):
27
  class PredictionResponse(BaseModel):
28
  generated_text: str
29
 
 
 
 
 
 
 
 
 
30
  @app.post("/predict", response_model=PredictionResponse)
31
  async def predict(request: PredictionRequest):
32
  try:
@@ -38,7 +46,7 @@ async def predict(request: PredictionRequest):
38
  tokenizer = T5Tokenizer.from_pretrained(
39
  model_path,
40
  token=HF_TOKEN,
41
- local_files_only=False, # Force download if needed
42
  return_special_tokens_mask=True
43
  )
44
  logger.info("Tokenizer loaded successfully")
@@ -47,7 +55,7 @@ async def predict(request: PredictionRequest):
47
  model = T5ForConditionalGeneration.from_pretrained(
48
  model_path,
49
  token=HF_TOKEN,
50
- local_files_only=False # Force download if needed
51
  )
52
  logger.info("Model loaded successfully")
53
 
@@ -76,11 +84,6 @@ async def predict(request: PredictionRequest):
76
  except Exception as e:
77
  logger.error(f"Error: {str(e)}")
78
  logger.error(f"Error type: {type(e)}")
79
- # Log the full traceback
80
  import traceback
81
  logger.error(f"Traceback: {traceback.format_exc()}")
82
- raise HTTPException(status_code=500, detail=str(e))
83
-
84
- @app.get("/health")
85
- async def health():
86
- return {"status": "healthy"}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  import logging
5
  import os
6
  import sys
7
 
8
+ # Initialize FastAPI first
9
+ app = FastAPI()
 
10
 
11
+ # Set up logging
12
  logging.basicConfig(level=logging.DEBUG)
13
  logger = logging.getLogger(__name__)
14
 
15
+ # Get HF token
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
17
 
18
  MODELS = {
 
27
  class PredictionResponse(BaseModel):
28
  generated_text: str
29
 
30
+ @app.get("/version")
31
+ async def version():
32
+ return {"python_version": sys.version}
33
+
34
+ @app.get("/health")
35
+ async def health():
36
+ return {"status": "healthy"}
37
+
38
  @app.post("/predict", response_model=PredictionResponse)
39
  async def predict(request: PredictionRequest):
40
  try:
 
46
  tokenizer = T5Tokenizer.from_pretrained(
47
  model_path,
48
  token=HF_TOKEN,
49
+ local_files_only=False,
50
  return_special_tokens_mask=True
51
  )
52
  logger.info("Tokenizer loaded successfully")
 
55
  model = T5ForConditionalGeneration.from_pretrained(
56
  model_path,
57
  token=HF_TOKEN,
58
+ local_files_only=False
59
  )
60
  logger.info("Model loaded successfully")
61
 
 
84
  except Exception as e:
85
  logger.error(f"Error: {str(e)}")
86
  logger.error(f"Error type: {type(e)}")
 
87
  import traceback
88
  logger.error(f"Traceback: {traceback.format_exc()}")
89
+ raise HTTPException(status_code=500, detail=str(e))