File size: 4,130 Bytes
dd2978a
35db041
 
dd2978a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a4ebff
35db041
 
7a4ebff
 
dd2978a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from pathlib import Path
from typing import Any, Dict, List

import gradio as gr
from huggingface_hub import snapshot_download
from ragatouille import RAGPretrainedModel
from toolz import unique

# Top-level variables
INDEX_PATH = Path(".ragatouille/colbert/indexes/my_index_with_ids_and_metadata/")
REPO_ID = "davanstrien/search-index"

INITIAL_QUERY = "hello world"
DEFAULT_K = 10


def initialize_index():
    INDEX_PATH.mkdir(parents=True, exist_ok=True)
    snapshot_download(REPO_ID, repo_type="dataset", local_dir=INDEX_PATH)
    rag = RAGPretrainedModel.from_index(INDEX_PATH)
    # Warm up index
    rag.search(INITIAL_QUERY)
    return rag


def format_results_as_markdown(results: List[Dict[str, Any]]) -> str:
    markdown = ""
    for result in results:
        content = result["content"]
        score = result["score"]
        rank = result["rank"]
        document_id = result["document_id"]
        passage_id = result["passage_id"]
        link = f"https://huggingface.co/datasets/{document_id}"

        markdown += f"### Result {rank}\n"
        markdown += f"**Score:** {score}\n\n"
        markdown += f"**Document ID:** [{document_id}]({link})\n\n"
        markdown += f"**Passage ID:** {passage_id}\n\n"

        # Limit initial content display to 1000 characters
        preview = f"{content[:1000]}..." if len(content) > 1000 else content
        markdown += f"{preview}\n\n"

        # Add expandable section for full content if it's longer than 1000 characters
        if len(content) > 1000:
            markdown += "<details>\n"
            markdown += "<summary>Click to expand full content</summary>\n\n"
            markdown += f"{content}\n\n"
            markdown += "</details>\n\n"

        markdown += "---\n\n"

    return markdown


def search_with_ragatouille(query, k=DEFAULT_K, make_unique=False):
    results = RAG.search(query, k=k)
    if make_unique:
        results = make_results_unique(results)
    return format_results_as_markdown(results)


def make_results_unique(results: List[Dict[str, Any]]):
    unique_results = unique(results, lambda x: x["document_id"])
    return list(unique_results)


def create_ragatouille_interface():
    with gr.Blocks() as ragatouille_demo:
        gr.Markdown("### RAGatouille Dataset Search")
        gr.Markdown(
            """This interface allows you to search inside dataset cards on the Hub using the [answerai-colbert-small-v1](https://huggingface.co/answerdotai/answerai-colbert-small-v1) ColBERT model via [RAGatouille](https://github.com/AnswerDotAI/RAGatouille). Please be aware that this is an early prototype and may not work as expected!

            ## Notes:
            **Not all datasets are indexed yet!**
            For a dataset to be indexed:
            - It must have a dataset card on the Hub. You can find documentation on how to write a good dataset card [here](https://huggingface.co/docs/hub/datasets-cards).
            - The dataset must have at least 1 like and 1 download
            - The card must be a minimum length (to weed out low quality cards)
            **At the moment the index is refreshed when I decide to do it, so it may not be up to date.** If there is sufficient interest I will implement a daily refresh (give this repo a like if you'd like this feature!)
            Feel free to open a discussion to give feedback or request features &#129303;
            """
        )
        with gr.Column():
            query = gr.Textbox(label="Search query", placeholder="medieval handwriting")
        with gr.Row():
            k = gr.Slider(1, 100, value=DEFAULT_K, step=1, label="Number of Results")
            make_unique = gr.Checkbox(False, label="Show each dataset only once?")
        search_button = gr.Button("Search")
        search_button.click(
            search_with_ragatouille,
            inputs=[query, k, make_unique],
            outputs=gr.Markdown(label="Results"),
        )
    return ragatouille_demo


# Initialize RAG globally
RAG = initialize_index()


def main():
    demo = create_ragatouille_interface()
    demo.launch()


if __name__ == "__main__":
    main()