Spaces:
Runtime error
Runtime error
File size: 5,148 Bytes
fbd4b06 19890e4 fbd4b06 19890e4 fbd4b06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b", trust_remote_code=True)
description = """# <h1 style="text-align: center; color: white;"><span style='color: #F26207;'> Code Completion with falcoder-7b </h1>
<span style="color: white; text-align: center;"> falcoder-7b You can click the button to generate your code.</span>"""
token = os.environ["HUB_TOKEN"]
device = "cuda" if torch.cuda.is_available() else "cpu"
PAD_TOKEN = "<|pad|>"
EOS_TOKEN = "<|endoftext|>"
UNK_TOKEN = "<|unk|>"
MAX_INPUT_TOKENS = 1024 # max tokens from context
tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True)
tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16)
else:
model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True)
model.eval()
custom_css = """
.gradio-container {
background-color: #0D1525;
color:white
}
#orange-button {
background: #F26207 !important;
color: white;
}
.cm-gutters{
border: none !important;
}
"""
def post_processing(prompt, completion):
return prompt + completion
# completion = "<span style='color: #499cd5;'>" + completion + "</span>"
# prompt = "<span style='color: black;'>" + prompt + "</span>"
# code_html = f"<hr><br><pre style='font-size: 14px'><code>{prompt}{completion}</code></pre><br><hr>"
# return code_html
def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
# truncates the prompt to MAX_INPUT_TOKENS if its too long
x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod
set_seed(seed)
y = model.generate(x,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
top_p=top_p,
top_k=top_k,
use_cache=use_cache,
repetition_penalty=repetition_penalty
)
completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
completion = completion[len(prompt):]
return post_processing(prompt, completion)
demo = gr.Blocks(
css=custom_css
)
with demo:
gr.Markdown(value=description)
with gr.Row():
input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6),
with input_col:
code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):")
with settings_col:
with gr.Accordion("Generation Settings", open=True):
max_new_tokens= gr.Slider(
minimum=8,
maximum=128,
step=1,
value=48,
label="Max Tokens",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.5,
step=0.1,
value=0.2,
label="Temperature",
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.9,
step=0.1,
value=1.0,
label="Repetition Penalty. 1.0 means no penalty.",
)
seed = gr.Slider(
minimum=0,
maximum=1000,
step=1,
label="Random Seed"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.9,
label="Top P",
)
top_k = gr.Slider(
minimum=1,
maximum=64,
step=1,
value=4,
label="Top K",
)
use_cache = gr.Checkbox(
label="Use Cache",
value=True
)
with gr.Row():
run = gr.Button(elem_id="orange-button", value="Generate")
# with gr.Row():
# # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
# # with middle_col_row_2:
# output = gr.HTML(label="Generated Code")
event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict")
demo.queue(max_size=40).launch() |