File size: 3,930 Bytes
2e98c79
 
 
 
 
 
3bbb5f4
2e98c79
10f043b
2ec7158
4295f9e
2ec7158
 
 
2e98c79
 
ff486ce
2e98c79
 
 
 
 
 
 
 
 
104f6dd
 
2e98c79
104f6dd
2e98c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ec7158
2e98c79
 
 
2ec7158
 
2e98c79
2ec7158
 
 
def34f2
2e98c79
 
2ec7158
def34f2
2e98c79
 
def34f2
2e98c79
def34f2
2e98c79
 
 
 
 
2d7464e
2e98c79
 
 
 
 
 
 
 
 
 
 
 
 
def34f2
2e98c79
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import gradio as gr
import chromadb

from sentence_transformers import SentenceTransformer

import spaces

@spaces.GPU
def get_embeddings(query, task):
    model = SentenceTransformer("Linq-AI-Research/Linq-Embed-Mistral", use_auth_token=os.getenv("HF_TOKEN"))
    prompt = f"Instruct: {task}\nQuery: {query}" 
    query_embeddings = model.encode([prompt])  
    return query_embeddings

# Initialize a persistent Chroma client and retrieve collection
client = chromadb.PersistentClient(path="./chroma")
collection = client.get_collection(name="phil_de")

authors_list = ["Ludwig Wittgenstein", "Sigmund Freud", "Marcus Aurelius", "Friedrich Nietzsche", "Epiktet", "Ernst Jünger", "Georg Christoph Lichtenberg", "Balthasar Gracian", "Hannah Arendt", "Erich Fromm", "Albert Camus"]
#authors_list = ["Friedrich Nietzsche", "Joscha Bach"]

def query_chroma(embeddings, authors, num_results=10):
    try:
        where_filter = {"author": {"$in": authors}} if authors else {}

        embeddings_list = embeddings[0].tolist()

        results = collection.query(
            query_embeddings=[embeddings_list],
            n_results=num_results,
            where=where_filter,
            include=["documents", "metadatas", "distances"]
        )

        ids = results.get('ids', [[]])[0]
        metadatas = results.get('metadatas', [[]])[0]
        documents = results.get('documents', [[]])[0]
        distances = results.get('distances', [[]])[0]

        formatted_results = []
        for id_, metadata, document_text, distance in zip(ids, metadatas, documents, distances):
            result_dict = {
                "id": id_,
                "author": metadata.get('author', 'Unknown author'),
                "book": metadata.get('book', 'Unknown book'),
                "section": metadata.get('section', 'Unknown section'),
                "title": metadata.get('title', 'Untitled'),
                "text": document_text,
                "distance": distance
            }
            formatted_results.append(result_dict)

        return formatted_results
    except Exception as e:
        return {"error": str(e)}


# Main function
def perform_query(query, authors, num_results):
    task = "Given a question, retrieve passages that answer the question"
    embeddings = get_embeddings(query, task)
    results = query_chroma(embeddings, authors, num_results)

    if "error" in results:
        return [gr.update(visible=True, value=f"Error: {results['error']}") for _ in range(max_textboxes * 2)]

    updates = []
    for res in results:
        markdown_content = f"**{res['author']}, {res['book']}**\n\n{res['text']}"
        updates.append(gr.update(visible=True, value=markdown_content))

    updates += [gr.update(visible=False)] * (max_textboxes - len(results))

    return updates

# Gradio interface
max_textboxes = 30

with gr.Blocks(css=".custom-markdown { border: 1px solid #ccc; padding: 10px; border-radius: 5px; }") as demo:
    gr.Markdown("Enter your query, filter authors (default is all), click **Search** to search. Click **Flag** if a result is relevant to the query and interesting to you.")
    with gr.Row():
        with gr.Column():
            inp = gr.Textbox(label="query", placeholder="Enter thought...")
            author_inp = gr.Dropdown(label="authors", choices=authors_list, multiselect=True)
            num_results_inp = gr.Number(label="number of results", value=10, step=1, minimum=1, maximum=max_textboxes)
            btn = gr.Button("Search")

    components = []
    textboxes = []

    for _ in range(max_textboxes):
        with gr.Column() as col:
            text_out = gr.Markdown(visible=False, elem_classes="custom-markdown")
            components.append(text_out)
            textboxes.append(text_out)

    btn.click(
        fn=perform_query, 
        inputs=[inp, author_inp, num_results_inp],
        outputs=components
    )
    

demo.launch()