Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import openai
|
3 |
+
import tiktoken
|
4 |
+
import numpy as np
|
5 |
+
import ast
|
6 |
+
import pandas as pd
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
# Load environment variables from .env file
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
# Get API keys from environment variables
|
14 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
15 |
+
|
16 |
+
openai.api_key = OPENAI_API_KEY
|
17 |
+
client = OpenAI()
|
18 |
+
|
19 |
+
# Initialize the tokenizer for the model
|
20 |
+
tokenizer = tiktoken.get_encoding('p50k_base') # Use the appropriate encoding for your model
|
21 |
+
|
22 |
+
def get_embedding(text, model='text-embedding-3-small', max_tokens=7000):
|
23 |
+
# Tokenize the text and truncate if necessary
|
24 |
+
tokens = tokenizer.encode(text)
|
25 |
+
if len(tokens) > max_tokens:
|
26 |
+
tokens = tokens[:max_tokens]
|
27 |
+
text = tokenizer.decode(tokens)
|
28 |
+
|
29 |
+
return client.embeddings.create(input=[text],model=model).data[0].embedding
|
30 |
+
|
31 |
+
data = pd.read_csv("ucdavis_health_embeddings.csv")
|
32 |
+
|
33 |
+
# Handle NaN values and convert the 'embedding' column from strings to lists of floats
|
34 |
+
def safe_literal_eval(x):
|
35 |
+
try:
|
36 |
+
return ast.literal_eval(x)
|
37 |
+
except (ValueError, SyntaxError):
|
38 |
+
return []
|
39 |
+
|
40 |
+
data['embedding'] = data['embedding'].apply(safe_literal_eval)
|
41 |
+
|
42 |
+
# Ensure all embeddings are lists of floats and filter out empty embeddings
|
43 |
+
data['embedding'] = data['embedding'].apply(lambda x: [float(i) for i in x] if isinstance(x, list) else [])
|
44 |
+
data = data[data['embedding'].apply(lambda x: len(x) > 0)]
|
45 |
+
|
46 |
+
|
47 |
+
def query(question):
|
48 |
+
question_embedding = get_embedding(question)
|
49 |
+
|
50 |
+
def fn(page_embedding):
|
51 |
+
return np.dot(page_embedding, question_embedding)
|
52 |
+
|
53 |
+
distance_series = data['embedding'].apply(fn)
|
54 |
+
|
55 |
+
top_four = distance_series.sort_values(ascending=False).index[0:4]
|
56 |
+
|
57 |
+
context_series = data.loc[top_four]['text']
|
58 |
+
context = " ".join(context_series)
|
59 |
+
similarity_scores = distance_series.sort_values(ascending=False)[0:4]
|
60 |
+
links_series = data.loc[top_four]['url']
|
61 |
+
links = "\n \n".join(links_series)
|
62 |
+
link_list = links_series.tolist()
|
63 |
+
|
64 |
+
chat_completion = client.chat.completions.create(
|
65 |
+
messages=[
|
66 |
+
{"role": "system", "content": "You are a helpful assistant tasked to respond to users of UC Davos Health who are seeking information about their services"},
|
67 |
+
{"role": "user", "content": question},
|
68 |
+
{"role": "assistant", "content": f"Use this information from the UC Davis Health website and answer the user's question: {context}. Please stick to this context while answering the question. Include all important information relevant to what the user is seeking, also tell them things they should be mindful of while following instructions. Don't miss any details about timings or weekdays."}
|
69 |
+
],
|
70 |
+
model="gpt-3.5-turbo"
|
71 |
+
)
|
72 |
+
|
73 |
+
return chat_completion.choices[0].message.content, links, similarity_scores.tolist(), link_list
|
74 |
+
|
75 |
+
def plot_bar_chart(similarity_scores, links_series):
|
76 |
+
# Sort the similarity scores and links together
|
77 |
+
sorted_pairs = sorted(zip(similarity_scores, links_series)) # Remove reverse=True to keep ascending order
|
78 |
+
sorted_scores, sorted_links = zip(*sorted_pairs)
|
79 |
+
|
80 |
+
# Create labels as "Link 1", "Link 2", etc.
|
81 |
+
link_labels = [f"Link {i+1}" for i in range(len(sorted_links)-1, -1, -1)]
|
82 |
+
|
83 |
+
plt.figure(figsize=(12, 8)) # Adjusting the figure size to make it larger
|
84 |
+
bars = plt.barh(link_labels, sorted_scores, color='skyblue', edgecolor='black')
|
85 |
+
plt.xlabel('Similarity Score')
|
86 |
+
plt.ylabel('Links')
|
87 |
+
plt.title('Similarity Scores Bar Chart')
|
88 |
+
plt.xlim(0, 1) # Set x-axis scale from 0 to 1
|
89 |
+
plt.grid(True, axis='x')
|
90 |
+
|
91 |
+
# Add labels for each bar
|
92 |
+
for bar, score in zip(bars, sorted_scores):
|
93 |
+
plt.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height() / 2,
|
94 |
+
f'{score:.2f}', va='center', ha='left')
|
95 |
+
|
96 |
+
plt.tight_layout()
|
97 |
+
plt.savefig('bar_chart.png')
|
98 |
+
return 'bar_chart.png'
|
99 |
+
|
100 |
+
# Define the Gradio interface
|
101 |
+
def gradio_query(question):
|
102 |
+
answer, links, similarity_scores, link_list = query(question)
|
103 |
+
bar_plot_path = plot_bar_chart(similarity_scores, link_list)
|
104 |
+
return answer, links, bar_plot_path
|
105 |
+
|
106 |
+
interface = gr.Interface(
|
107 |
+
fn=gradio_query,
|
108 |
+
inputs=gr.Textbox(lines=2, placeholder="Enter your question here..."),
|
109 |
+
outputs=[
|
110 |
+
gr.Textbox(label="Answer"),
|
111 |
+
gr.Textbox(label="For more information, visit these links"),
|
112 |
+
gr.Image(type="filepath", label="Similarity Scores Bar Chart", elem_id="bar_chart")
|
113 |
+
],
|
114 |
+
title="UC Davis Health Query Assistant",
|
115 |
+
description="Ask your questions about UC Davis Health services and get relevant information from their website.",
|
116 |
+
css=".gradio-container #bar_chart img {width: 200%; height: auto;}"
|
117 |
+
)
|
118 |
+
|
119 |
+
# Launch the interface
|
120 |
+
interface.launch(share=True)
|