Spaces:
Running
Running
import os | |
import gradio as gr | |
import chromadb | |
from sentence_transformers import SentenceTransformer | |
import spaces | |
client = chromadb.PersistentClient(path="./chroma") | |
collection_de = client.get_collection(name="phil_de") | |
#collection_en = client.get_collection(name="phil_en") | |
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"] | |
#authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"] | |
def get_embeddings(queries, task): | |
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN")) | |
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries] | |
query_embeddings = model.encode(prompts) | |
return query_embeddings | |
def query_chroma(collection, embedding, authors): | |
results = collection.query( | |
query_embeddings=[embedding.tolist()], | |
n_results=10, | |
where={"author": {"$in": authors}} if authors else {}, | |
include=["documents", "metadatas", "distances"] | |
) | |
ids = results.get('ids', [[]])[0] | |
metadatas = results.get('metadatas', [[]])[0] | |
documents = results.get('documents', [[]])[0] | |
distances = results.get('distances', [[]])[0] | |
formatted_results = [] | |
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances): | |
result_dict = { | |
"id": id_, | |
"author": metadata.get('author', 'Unknown author'), | |
"book": metadata.get('book', 'Unknown book'), | |
"section": metadata.get('section', 'Unknown section'), | |
"title": metadata.get('title', 'Untitled'), | |
"text": document_text, | |
"distance": distance | |
} | |
formatted_results.append(result_dict) | |
return formatted_results | |
theme = gr.themes.Soft( | |
primary_hue="indigo", | |
secondary_hue="slate", | |
neutral_hue="slate", | |
spacing_size="lg", | |
radius_size="lg", | |
text_size="lg", | |
font=["Helvetica", "sans-serif"], | |
font_mono=["Courier", "monospace"], | |
).set( | |
body_text_color="*neutral_800", | |
block_background_fill="*neutral_50", | |
block_border_width="0px", | |
button_primary_background_fill="*primary_600", | |
button_primary_background_fill_hover="*primary_700", | |
button_primary_text_color="white", | |
input_background_fill="white", | |
input_border_color="*neutral_200", | |
input_border_width="1px", | |
checkbox_background_color_selected="*primary_600", | |
checkbox_border_color_selected="*primary_600", | |
) | |
custom_css = """ | |
.custom-markdown { | |
border: 1px solid var(--neutral-200); | |
padding: 15px; | |
border-radius: var(--radius-lg); | |
background-color: var(--color-background-primary); | |
margin-bottom: 15px; | |
} | |
.custom-markdown p { | |
margin-bottom: 10px; | |
line-height: 1.6; | |
} | |
""" | |
with gr.Blocks(theme=theme, css=custom_css) as demo: | |
gr.Markdown("Geben Sie ein, wonach Sie suchen möchten (Query), filtern Sie nach Autoren (ohne Auswahl werden alle durchsucht) und klicken Sie auf **Suchen**, um zu suchen. Trennen Sie mehrere Fragen durch Semikola; die Suche dauert, unabhängig von der Anzahl der Abfragen, etwa 40 Sekunden, da das Embeddingmodell jedes Mal auf eine GPU geladen werden muss.") | |
#database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German") | |
author_inp = gr.Dropdown(label="Autoren", choices=authors_list_de, multiselect=True) | |
inp = gr.Textbox(label="Query", placeholder="Wie kann ich gesund leben?; Wie kann ich mich besser konzentrieren?; Was ist der Sinn des Lebens?; ...") | |
btn = gr.Button("Suchen") | |
loading_indicator = gr.Markdown(visible=False) | |
results = gr.State() | |
#def update_authors(database): | |
# return gr.update(choices=authors_list_de if database == "German" else authors_list_en) | |
#database_inp.change( | |
# fn=lambda database: update_authors(database), | |
# inputs=[database_inp], | |
# outputs=[author_inp] | |
#) | |
def perform_query(queries, authors): | |
task = "Given a question, retrieve passages that answer the question" | |
queries = [query.strip() for query in queries.split(';')] | |
embeddings = get_embeddings(queries, task) | |
collection = collection_de | |
results_data = [] | |
for query, embedding in zip(queries, embeddings): | |
res = query_chroma(collection, embedding, authors) | |
results_data.append((query, res)) | |
return results_data, "" | |
btn.click( | |
fn=lambda: ("Searching... This may take up to 40 seconds.", gr.update(visible=True)), | |
inputs=None, | |
outputs=[loading_indicator, loading_indicator], | |
queue=False | |
).then( | |
perform_query, | |
inputs=[inp, author_inp], | |
outputs=[results, loading_indicator] | |
) | |
def display_accordion(data): | |
for query, res in data: | |
with gr.Accordion(query, open=False) as acc: | |
for result in res: | |
with gr.Column(): | |
author = result.get('author', 'Unknown author') | |
book = result.get('book', 'Unknown book') | |
text = result.get('text') | |
markdown_contents = f"**{author}, {book}**\n\n{text}" | |
gr.Markdown(value=markdown_contents, elem_classes="custom-markdown") | |
demo.launch() |