File size: 4,280 Bytes
14d74dd
 
 
 
afb895f
14d74dd
19fe157
851e1ef
2c88b04
851e1ef
 
0b8ee40
14d74dd
 
04bac47
e46996a
 
 
 
14d74dd
c67d420
14d74dd
 
c67d420
14d74dd
c67d420
851e1ef
14d74dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c67d420
14d74dd
e46996a
 
 
14d74dd
 
851e1ef
 
 
 
 
 
c67d420
851e1ef
 
 
c67d420
851e1ef
 
c67d420
 
 
 
 
14d74dd
 
851e1ef
 
c67d420
14d74dd
851e1ef
c67d420
 
 
 
 
 
 
 
 
 
 
 
 
 
e46996a
851e1ef
 
 
e46996a
14d74dd
851e1ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces

@spaces.GPU
def get_embeddings(queries, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
    query_embeddings = model.encode(prompts)  
    return query_embeddings

# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]

def query_chroma(collection, embedding, authors):
    try:
        where_filter = {"author": {"$in": authors}} if authors else {}
        # Directly use the embedding provided, already in list format suitable for the query
        results = collection.query(
            query_embeddings=[embedding.tolist()],  # Ensure embedding is properly formatted
            n_results=10,
            where=where_filter,
            include=["documents", "metadatas", "distances"]
        )

        ids = results.get('ids', [[]])[0]
        metadatas = results.get('metadatas', [[]])[0]
        documents = results.get('documents', [[]])[0]
        distances = results.get('distances', [[]])[0]

        formatted_results = []
        for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
            result_dict = {
                "id": id_,
                "author": metadata.get('author', 'Unknown author'),
                "book": metadata.get('book', 'Unknown book'),
                "section": metadata.get('section', 'Unknown section'),
                "title": metadata.get('title', 'Untitled'),
                "text": document_text,
                "distance": distance
            }
            formatted_results.append(result_dict)

        return formatted_results
    except Exception as e:
        return [{"error": str(e)}]

def update_authors(database):
    return gr.update(choices=authors_list_de if database == "German" else authors_list_en)



with gr.Blocks() as demo:
    gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search.")
    database_inp = gr.Dropdown(label="Database", choices=["English", "German"], value="German")
    author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
    inp = gr.Textbox(label="Query", placeholder="Enter questions separated by semicolons...")
    btn = gr.Button("Search")
    results = gr.State()  # Store results in a State component

    def perform_query(queries, authors, database):
        task = "Given a question, retrieve passages that answer the question"
        queries = queries.split(';')
        embeddings = get_embeddings(queries, task)
        collection = collection_de if database == "German" else collection_en
        results_data = []
        for query, embedding in zip(queries, embeddings):
            res = query_chroma(collection, embedding, authors)
            results_data.append((query, res))
        return results_data

    btn.click(
        perform_query,
        inputs=[inp, author_inp, database_inp],
        outputs=[results]
    )

    @gr.render(inputs=[results])
    def display_accordion(data):
        output_blocks = []
        for query, res in data:
            with gr.Accordion(query) as acc:
                if not res:
                    markdown_contents = "No results found."
                elif "error" in res[0]:
                    markdown_contents = f"Error retrieving data: {res[0]['error']}"
                else:
                    markdown_contents = "\n".join(f"**{r['author']}, {r['book']}**\n\n{r['text']}" for r in res)
                gr.Markdown(markdown_contents)


    database_inp.change(
        fn=lambda database: update_authors(database),
        inputs=[database_inp],
        outputs=[author_inp]
    )

demo.launch()