File size: 4,347 Bytes
2e98c79
 
 
 
3bbb5f4
2e98c79
331b253
08029e5
 
 
 
331b253
10f043b
7d6132f
4295f9e
7d6132f
 
2ec7158
2e98c79
2a5653d
331b253
 
 
 
 
 
2e98c79
331b253
 
 
 
2e8a9c7
331b253
 
 
 
 
 
 
 
 
 
 
 
2e98c79
331b253
2e98c79
92ed022
08029e5
dde2538
08029e5
 
 
92ed022
 
dde2538
 
331b253
dde2538
 
 
 
 
7d6132f
 
 
dde2538
7d6132f
dde2538
08029e5
2a5653d
 
 
 
 
2e98c79
 
7d6132f
dde2538
2a5653d
2e98c79
7d6132f
2a5653d
 
 
92ed022
331b253
 
 
 
 
 
 
2e98c79
7d6132f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces

client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
#collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]

@spaces.GPU
def get_embeddings(queries, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
    query_embeddings = model.encode(prompts)  
    return query_embeddings

def query_chroma(collection, embedding, authors):
    results = collection.query(
        query_embeddings=[embedding.tolist()],
        n_results=10,
        where={"author": {"$in": authors}} if authors else {},
        include=["documents", "metadatas", "distances"]
    )

    ids = results.get('ids', [[]])[0]
    metadatas = results.get('metadatas', [[]])[0]
    documents = results.get('documents', [[]])[0]
    distances = results.get('distances', [[]])[0]

    formatted_results = []
    for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
        result_dict = {
            "id": id_,
            "author": metadata.get('author', 'Unknown author'),
            "book": metadata.get('book', 'Unknown book'),
            "section": metadata.get('section', 'Unknown section'),
            "title": metadata.get('title', 'Untitled'),
            "text": document_text,
            "distance": distance
        }
        formatted_results.append(result_dict)

    return formatted_results

with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }")  as demo:
    gr.Markdown("Geben Sie ein, wonach Sie suchen möchten (Query), filtern Sie nach Autoren (ohne Auswahl werden alle durchsucht) und klicken Sie auf **Suchen**, um zu suchen. Trennen Sie mehrere Fragen durch Semikola; die Suche dauert, unabhängig von der Anzahl der Abfragen, etwa 40 Sekunden, da das Embeddingmodell jedes Mal auf eine GPU geladen werden muss.")
    #database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
    author_inp = gr.Dropdown(label="Autoren", choices=authors_list_de, multiselect=True)
    inp = gr.Textbox(label="Query", placeholder="Wie kann ich gesund leben?; Wie kann ich mich besser konzentrieren?; Was ist der Sinn des Lebens?; ...")
    btn = gr.Button("Suchen")
    results = gr.State()

    #def update_authors(database):
    #    return gr.update(choices=authors_list_de if database == "German" else authors_list_en)

    #database_inp.change(
    #    fn=lambda database: update_authors(database),
    #    inputs=[database_inp],
    #    outputs=[author_inp]
    #)

    def perform_query(queries, authors, database):
        task = "Given a question, retrieve passages that answer the question"
        queries = [query.strip() for query in queries.split(';')]
        embeddings = get_embeddings(queries, task)
        #collection = collection_de if database == "German" else collection_en
        collection = collection_de
        results_data = []
        for query, embedding in zip(queries, embeddings):
            res = query_chroma(collection, embedding, authors)
            results_data.append((query, res))
        return results_data

    btn.click(
        perform_query,
        inputs=[inp, author_inp],
        outputs=[results]
    )

    @gr.render(inputs=[results])
    def display_accordion(data):
        for query, res in data:
            with gr.Accordion(query, open=False) as acc:
                for result in res:
                    with gr.Column():
                        author = result.get('author', 'Unknown author')
                        book = result.get('book', 'Unknown book')
                        text = result.get('text')
                        markdown_contents = f"**{author}, {book}**\n\n{text}"
                        gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")

demo.launch()