m1k3wn commited on
Commit
7394c77
·
verified ·
1 Parent(s): 2095fff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
  from typing import Optional, Dict, Any
 
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO)
@@ -10,6 +11,11 @@ logger = logging.getLogger(__name__)
10
 
11
  app = FastAPI(title="Dream Interpretation API")
12
 
 
 
 
 
 
13
  # Define the model names
14
  MODELS = {
15
  "nidra-v1": "m1k3wn/nidra-v1",
@@ -35,9 +41,10 @@ def load_model(model_name: str):
35
  logger.info(f"Loading {model_name}...")
36
  try:
37
  model_path = MODELS[model_name]
38
- tokenizer = AutoTokenizer.from_pretrained(model_path)
39
  model = AutoModelForSeq2SeqLM.from_pretrained(
40
  model_path,
 
41
  device_map="auto",
42
  torch_dtype="auto"
43
  )
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import logging
5
  from typing import Optional, Dict, Any
6
+ import os
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
 
11
 
12
  app = FastAPI(title="Dream Interpretation API")
13
 
14
+ # Get HF token from environment variable
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+ if not HF_TOKEN:
17
+ raise ValueError("HF_TOKEN environment variable must be set")
18
+
19
  # Define the model names
20
  MODELS = {
21
  "nidra-v1": "m1k3wn/nidra-v1",
 
41
  logger.info(f"Loading {model_name}...")
42
  try:
43
  model_path = MODELS[model_name]
44
+ tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
45
  model = AutoModelForSeq2SeqLM.from_pretrained(
46
  model_path,
47
+ token=HF_TOKEN,
48
  device_map="auto",
49
  torch_dtype="auto"
50
  )