Spaces:
Running
Running
File size: 4,349 Bytes
2e98c79 3bbb5f4 2e98c79 331b253 dde2538 331b253 dde2538 331b253 10f043b 7d6132f 4295f9e 7d6132f 2ec7158 2e98c79 2a5653d 331b253 2e98c79 331b253 2e8a9c7 331b253 2e98c79 331b253 2e98c79 92ed022 4699e8e dde2538 7d6132f 92ed022 dde2538 331b253 dde2538 7d6132f dde2538 7d6132f dde2538 2a5653d 2e98c79 7d6132f dde2538 2a5653d 2e98c79 7d6132f 2a5653d 92ed022 331b253 2e98c79 7d6132f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces
client = chromadb.PersistentClient(path="./chroma")
#collection_de = client.get_collection(name="phil_de")
collection_en = client.get_collection(name="phil_en")
#authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
@spaces.GPU
def get_embeddings(queries, task):
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
query_embeddings = model.encode(prompts)
return query_embeddings
def query_chroma(collection, embedding, authors):
results = collection.query(
query_embeddings=[embedding.tolist()],
n_results=10,
where={"author": {"$in": authors}} if authors else {},
include=["documents", "metadatas", "distances"]
)
ids = results.get('ids', [[]])[0]
metadatas = results.get('metadatas', [[]])[0]
documents = results.get('documents', [[]])[0]
distances = results.get('distances', [[]])[0]
formatted_results = []
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
result_dict = {
"id": id_,
"author": metadata.get('author', 'Unknown author'),
"book": metadata.get('book', 'Unknown book'),
"section": metadata.get('section', 'Unknown section'),
"title": metadata.get('title', 'Untitled'),
"text": document_text,
"distance": distance
}
formatted_results.append(result_dict)
return formatted_results
with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Delimit multiple queries with semicola; since there is a quota for each user (based on IP) it makes sense to query in batches. The search takes around 40 seconds, regardless of the number of queries, because the embedding model needs to be loaded to a GPU each time.")
#database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
author_inp = gr.Dropdown(label="Authors", choices=authors_list_en, multiselect=True)
inp = gr.Textbox(label="Query", placeholder="How can I live a healthy life?; How can I improve my ability to focus?; What is the meaning of life?; ...")
btn = gr.Button("Search")
results = gr.State()
#def update_authors(database):
# return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
#database_inp.change(
# fn=lambda database: update_authors(database),
# inputs=[database_inp],
# outputs=[author_inp]
#)
def perform_query(queries, authors, database):
task = "Given a question, retrieve passages that answer the question"
queries = [query.strip() for query in queries.split(';')]
embeddings = get_embeddings(queries, task)
#collection = collection_de if database == "German" else collection_en
collection = collection_en
results_data = []
for query, embedding in zip(queries, embeddings):
res = query_chroma(collection, embedding, authors)
results_data.append((query, res))
return results_data
btn.click(
perform_query,
inputs=[inp, author_inp],
outputs=[results]
)
@gr.render(inputs=[results])
def display_accordion(data):
for query, res in data:
with gr.Accordion(query, open=False) as acc:
for result in res:
with gr.Column():
author = result.get('author', 'Unknown author')
book = result.get('book', 'Unknown book')
text = result.get('text')
markdown_contents = f"**{author}, {book}**\n\n{text}"
gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")
demo.launch() |