hackergeek commited on
Commit
2b35f7d
·
verified ·
1 Parent(s): d2144a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
+ from peft import PeftModel
4
+
5
+ class DeepSeekLoraCPUInference:
6
+ def __init__(self, base_model="deepseek-ai/deepseek-r1", fine_tuned_model="./deepseek_lora_finetuned"):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model)
8
+
9
+ # Load model in 4-bit on CPU (if no GPU is available)
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ quant_config = BitsAndBytesConfig(
12
+ load_in_4bit=True if device == "cuda" else False, # Use 4-bit only if GPU is available
13
+ bnb_4bit_compute_dtype=torch.bfloat16,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_use_double_quant=True
16
+ )
17
+
18
+ self.model = AutoModelForCausalLM.from_pretrained(
19
+ base_model,
20
+ quantization_config=quant_config if device == "cuda" else None,
21
+ device_map=device
22
+ )
23
+
24
+ # Load fine-tuned LoRA model
25
+ self.model = PeftModel.from_pretrained(self.model, fine_tuned_model)
26
+ self.model.to(device)
27
+ self.model.eval()
28
+
29
+ def generate_text(self, prompt, max_length=200):
30
+ """Generates text efficiently using CPU or GPU."""
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(device)
33
+
34
+ with torch.no_grad():
35
+ output = self.model.generate(
36
+ **inputs,
37
+ max_length=max_length,
38
+ temperature=0.7,
39
+ top_p=0.9,
40
+ repetition_penalty=1.1
41
+ )
42
+
43
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
44
+
45
+ if __name__ == "__main__":
46
+ model = DeepSeekLoraCPUInference()
47
+
48
+ prompt = "The implications of AI in the next decade are"
49
+ generated_text = model.generate_text(prompt)
50
+
51
+ print("\nGenerated Text:\n", generated_text)