sorhwphuo / app.py
HonestAnnie's picture
alles deutsch jetzt
08029e5
raw
history blame
4.35 kB
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces
client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
#collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
@spaces.GPU
def get_embeddings(queries, task):
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
query_embeddings = model.encode(prompts)
return query_embeddings
def query_chroma(collection, embedding, authors):
results = collection.query(
query_embeddings=[embedding.tolist()],
n_results=10,
where={"author": {"$in": authors}} if authors else {},
include=["documents", "metadatas", "distances"]
)
ids = results.get('ids', [[]])[0]
metadatas = results.get('metadatas', [[]])[0]
documents = results.get('documents', [[]])[0]
distances = results.get('distances', [[]])[0]
formatted_results = []
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
result_dict = {
"id": id_,
"author": metadata.get('author', 'Unknown author'),
"book": metadata.get('book', 'Unknown book'),
"section": metadata.get('section', 'Unknown section'),
"title": metadata.get('title', 'Untitled'),
"text": document_text,
"distance": distance
}
formatted_results.append(result_dict)
return formatted_results
with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
gr.Markdown("Geben Sie ein, wonach Sie suchen möchten (Query), filtern Sie nach Autoren (ohne Auswahl werden alle durchsucht) und klicken Sie auf **Suchen**, um zu suchen. Trennen Sie mehrere Fragen durch Semikola; die Suche dauert, unabhängig von der Anzahl der Abfragen, etwa 40 Sekunden, da das Embeddingmodell jedes Mal auf eine GPU geladen werden muss.")
#database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
author_inp = gr.Dropdown(label="Autoren", choices=authors_list_de, multiselect=True)
inp = gr.Textbox(label="Query", placeholder="Wie kann ich gesund leben?; Wie kann ich mich besser konzentrieren?; Was ist der Sinn des Lebens?; ...")
btn = gr.Button("Suchen")
results = gr.State()
#def update_authors(database):
# return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
#database_inp.change(
# fn=lambda database: update_authors(database),
# inputs=[database_inp],
# outputs=[author_inp]
#)
def perform_query(queries, authors, database):
task = "Given a question, retrieve passages that answer the question"
queries = [query.strip() for query in queries.split(';')]
embeddings = get_embeddings(queries, task)
#collection = collection_de if database == "German" else collection_en
collection = collection_de
results_data = []
for query, embedding in zip(queries, embeddings):
res = query_chroma(collection, embedding, authors)
results_data.append((query, res))
return results_data
btn.click(
perform_query,
inputs=[inp, author_inp],
outputs=[results]
)
@gr.render(inputs=[results])
def display_accordion(data):
for query, res in data:
with gr.Accordion(query, open=False) as acc:
for result in res:
with gr.Column():
author = result.get('author', 'Unknown author')
book = result.get('book', 'Unknown book')
text = result.get('text')
markdown_contents = f"**{author}, {book}**\n\n{text}"
gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")
demo.launch()