MienOlle commited on
Commit
95f8ea9
·
1 Parent(s): cd86ae0

More fixes to HF HOME cache

Browse files
Files changed (1) hide show
  1. main.py +8 -4
main.py CHANGED
@@ -6,22 +6,26 @@ from pydantic import BaseModel
6
  import os
7
 
8
  app = FastAPI()
9
- os.makedirs("/app/cache", exist_ok=True)
 
10
 
11
  model_path = hf_hub_download(repo_id="MienOlle/sentiment_analysis_api",
12
  filename="sentimentAnalysis.pth",
13
- cache_dir="/app/cache"
14
  )
15
  modelToken = token.from_pretrained("mdhugol/indonesia-bert-sentiment-classification")
16
  model = modelSC.from_pretrained("mdhugol/indonesia-bert-sentiment-classification", num_labels=3)
17
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
 
 
 
18
  model.eval()
19
 
20
  class TextInput(BaseModel):
21
  text: str
22
 
23
  def predict(input):
24
- inputs = modelToken(input, return_tensors="pt", padding=True, truncation=True, max_length=512)
25
 
26
  with torch.no_grad():
27
  outputs = model(**inputs)
 
6
  import os
7
 
8
  app = FastAPI()
9
+ os.environ["HF_HOME"] = "/tmp/huggingface"
10
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
11
 
12
  model_path = hf_hub_download(repo_id="MienOlle/sentiment_analysis_api",
13
  filename="sentimentAnalysis.pth",
14
+ cache_dir=os.environ["HF_HOME"]
15
  )
16
  modelToken = token.from_pretrained("mdhugol/indonesia-bert-sentiment-classification")
17
  model = modelSC.from_pretrained("mdhugol/indonesia-bert-sentiment-classification", num_labels=3)
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
21
+ model.to(device)
22
  model.eval()
23
 
24
  class TextInput(BaseModel):
25
  text: str
26
 
27
  def predict(input):
28
+ inputs = modelToken(input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
29
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)