Spaces:
Build error
Build error
File size: 2,129 Bytes
562d551 4b85d9d 562d551 39e12e7 1c022e5 562d551 b212dd9 3c06e38 3ce9b38 562d551 1c022e5 dc9a7be 1c022e5 562d551 26a52e2 39e12e7 14cb31c 56eec5a 1c022e5 dc9a7be 4b85d9d 562d551 1c022e5 dc9a7be 562d551 dc9a7be 562d551 cf94a33 562d551 39e12e7 1c022e5 562d551 27ff797 562d551 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from transformers import pipeline
import torch
import json
#@st.cache(allow_output_mutation=True)
def load_pipeline(model_ckpt, device):
pipe = pipeline('text-generation', model=model_ckpt, device=device)
return pipe
@st.cache()
def load_examples():
with open("examples.json", "r") as f:
examples = json.load(f)
return examples
st.set_page_config(page_icon=':parrot:', layout="wide")
device = 0 if torch.cuda.is_available() else -1
device_name = "GPU" if torch.cuda.is_available() else "CPU"
model_ckpt = "lvwerra/codeparrot"
pipe = load_pipeline(model_ckpt, device)
examples = load_examples()
example_names = [example["name"] for example in examples]
name2id = dict([(name, i) for i, name in enumerate(example_names)])
set_seed(42)
gen_kwargs = {}
st.title("CodeParrot 🦜")
st.markdown('##')
st.sidebar.header("Examples:")
selected_example = st.sidebar.selectbox("Select one of the following examples:", example_names)
example_text = examples[name2id[selected_example]]["value"]
default_length = examples[name2id[selected_example]]["length"]
st.sidebar.header("Generation settings:")
gen_kwargs["do_sample"] = st.sidebar.radio("Decoding strategy", ["Greedy", "Sample"]) == "Sample"
gen_kwargs["max_new_tokens"] = st.sidebar.slider("Number of tokens to generate", value=default_length, min_value=8, step=8, max_value=256)
if gen_kwargs["do_sample"]:
gen_kwargs["temperature"] = st.sidebar.slider("Temperature", value = 0.2, min_value = 0.0, max_value=2.0, step=0.05)
gen_kwargs["top_k"] = st.sidebar.slider("Top-k", min_value = 0, max_value=100, value = 0)
gen_kwargs["top_p"] = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.01, value = 0.95)
st.sidebar.markdown(f"Device: _{device_name}_")
gen_prompt = st.text_area("Generate code with prompt:", value=example_text, height=220,).strip()
if st.button("Generate code!"):
with st.spinner("Generating code..."):
generated_text = pipe(gen_prompt, **gen_kwargs)[0]['generated_text']
st.code(generated_text) |