Spaces:
Build error
Build error
File size: 2,342 Bytes
562d551 4b85d9d 562d551 39e12e7 1c022e5 562d551 1c022e5 dc9a7be 1c022e5 562d551 39e12e7 7e61ab1 562d551 1c022e5 dc9a7be 4b85d9d 562d551 39e12e7 1c022e5 dc9a7be 562d551 dc9a7be 562d551 cf94a33 562d551 39e12e7 1c022e5 562d551 27ff797 562d551 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import torch
import json
@st.cache(allow_output_mutation=True)
def load_tokenizer(model_ckpt):
return AutoTokenizer.from_pretrained(model_ckpt)
@st.cache(allow_output_mutation=True)
def load_model(model_ckpt):
model = AutoModelForCausalLM.from_pretrained(model_ckpt)
return model
@st.cache()
def load_examples():
with open("examples.json", "r") as f:
examples = json.load(f)
return examples
st.set_page_config(page_icon=':parrot:', layout="wide")
device = 1 if torch.cuda.is_available() else 0
device_name = "GPU" if torch.cuda.is_available() else "CPU"
model_ckpt = "lvwerra/codeparrot"
tokenizer = load_tokenizer(model_ckpt)
model = load_model(model_ckpt)
examples = load_examples()
example_names = [example["name"] for example in examples]
name2id = dict([(name, i) for i, name in enumerate(example_names)])
set_seed(42)
gen_kwargs = {}
st.title("CodeParrot 🦜")
st.markdown('##')
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device)
st.sidebar.header("Examples:")
selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
example_text = examples[name2id[selected_example]]["value"]
default_length = examples[name2id[selected_example]]["length"]
st.sidebar.header("Generation settings:")
gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=default_length, min_value=8, step=8, max_value=256)
if gen_kwargs["do_sample"]:
gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
st.sidebar.markdown(f"Device: _{device_name}_")
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
if st.button("Generate code!"):
with st.spinner("Generating code..."):
generated_text = pipe(gen_prompt, **gen_kwargs)[0]['generated_text']
st.code(generated_text) |