Chaitanya Sagar Gurujula commited on
Commit
4810fdb
·
1 Parent(s): e59a831

added initial version

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. src/app.py +13 -3
.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/app.py CHANGED
@@ -10,11 +10,12 @@ from fastapi.templating import Jinja2Templates
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import HTMLResponse
12
  from fastapi.staticfiles import StaticFiles
 
13
 
14
  # Get the absolute path to the templates directory
15
  TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
16
 
17
- MODEL_ID = "sagargurujula/text-generator"
18
 
19
  # Initialize FastAPI
20
  app = FastAPI(title="GPT Text Generator")
@@ -34,13 +35,22 @@ app.add_middleware(
34
  # Set device
35
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
36
 
 
 
 
 
 
 
 
 
37
  # Load model from Hugging Face Hub
38
  def load_model():
39
  try:
40
- # Download the model file from HF Hub
41
  model_path = hf_hub_download(
42
  repo_id=MODEL_ID,
43
- filename="best_model.pth"
 
44
  )
45
 
46
  # Initialize our custom GPT model
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.responses import HTMLResponse
12
  from fastapi.staticfiles import StaticFiles
13
+ from pathlib import Path
14
 
15
  # Get the absolute path to the templates directory
16
  TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), "templates")
17
 
18
+ MODEL_ID = "sagargurujala/text-generator"
19
 
20
  # Initialize FastAPI
21
  app = FastAPI(title="GPT Text Generator")
 
35
  # Set device
36
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
 
38
+ # Create cache directory in the current working directory
39
+ cache_dir = Path("model_cache")
40
+ cache_dir.mkdir(exist_ok=True)
41
+
42
+ # Set environment variable for Hugging Face cache
43
+ os.environ['TRANSFORMERS_CACHE'] = str(cache_dir)
44
+ os.environ['HF_HOME'] = str(cache_dir)
45
+
46
  # Load model from Hugging Face Hub
47
  def load_model():
48
  try:
49
+ # Download the model file from HF Hub with custom cache directory
50
  model_path = hf_hub_download(
51
  repo_id=MODEL_ID,
52
+ filename="best_model.pth",
53
+ cache_dir=cache_dir
54
  )
55
 
56
  # Initialize our custom GPT model