File size: 6,760 Bytes
14d74dd
 
 
 
afb895f
14d74dd
800ef3d
8f4bc97
 
 
 
800ef3d
19fe157
851e1ef
2c88b04
851e1ef
 
0b8ee40
14d74dd
c67d420
800ef3d
 
f156dfd
800ef3d
 
 
14d74dd
800ef3d
 
 
 
e46996a
800ef3d
 
 
 
27ae4d9
 
 
 
800ef3d
 
 
 
14d74dd
800ef3d
14d74dd
5ab1db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f41789b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ab1db6
 
f41789b
5ab1db6
 
 
 
 
 
 
 
f41789b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ab1db6
 
 
27ae4d9
da3c7f4
8f4bc97
27ae4d9
8f4bc97
f41789b
c897d84
 
da3c7f4
 
800ef3d
da3c7f4
 
 
 
 
851e1ef
b17764a
27ae4d9
da3c7f4
851e1ef
8f4bc97
c67d420
 
 
 
b17764a
14d74dd
 
f41789b
b17764a
 
 
 
851e1ef
da3c7f4
b17764a
14d74dd
851e1ef
c67d420
 
 
f41789b
800ef3d
 
27ae4d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800ef3d
14d74dd
f41789b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import gradio as gr
import chromadb
from sentence_transformers import SentenceTransformer
import spaces

client = chromadb.PersistentClient(path="./chroma")
collection_de = client.get_collection(name="phil_de")
#collection_en = client.get_collection(name="phil_en")
authors_list_de = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list_en = ["Friedrich Nietzsche", "Joscha Bach"]

@spaces.GPU
def get_embeddings(queries, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompts = [f"Instruct: {task}\nQuery: {query}" for query in queries]
    query_embeddings = model.encode(prompts)  
    return query_embeddings

def query_chroma(collection, embedding, authors):
    results = collection.query(
        query_embeddings=[embedding.tolist()],
        n_results=20,
        where={"author": {"$in": authors}} if authors else {},
        include=["documents", "metadatas", "distances"]
    )

    ids = results.get('ids', [[]])[0]
    metadatas = results.get('metadatas', [[]])[0]
    documents = results.get('documents', [[]])[0]
    distances = results.get('distances', [[]])[0]

    formatted_results = []
    for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
        result_dict = {
            "id": id_,
            "author": metadata.get('author', ''),
            "book": metadata.get('book', ''),
            "section": metadata.get('section', ''),
            "title": metadata.get('title', ''),
            "text": document_text,
            "distance": distance
        }
        formatted_results.append(result_dict)

    return formatted_results


theme = gr.themes.Soft(
    primary_hue="indigo",
    secondary_hue="slate",
    neutral_hue="slate",
    spacing_size="lg",
    radius_size="lg",
    text_size="lg",
    font=["Helvetica", "sans-serif"],
    font_mono=["Courier", "monospace"],
).set(
    body_text_color="*neutral_800",
    block_background_fill="*neutral_50",
    block_border_width="0px",
    button_primary_background_fill="*primary_600",
    button_primary_background_fill_hover="*primary_700",
    button_primary_text_color="white",
    input_background_fill="white",
    input_border_color="*neutral_200",
    input_border_width="1px",
    checkbox_background_color_selected="*primary_600",
    checkbox_border_color_selected="*primary_600",
)

custom_css = """
/* Remove outer padding, margins, and borders */
gradio-app,
gradio-app > div,
gradio-app .gradio-container {
    padding: 0 !important;
    margin: 0 !important;
    border: none !important;
}

/* Remove any potential outlines */
gradio-app:focus,
gradio-app > div:focus,
gradio-app .gradio-container:focus {
    outline: none !important;
}

/* Ensure full width */
gradio-app {
    width: 100% !important;
    display: block !important;
}

.custom-markdown { 
    border: 1px solid var(--neutral-200); 
    padding: 10px; 
    border-radius: var(--radius-lg);
    background-color: var(--color-background-primary);
    margin-bottom: 15px;
}
.custom-markdown p {
    margin-bottom: 10px;
    line-height: 1.6;
}

@media (max-width: 768px) {
    gradio-app, 
    gradio-app > div,
    gradio-app .gradio-container {
        padding-left: 1px !important;
        padding-right: 1px !important;
    }
    .custom-markdown {
        padding: 5px;
    }
    .accordion {
        margin-left: -10px;
        margin-right: -10px;
    }
}
"""

with gr.Blocks(theme=theme, css=custom_css) as demo:
    gr.Markdown("Geben Sie ein, wonach Sie suchen möchten (Query), trennen Sie mehrere Suchanfragen durch Semikola; filtern Sie nach Autoren (ohne Auswahl werden alle durchsucht) und klicken Sie auf **Suchen**, um zu suchen.")
    #database_inp = gr.Dropdown(label="Database", choices=["German", "English"], value="German")
    author_inp = gr.Dropdown(label="Autoren", choices=authors_list_de, multiselect=True)
    inp = gr.Textbox(label="Query", lines=3, placeholder="Wie kann ich gesund leben?; Wie kann ich mich besser konzentrieren?; Was ist der Sinn des Lebens?; ...")
    btn = gr.Button("Suchen")
    loading_indicator = gr.Markdown(visible=False, elem_id="loading-indicator")
    results = gr.State()

    #def update_authors(database):
    #    return gr.update(choices=authors_list_de if database == "German" else authors_list_en)

    #database_inp.change(
    #    fn=lambda database: update_authors(database),
    #    inputs=[database_inp],
    #    outputs=[author_inp]
    #)

    def perform_query(queries, authors):
        task = "Suche den zur Frage passenden Text"
        queries = [query.strip() for query in queries.split(';')]
        embeddings = get_embeddings(queries, task)
        collection = collection_de
        results_data = []
        for query, embedding in zip(queries, embeddings):
            res = query_chroma(collection, embedding, authors)
            results_data.append((query, res))
        return results_data, ""

    btn.click(
        fn=lambda: ("", gr.update(visible=True)),
        inputs=None,
        outputs=[loading_indicator, loading_indicator],
        queue=False
    ).then(
        perform_query,
        inputs=[inp, author_inp],
        outputs=[results, loading_indicator]
    )

    @gr.render(inputs=[results])
    def display_accordion(data):
        for query, res in data:
            with gr.Accordion(query, open=False, elem_classes="accordion") as acc:
                for result in res:
                    with gr.Column():
                        author = str(result.get('author', ''))
                        book = str(result.get('book', ''))
                        section = str(result.get('section', ''))
                        title = str(result.get('title', ''))
                        text = str(result.get('text', ''))

                        header_parts = []
                        if author and author != "Unknown":
                            header_parts.append(author)
                        if book and book != "Unknown":
                            header_parts.append(book)
                        if section and section != "Unknown":
                            header_parts.append(section)
                        if title and title != "Unknown":
                            header_parts.append(title)
                        
                        header = ", ".join(header_parts)
                        markdown_contents = f"**{header}**\n\n{text}"
                        gr.Markdown(value=markdown_contents, elem_classes="custom-markdown")

demo.launch(inline=False)