1inkusFace commited on
Commit
7665895
·
verified ·
1 Parent(s): 4239ac9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -10
app.py CHANGED
@@ -15,31 +15,26 @@ model = AutoModelForCausalLM.from_pretrained(
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Add this line for Qwen models
17
 
18
-
19
  @spaces.GPU(required=True)
20
  def generate_code(prompt):
21
  messages = [
22
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
23
  {"role": "user", "content": prompt}
24
  ]
25
-
26
  text = tokenizer.apply_chat_template(
27
  messages,
28
  tokenize=False,
29
  add_generation_prompt=True
30
  )
31
-
32
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
33
-
34
- generated_ids = model.generate(
35
- **model_inputs,
36
- max_new_tokens=512
37
- )
38
-
39
  generated_ids = [
40
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
41
  ]
42
-
43
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
  return response
45
 
 
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Add this line for Qwen models
17
 
 
18
  @spaces.GPU(required=True)
19
  def generate_code(prompt):
20
  messages = [
21
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
22
  {"role": "user", "content": prompt}
23
  ]
 
24
  text = tokenizer.apply_chat_template(
25
  messages,
26
  tokenize=False,
27
  add_generation_prompt=True
28
  )
 
29
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
30
+ with torch.no_grad():
31
+ generated_ids = model.generate(
32
+ **model_inputs,
33
+ max_new_tokens=1024
34
+ )
 
35
  generated_ids = [
36
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
37
  ]
 
38
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
  return response
40