File size: 6,240 Bytes
8a0ec5f
 
 
 
 
e2ff706
8a0ec5f
 
e2ff706
2ea5c26
 
 
8a0ec5f
 
e2ff706
2ea5c26
 
8a0ec5f
 
2ea5c26
8a0ec5f
 
2ea5c26
8a0ec5f
e2ff706
8a0ec5f
2ea5c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a0ec5f
e2ff706
8a0ec5f
2ea5c26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr

def bot(input_message: str, db_info="", temperature=0.1, top_p=0.9, top_k=0, repetition_penalty=1.08):
    # For the stripped down version, let's just return a preset output
    final_query = "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |" 
    final_query_markdown = f"{final_query}"
    return final_query_markdown

with gr.Blocks(theme='gradio/soft') as demo:
    header = gr.HTML("""
        <h1 style="text-align: center">SQL Skeleton WizardCoder Demo</h1>
        <h3 style="text-align: center">πŸ§™β€β™‚οΈ Generate SQL queries from Natural Language πŸ§™β€β™‚οΈ</h3>
    """)

    output_box = gr.Code(label="Generated SQL", lines=2, interactive=True)
    input_text = gr.Textbox(lines=3, placeholder='Write your question here...', label='NL Input')
    db_info = gr.Textbox(lines=4, placeholder='Example: | table_01 : column_01 , column_02 | table_02 : column_01 , column_02 | ...', label='Database Info')

    with gr.Accordion("Hyperparameters", open=False):
        temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
        top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
        top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, value=0, step=1)
        repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.08, step=0.01)
        
    run_button = gr.Button("Generate SQL", variant="primary")
    
    with gr.Accordion("Examples", open=True):
        examples = gr.Examples([
            ["What is the average, minimum, and maximum age for all French singers?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["Show location and name for all stadiums with a capacity between 5000 and 10000.", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["What are the number of concerts that occurred in the stadium with the largest capacity ?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["How many male singers performed in concerts in the year 2023?", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"],
            ["List the names of all singers who performed in a concert with the theme 'Rock'", "| stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id | concert.stadium_id = stadium.stadium_id | singer_in_concert.singer_id = singer.singer_id | singer_in_concert.concert_id = concert.concert_id |"]
        ], inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], fn=bot)

    bitsandbytes_model = "richardr1126/spider-skeleton-wizard-coder-8bit"
    merged_model = "richardr1126/spider-skeleton-wizard-coder-merged"
    initial_model = "WizardLM/WizardCoder-15B-V1.0"
    finetuned_model = "richardr1126/spider-skeleton-wizard-coder-qlora"
    dataset = "richardr1126/spider-skeleton-context-instruct"
    
    footer = gr.HTML(f"""
        <p>πŸ› οΈ If you want you can <strong>duplicate this Space</strong>, then change the HF_MODEL_REPO spaces env varaible to use any Transformers model.</p>
        <p>🌐 Leveraging the <a href='https://huggingface.co/{bitsandbytes_model}'><strong>bitsandbytes 8-bit version</strong></a> of <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a> model.</p>
        <p>πŸ”— How it's made: <a href='https://huggingface.co/{initial_model}'><strong>{initial_model}</strong></a> was finetuned to create <a href='https://huggingface.co/{finetuned_model}'><strong>{finetuned_model}</strong></a>, then merged together to create <a href='https://huggingface.co/{merged_model}'><strong>{merged_model}</strong></a>.</p>
        <p>πŸ“‰ Fine-tuning was performed using QLoRA techniques on the <a href='https://huggingface.co/datasets/{dataset}'><strong>{dataset}</strong></a> dataset. You can view training metrics on the <a href='https://huggingface.co/{finetuned_model}'><strong>QLoRa adapter HF Repo</strong></a>.</p>
    """)


    run_button.click(fn=bot, inputs=[input_text, db_info, temperature, top_p, top_k, repetition_penalty], outputs=output_box, api_name="txt2sql")

demo.queue(concurrency_count=1, max_size=10).launch()