Spaces:
Build error
Build error
File size: 4,763 Bytes
8b15eea 07ce11e 8b15eea bb9a90d 8b15eea df1aa0b bb9a90d 8b15eea df1aa0b 8b15eea df1aa0b 8b15eea df1aa0b 8b15eea bb9a90d 8b15eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import logging
import os
from pathlib import Path
from time import perf_counter
import gradio as gr
from jinja2 import Environment, FileSystemLoader
import requests
from backend.query_llm import generate
from backend.semantic_search import retriever
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('inception-mbzuai/jais-13b-chat')
def check_endpoint_status():
# Replace with the actual API URL and headers
api_url = os.getenv("ENDPOINT_URL")
headers = {
'accept': 'application/json',
'Authorization': f'Bearer {os.getenv("BEARER")}'
}
try:
response = requests.get(api_url, headers=headers)
response.raise_for_status() # will throw an exception for non-200 status
data = response.json()
# Extracting the status information
status = data.get('status', {}).get('state', 'No status found')
message = data.get('status', {}).get('message', 'No message found')
return f"Status: {status}\nMessage: {message}"
except requests.exceptions.RequestException as e:
return f"Failed to get status: {str(e)}"
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history, system_prompt=""):
top_k = 5
query = history[-1][0]
logger.warning('Retrieving documents...')
# Retrieve documents relevant to query
document_start = perf_counter()
documents = retriever(query, top_k=top_k)
document_time = document_start - perf_counter()
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
# Function to count tokens
def count_tokens(text):
return len(tokenizer.encode(text))
# Create Prompt
prompt = template.render(documents=documents, query=query)
# Check if the prompt is too long
token_count = count_tokens(prompt)
while token_count > 2048:
# Shorten your documents here. This is just a placeholder for the logic you'd use.
documents.pop() # Remove the last document
prompt = template.render(documents=documents, query=query) # Re-render the prompt
token_count = count_tokens(prompt) # Re-count tokens
prompt_html = template_html.render(documents=documents, query=query)
logger.warning(prompt)
history[-1][1] = ""
for character in generate(prompt):
history[-1][1] = character
yield history, prompt_html
with gr.Blocks() as demo:
with gr.Tab("Application"):
output = gr.Textbox(check_endpoint_status, label="Endpoint Status (send chat to wake up)", every=1)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
prompt_html = gr.HTML()
# Turn off interactivity while generating if you hit enter
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Turn off interactivity while generating if you hit enter
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, [chatbot, prompt_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
gr.Examples(['What is the capital of China, I think its Shanghai?',
'Why is the sky blue?',
'Who won the mens world cup in 2014?',], txt)
demo.queue()
demo.launch(debug=True)
|