chegde commited on
Commit
ab95480
·
verified ·
1 Parent(s): e4b2909

Update app.py

Browse files

Optimized inference with FA and KV cache.

Files changed (1) hide show
  1. app.py +34 -6
app.py CHANGED
@@ -10,8 +10,17 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
11
  # Try loading the model with explicit error handling
12
  try:
13
- veri_model = AutoModelForCausalLM.from_pretrained(veri_model_path, device_map="auto", torch_dtype="auto")
14
  veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
 
 
 
 
 
 
 
 
 
 
15
  except Exception as e:
16
  print(f"Model loading error: {e}")
17
  veri_model = None
@@ -26,27 +35,46 @@ def generate_response(user_message, history):
26
  return history
27
 
28
  # Simple generation without streaming first
29
- system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud, to answer Verilog coding questions. Make sure your input and output interface has the same names as described in the question. Please start your Verilog code with CODE BEGIN and end with CODE END."
30
 
31
- # Create conversation history
32
  conversation = f"System: {system_message}\n"
33
- for h in history:
 
 
34
  conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
35
  conversation += f"User: {user_message}\nAssistant:"
36
 
37
- inputs = veri_tokenizer(conversation, return_tensors="pt", truncation=True, max_length=2048).to(device)
 
 
 
 
 
 
 
38
 
39
  with torch.no_grad():
40
  outputs = veri_model.generate(
41
  **inputs,
42
  max_new_tokens=4096,
43
- temperature=0.6,
44
  top_p=0.95,
45
  do_sample=True,
 
 
 
 
 
 
 
 
46
  pad_token_id=veri_tokenizer.eos_token_id
47
  )
48
 
49
  response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
 
 
 
50
 
51
  # Return updated history
52
  return history + [[user_message, response.strip()]]
 
10
 
11
  # Try loading the model with explicit error handling
12
  try:
 
13
  veri_tokenizer = AutoTokenizer.from_pretrained(veri_model_path)
14
+
15
+ veri_model = AutoModelForCausalLM.from_pretrained(
16
+ veri_model_path,
17
+ device_map="auto",
18
+ torch_dtype="auto",
19
+ trust_remote_code=True,
20
+ use_cache=True, # Enable KV caching
21
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
22
+ )
23
+
24
  except Exception as e:
25
  print(f"Model loading error: {e}")
26
  veri_model = None
 
35
  return history
36
 
37
  # Simple generation without streaming first
38
+ system_message = "You are VeriThoughts, a helpful assistant that thinks step by step. You are finetuned from a Qwen model, created by Alibaba Cloud. If you are asked a Verilog question, make sure your input and output interface has the same names as described in the question. If you are asked to generate code, please start your Verilog code with CODE BEGIN and end with CODE END."
39
 
 
40
  conversation = f"System: {system_message}\n"
41
+ recent_history = history[-3:] if len(history) > 3 else history
42
+
43
+ for h in recent_history:
44
  conversation += f"User: {h[0]}\nAssistant: {h[1]}\n"
45
  conversation += f"User: {user_message}\nAssistant:"
46
 
47
+ inputs = veri_tokenizer(
48
+ conversation,
49
+ return_tensors="pt",
50
+ truncation=True,
51
+ max_length=2048,
52
+ padding=True,
53
+ return_attention_mask=True
54
+ ).to(device)
55
 
56
  with torch.no_grad():
57
  outputs = veri_model.generate(
58
  **inputs,
59
  max_new_tokens=4096,
60
+ temperature=0.7,
61
  top_p=0.95,
62
  do_sample=True,
63
+ top_k=50, # Top-k sampling for efficiency
64
+ # pad_token_id=veri_tokenizer.eos_token_id,
65
+ # eos_token_id=veri_tokenizer.eos_token_id,
66
+ use_cache=True, # Enable KV caching for faster generation
67
+ repetition_penalty=1.1, # Reduce repetition
68
+ length_penalty=1.0,
69
+ early_stopping=True, # Stop early when appropriate
70
+ num_beams=1, # Greedy search for speed
71
  pad_token_id=veri_tokenizer.eos_token_id
72
  )
73
 
74
  response = veri_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
75
+
76
+ if torch.cuda.is_available():
77
+ torch.cuda.empty_cache()
78
 
79
  # Return updated history
80
  return history + [[user_message, response.strip()]]