Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
import logging
|
5 |
from typing import Optional, Dict, Any
|
|
|
6 |
|
7 |
# Set up logging
|
8 |
logging.basicConfig(level=logging.INFO)
|
@@ -10,6 +11,11 @@ logger = logging.getLogger(__name__)
|
|
10 |
|
11 |
app = FastAPI(title="Dream Interpretation API")
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
# Define the model names
|
14 |
MODELS = {
|
15 |
"nidra-v1": "m1k3wn/nidra-v1",
|
@@ -35,9 +41,10 @@ def load_model(model_name: str):
|
|
35 |
logger.info(f"Loading {model_name}...")
|
36 |
try:
|
37 |
model_path = MODELS[model_name]
|
38 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
39 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
40 |
model_path,
|
|
|
41 |
device_map="auto",
|
42 |
torch_dtype="auto"
|
43 |
)
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
import logging
|
5 |
from typing import Optional, Dict, Any
|
6 |
+
import os
|
7 |
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
|
|
11 |
|
12 |
app = FastAPI(title="Dream Interpretation API")
|
13 |
|
14 |
+
# Get HF token from environment variable
|
15 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
+
if not HF_TOKEN:
|
17 |
+
raise ValueError("HF_TOKEN environment variable must be set")
|
18 |
+
|
19 |
# Define the model names
|
20 |
MODELS = {
|
21 |
"nidra-v1": "m1k3wn/nidra-v1",
|
|
|
41 |
logger.info(f"Loading {model_name}...")
|
42 |
try:
|
43 |
model_path = MODELS[model_name]
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
|
45 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
46 |
model_path,
|
47 |
+
token=HF_TOKEN,
|
48 |
device_map="auto",
|
49 |
torch_dtype="auto"
|
50 |
)
|