Spaces:
Sleeping
Sleeping
File size: 4,280 Bytes
14d74dd afb895f 14d74dd 19fe157 851e1ef 2c88b04 851e1ef 0b8ee40 14d74dd 04bac47 e46996a 14d74dd c67d420 14d74dd c67d420 14d74dd c67d420 851e1ef 14d74dd c67d420 14d74dd e46996a 14d74dd 851e1ef c67d420 851e1ef c67d420 851e1ef c67d420 14d74dd 851e1ef c67d420 14d74dd 851e1ef c67d420 e46996a 851e1ef e46996a 14d74dd 851e1ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces
@spaces.GPU
def get_embeddings(queries, task):
model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
query_embeddings = model.encode(prompts)
return query_embeddings
# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]
def query_chroma(collection, embedding, authors):
try:
where_filter = {"author": {"$in": authors}} if authors else {}
# Directly use the embedding provided, already in list format suitable for the query
results = collection.query(
query_embeddings=[embedding.tolist()], # Ensure embedding is properly formatted
n_results=10,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
ids = results.get('ids', [[]])[0]
metadatas = results.get('metadatas', [[]])[0]
documents = results.get('documents', [[]])[0]
distances = results.get('distances', [[]])[0]
formatted_results = []
for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
result_dict = {
"id": id_,
"author": metadata.get('author', 'Unknown author'),
"book": metadata.get('book', 'Unknown book'),
"section": metadata.get('section', 'Unknown section'),
"title": metadata.get('title', 'Untitled'),
"text": document_text,
"distance": distance
}
formatted_results.append(result_dict)
return formatted_results
except Exception as e:
return [{"error": str(e)}]
def update_authors(database):
return gr.update(choices=authors_list_de if database == "German" else authors_list_en)
with gr.Blocks() as demo:
gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search.")
database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German")
author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...")
btn = gr.Button("Search")
results = gr.State() # Store results in a State component
def perform_query(queries, authors, database):
task = "Given a question, retrieve passages that answer the question"
queries = queries.split(';')
embeddings = get_embeddings(queries, task)
collection = collection_de if database == "German" else collection_en
results_data = []
for query, embedding in zip(queries, embeddings):
res = query_chroma(collection, embedding, authors)
results_data.append((query, res))
return results_data
btn.click(
perform_query,
inputs=[inp, author_inp, database_inp],
outputs=[results]
)
@gr.render(inputs=[results])
def display_accordion(data):
output_blocks = []
for query, res in data:
with gr.Accordion(query) as acc:
if not res:
markdown_contents = "No results found."
elif "error" in res[0]:
markdown_contents = f"Error retrieving data: {res[0]['error']}"
else:
markdown_contents = "\n".join(f"**{r['author']}, {r['book']}**\n\n{r['text']}" for r in res)
gr.Markdown(markdown_contents)
database_inp.change(
fn=lambda database: update_authors(database),
inputs=[database_inp],
outputs=[author_inp]
)
demo.launch() |