File size: 4,301 Bytes
14d74dd
 
 
 
afb895f
14d74dd
800ef3d
 
 
 
 
 
19fe157
851e1ef
2c88b04
851e1ef
 
0b8ee40
14d74dd
c67d420
800ef3d
 
 
 
 
 
14d74dd
800ef3d
 
 
 
e46996a
800ef3d
 
 
 
 
 
 
 
 
 
 
 
14d74dd
800ef3d
14d74dd
c897d84
d525093
800ef3d
851e1ef
d525093
851e1ef
c897d84
 
800ef3d
 
 
c897d84
 
 
 
 
851e1ef
 
 
c67d420
851e1ef
 
c67d420
 
 
 
 
14d74dd
 
851e1ef
 
c67d420
14d74dd
851e1ef
c67d420
 
 
c897d84
800ef3d
 
 
 
 
 
 
14d74dd
851e1ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces

client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]

@spaces.GPU
def get_embeddings(queries, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
    query_embeddings = model.encode(prompts)  
    return query_embeddings

def query_chroma(collection, embedding, authors):
    results = collection.query(
        query_embeddings=[embedding.tolist()],
        n_results=10,
        where={"author": {"$in": authors}} if authors else {},
        include=["documents", "metadatas", "distances"]
    )

    ids = results.get('ids', [[]])[0]
    metadatas = results.get('metadatas', [[]])[0]
    documents = results.get('documents', [[]])[0]
    distances = results.get('distances', [[]])[0]

    formatted_results = []
    for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
        result_dict = {
            "id": id_,
            "author": metadata.get('author', 'Unknown author'),
            "book": metadata.get('book', 'Unknown book'),
            "section": metadata.get('section', 'Unknown section'),
            "title": metadata.get('title', 'Untitled'),
            "text": document_text,
            "distance": distance
        }
        formatted_results.append(result_dict)

    return formatted_results

with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }")  as demo:
    gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Delimit multiple queries with semicola; since there is a quota for each user (based on IP) it makes sense to query in batches. The search takes around 50 seconds because the embedding model needs to be loaded to a GPU each time (another reason to do multiple queries at once).")
    database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
    author_inp = gr.Dropdown(label="Authors", choices=authors_list_de, multiselect=True)
    inp = gr.Textbox(label="Query", placeholder="Wie kann ich gesund leben und bedeutet Gesundheit für jeden das gleiche?; Was ist der Sinn des Lebens?; ...")
    btn = gr.Button("Search")
    results = gr.State()

    def update_authors(database):
        return gr.update(choices=authors_list_de if database == "German" else authors_list_en)

    database_inp.change(
        fn=lambda database: update_authors(database),
        inputs=[database_inp],
        outputs=[author_inp]
    )

    def perform_query(queries, authors, database):
        task = "Given a question, retrieve passages that answer the question"
        queries = queries.split(';')
        embeddings = get_embeddings(queries, task)
        collection = collection_de if database == "German" else collection_en
        results_data = []
        for query, embedding in zip(queries, embeddings):
            res = query_chroma(collection, embedding, authors)
            results_data.append((query, res))
        return results_data

    btn.click(
        perform_query,
        inputs=[inp, author_inp, database_inp],
        outputs=[results]
    )

    @gr.render(inputs=[results])
    def display_accordion(data):
        for query, res in data:
            with gr.Accordion(query, open=False) as acc:
                for result in res:
                    with gr.Column():
                        author = result.get('author', 'Unknown author')
                        book = result.get('book', 'Unknown book')
                        text = result.get('text')
                        markdown_contents = f"**{author}, {book}**\n\n{text}"
                        gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")

demo.launch()