Spaces:
Sleeping
Sleeping
Update app.py
Browse filesOptimized inference with FA and KV cache.
app.py
CHANGED
@@ -10,8 +10,17 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
10 |
|
11 |
# Try loading the model with explicit error handling
|
12 |
try:
|
13 |
-
veri_model = AutoModelForCausalLM.from_pretrained(veri_model_path, device_map="auto", torch_dtype="auto")
|
14 |
veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
except Exception as e:
|
16 |
print(f"Model loading error: {e}")
|
17 |
veri_model = None
|
@@ -26,27 +35,46 @@ def generate_response(user_message, history):
|
|
26 |
return history
|
27 |
|
28 |
# Simple generation without streaming first
|
29 |
-
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud
|
30 |
|
31 |
-
# Create conversation history
|
32 |
conversation = f"System: {system_message}\n"
|
33 |
-
|
|
|
|
|
34 |
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
|
35 |
conversation += f"User: {user_message}\nAssistant:"
|
36 |
|
37 |
-
inputs = veri_tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
with torch.no_grad():
|
40 |
outputs = veri_model.generate(
|
41 |
**inputs,
|
42 |
max_new_tokens=4096,
|
43 |
-
temperature=0.
|
44 |
top_p=0.95,
|
45 |
do_sample=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
pad_token_id=veri_tokenizer.eos_token_id
|
47 |
)
|
48 |
|
49 |
response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
|
|
|
|
|
|
50 |
|
51 |
# Return updated history
|
52 |
return history + [[user_message, response.strip()]]
|
|
|
10 |
|
11 |
# Try loading the model with explicit error handling
|
12 |
try:
|
|
|
13 |
veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
|
14 |
+
|
15 |
+
veri_model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
veri_model_path,
|
17 |
+
device_map="auto",
|
18 |
+
torch_dtype="auto",
|
19 |
+
trust_remote_code=True,
|
20 |
+
use_cache=True, # Enable KV caching
|
21 |
+
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
|
22 |
+
)
|
23 |
+
|
24 |
except Exception as e:
|
25 |
print(f"Model loading error: {e}")
|
26 |
veri_model = None
|
|
|
35 |
return history
|
36 |
|
37 |
# Simple generation without streaming first
|
38 |
+
system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END."
|
39 |
|
|
|
40 |
conversation = f"System: {system_message}\n"
|
41 |
+
recent_history = history[-3:] if len(history) > 3 else history
|
42 |
+
|
43 |
+
for h in recent_history:
|
44 |
conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
|
45 |
conversation += f"User: {user_message}\nAssistant:"
|
46 |
|
47 |
+
inputs = veri_tokenizer(
|
48 |
+
conversation,
|
49 |
+
return_tensors="pt",
|
50 |
+
truncation=True,
|
51 |
+
max_length=2048,
|
52 |
+
padding=True,
|
53 |
+
return_attention_mask=True
|
54 |
+
).to(device)
|
55 |
|
56 |
with torch.no_grad():
|
57 |
outputs = veri_model.generate(
|
58 |
**inputs,
|
59 |
max_new_tokens=4096,
|
60 |
+
temperature=0.7,
|
61 |
top_p=0.95,
|
62 |
do_sample=True,
|
63 |
+
top_k=50, # Top-k sampling for efficiency
|
64 |
+
# pad_token_id=veri_tokenizer.eos_token_id,
|
65 |
+
# eos_token_id=veri_tokenizer.eos_token_id,
|
66 |
+
use_cache=True, # Enable KV caching for faster generation
|
67 |
+
repetition_penalty=1.1, # Reduce repetition
|
68 |
+
length_penalty=1.0,
|
69 |
+
early_stopping=True, # Stop early when appropriate
|
70 |
+
num_beams=1, # Greedy search for speed
|
71 |
pad_token_id=veri_tokenizer.eos_token_id
|
72 |
)
|
73 |
|
74 |
response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
75 |
+
|
76 |
+
if torch.cuda.is_available():
|
77 |
+
torch.cuda.empty_cache()
|
78 |
|
79 |
# Return updated history
|
80 |
return history + [[user_message, response.strip()]]
|