File size: 1,583 Bytes
d9ded01
b2c1928
ea0af80
 
812fce0
d9ded01
 
8374669
d36dc81
ba41c7f
d9ded01
ba41c7f
812fce0
ba41c7f
812fce0
ba41c7f
d36dc81
 
 
812fce0
eb66cb5
e2116c0
ea0af80
 
812fce0
e2116c0
ad9f174
e2116c0
ad9f174
812fce0
e2116c0
812fce0
ad9f174
 
 
 
 
e2116c0
812fce0
 
 
ad9f174
 
e2116c0
 
812fce0
e2116c0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_NAME = "bigcode/starcoderbase-3b"
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

device = "cpu"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)

# Ensure the tokenizer has a pad token set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Set pad_token to eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    token=HF_TOKEN,
    torch_dtype=torch.float32,  # Ensure compatibility with CPU
    trust_remote_code=True
).to(device)

def generate_code(prompt: str, max_tokens: int = 256):
    formatted_prompt = f"# Python\n{prompt}\n\n"  # Ensure the model understands it's code

    inputs = tokenizer(
        formatted_prompt, 
        return_tensors="pt", 
        padding=True, 
        truncation=True,  
        max_length=1024  # Explicit max length to prevent issues
    ).to(device)

    output = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,  # Enable randomness for better outputs
        top_p=0.95,  # Nucleus sampling to improve generation
        temperature=0.7  # Control creativity
    )
    
    generated_code = tokenizer.decode(output[0], skip_special_tokens=True)

    # Clean the output: remove the repeated prompt at the start
    if generated_code.startswith(formatted_prompt):
        generated_code = generated_code[len(formatted_prompt):]

    return generated_code.strip()