noddysnots commited on
Commit
1e69485
Β·
verified Β·
1 Parent(s): 8e593da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -3,22 +3,27 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
  import requests
5
 
 
 
 
 
 
 
6
  # Load DeepSeek-R1 model with trust_remote_code enabled
7
  model_name = "deepseek-ai/DeepSeek-R1"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
 
10
- # Ensure compatibility with `flash_attn` and force proper dtype
11
- try:
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_name,
14
- torch_dtype=torch.float16, # Forces float16 to prevent fp8 issue
15
- device_map="auto",
16
- trust_remote_code=True
17
- )
18
- except ImportError as e:
19
- raise RuntimeError("Missing required dependency: flash_attn. Install with `pip install flash_attn`") from e
20
-
21
- # Use a text-generation pipeline for better inference
22
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
23
 
24
 
 
3
  import torch
4
  import requests
5
 
6
+ # Ensure torch is installed before flash_attn
7
+ try:
8
+ import flash_attn
9
+ except ImportError:
10
+ raise RuntimeError("Missing required dependency: flash_attn. Install with `pip install flash-attn`")
11
+
12
  # Load DeepSeek-R1 model with trust_remote_code enabled
13
  model_name = "deepseek-ai/DeepSeek-R1"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
 
16
+ # Ensure compatibility and force execution on GPU if available
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32, # Prevents fp8 errors
22
+ device_map="auto",
23
+ trust_remote_code=True
24
+ )
25
+
26
+ # Use a text-generation pipeline
 
27
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
28
 
29