arya-ai-model commited on
Commit
e2116c0
·
1 Parent(s): ad9f174

updated model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -8
model.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  MODEL_NAME = "bigcode/starcoderbase-1b"
6
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
7
 
8
- # Force CPU mode
9
  device = "cpu"
10
 
11
  # Load tokenizer and model
@@ -18,23 +17,34 @@ if tokenizer.pad_token is None:
18
  model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_NAME,
20
  token=HF_TOKEN,
21
- torch_dtype=torch.float32, # Use float32 for CPU
22
  trust_remote_code=True
23
- ).to(device) # Move model explicitly to CPU
24
 
25
  def generate_code(prompt: str, max_tokens: int = 256):
 
 
26
  inputs = tokenizer(
27
- prompt,
28
  return_tensors="pt",
29
  padding=True,
30
- truncation=True, # Allow truncation
31
- max_length=1024 # Set a maximum length explicitly
32
  ).to(device)
33
 
34
  output = model.generate(
35
  **inputs,
36
  max_new_tokens=max_tokens,
37
- pad_token_id=tokenizer.pad_token_id
 
 
 
38
  )
39
 
40
- return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
5
  MODEL_NAME = "bigcode/starcoderbase-1b"
6
  HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
7
 
 
8
  device = "cpu"
9
 
10
  # Load tokenizer and model
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_NAME,
19
  token=HF_TOKEN,
20
+ torch_dtype=torch.float32, # Ensure compatibility with CPU
21
  trust_remote_code=True
22
+ ).to(device)
23
 
24
  def generate_code(prompt: str, max_tokens: int = 256):
25
+ formatted_prompt = f"# Python\n{prompt}\n\n" # Ensure the model understands it's code
26
+
27
  inputs = tokenizer(
28
+ formatted_prompt,
29
  return_tensors="pt",
30
  padding=True,
31
+ truncation=True,
32
+ max_length=1024 # Explicit max length to prevent issues
33
  ).to(device)
34
 
35
  output = model.generate(
36
  **inputs,
37
  max_new_tokens=max_tokens,
38
+ pad_token_id=tokenizer.pad_token_id,
39
+ do_sample=True, # Enable randomness for better outputs
40
+ top_p=0.95, # Nucleus sampling to improve generation
41
+ temperature=0.7 # Control creativity
42
  )
43
 
44
+ generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
45
+
46
+ # Clean the output: remove the repeated prompt at the start
47
+ if generated_code.startswith(formatted_prompt):
48
+ generated_code = generated_code[len(formatted_prompt):]
49
+
50
+ return generated_code.strip()