|
import os |
|
import gradio as gr |
|
from gradio.components import Textbox, Button, Slider, Checkbox |
|
from AinaTheme import theme |
|
from urllib.error import HTTPError |
|
|
|
from rag import RAG |
|
from utils import setup |
|
|
|
MAX_NEW_TOKENS = 700 |
|
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True" |
|
|
|
setup() |
|
|
|
|
|
rag = RAG( |
|
hf_token=os.getenv("HF_TOKEN"), |
|
embeddings_model=os.getenv("EMBEDDINGS"), |
|
model_name=os.getenv("MODEL"), |
|
rerank_model=os.getenv("RERANK_MODEL"), |
|
rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS")) |
|
) |
|
|
|
|
|
def generate(prompt, model_parameters): |
|
try: |
|
output, context, source = rag.get_response(prompt, model_parameters) |
|
return output, context, source |
|
except HTTPError as err: |
|
if err.code == 400: |
|
gr.Warning( |
|
"The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET." |
|
) |
|
except: |
|
gr.Warning( |
|
"Inference endpoint is not available right now. Please try again later." |
|
) |
|
return None, None, None |
|
|
|
|
|
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature): |
|
if input_.strip() == "": |
|
gr.Warning("Not possible to inference an empty input") |
|
return None |
|
|
|
|
|
model_parameters = { |
|
"NUM_CHUNKS": num_chunks, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"do_sample": do_sample, |
|
"temperature": temperature |
|
} |
|
|
|
output, context, source = generate(input_, model_parameters) |
|
sources_markup = "" |
|
|
|
for url in source: |
|
sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>' |
|
|
|
return output, sources_markup, context |
|
|
|
|
|
|
|
def change_interactive(text): |
|
if len(text) == 0: |
|
return gr.update(interactive=True), gr.update(interactive=False) |
|
return gr.update(interactive=True), gr.update(interactive=True) |
|
|
|
|
|
def clear(): |
|
return ( |
|
None, |
|
None, |
|
None, |
|
None, |
|
gr.Slider(value=2.0), |
|
gr.Slider(value=MAX_NEW_TOKENS), |
|
gr.Slider(value=1.0), |
|
gr.Slider(value=50), |
|
gr.Slider(value=0.99), |
|
gr.Checkbox(value=False), |
|
gr.Slider(value=0.35), |
|
) |
|
|
|
|
|
def gradio_app(): |
|
with gr.Blocks(theme=theme) as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=0.1): |
|
gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False) |
|
with gr.Column(): |
|
gr.Markdown( |
|
"""# TEST de Retrieval-Augmented Generation para proyectos de LangTechE |
|
馃攳 |
|
|
|
鈿狅笍 **Advertencias**: Esta es una versi贸n experimental. 馃憖 |
|
""" |
|
) |
|
with gr.Row(equal_height=True): |
|
with gr.Column(variant="panel"): |
|
input_ = Textbox( |
|
lines=11, |
|
label="Input", |
|
placeholder="", |
|
|
|
) |
|
with gr.Row(variant="panel"): |
|
clear_btn = Button( |
|
"Clear", |
|
) |
|
submit_btn = Button("Submit", variant="primary", interactive=False) |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI): |
|
num_chunks = Slider( |
|
minimum=1, |
|
maximum=6, |
|
step=1, |
|
value=2, |
|
label="Number of chunks" |
|
) |
|
max_new_tokens = Slider( |
|
minimum=50, |
|
maximum=2000, |
|
step=1, |
|
value=MAX_NEW_TOKENS, |
|
label="Max tokens" |
|
) |
|
repetition_penalty = Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
step=0.1, |
|
value=1.0, |
|
label="Repetition penalty" |
|
) |
|
top_k = Slider( |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=50, |
|
label="Top k" |
|
) |
|
top_p = Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.99, |
|
label="Top p" |
|
) |
|
do_sample = Checkbox( |
|
value=False, |
|
label="Do sample" |
|
) |
|
temperature = Slider( |
|
minimum=0.1, |
|
maximum=1, |
|
value=0.35, |
|
label="Temperature" |
|
) |
|
|
|
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature] |
|
|
|
with gr.Column(variant="panel"): |
|
output = Textbox( |
|
lines=10, |
|
label="Output", |
|
interactive=False, |
|
show_copy_button=True |
|
) |
|
with gr.Accordion("Sources and context:", open=False): |
|
source_context = gr.Markdown( |
|
label="Sources", |
|
show_label=False, |
|
) |
|
with gr.Accordion("See full context evaluation:", open=False): |
|
context_evaluation = gr.Markdown( |
|
label="Full context", |
|
show_label=False, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
input_.change( |
|
fn=change_interactive, |
|
inputs=[input_], |
|
outputs=[clear_btn, submit_btn], |
|
api_name=False, |
|
) |
|
|
|
input_.change( |
|
fn=None, |
|
inputs=[input_], |
|
api_name=False, |
|
js="""(i, m) => { |
|
document.getElementById('inputlenght').textContent = i.length + ' ' |
|
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; |
|
}""", |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear, |
|
inputs=[], |
|
outputs=[input_, output, source_context, context_evaluation] + parameters_compontents, |
|
queue=False, |
|
api_name=False |
|
) |
|
|
|
submit_btn.click( |
|
fn=submit_input, |
|
inputs=[input_]+ parameters_compontents, |
|
outputs=[output, source_context, context_evaluation], |
|
api_name="get-results" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.5): |
|
gr.Examples( |
|
examples=[ |
|
["""驴Se pueden transportar mascotas en el AVE?"""], |
|
], |
|
inputs=input_, |
|
outputs=[output, source_context, context_evaluation], |
|
fn=submit_input, |
|
) |
|
gr.Examples( |
|
examples=[ |
|
["""驴C贸mo se crea un billete de d铆a del Pase M贸vil en la aplicaci贸n Rail Planner App?"""], |
|
], |
|
inputs=input_, |
|
outputs=[output, source_context, context_evaluation], |
|
fn=submit_input, |
|
) |
|
gr.Examples( |
|
examples=[ |
|
["""驴C贸mo puedo solicitar la factura de un abono con posterioridad a la compra?"""], |
|
], |
|
inputs=input_, |
|
outputs=[output, source_context, context_evaluation], |
|
fn=submit_input, |
|
) |
|
|
|
demo.launch(show_api=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
gradio_app() |