File size: 6,988 Bytes
50e3a95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
from typing import Tuple, Optional, List
from openai import OpenAI
from utils.db_utils import DatabaseUtils
from utils.embedding_utils import parallel_generate_embeddings

def create_embeddings_tab(openai_client: OpenAI, db_utils: DatabaseUtils, databases: List[str]) -> Tuple[gr.Tab, dict]:
    """Create the embeddings generation tab UI
    
    Args:
        openai_client: OpenAI client instance
        db_utils: DatabaseUtils instance
        databases: List of available databases
        
    Returns:
        Tuple[gr.Tab, dict]: The tab component and its interface elements
    """
    def update_collections(db_name: str) -> gr.Dropdown:
        """Update collections dropdown when database changes"""
        collections = db_utils.get_collections(db_name)
        # If there's only one collection, select it by default
        value = collections[0] if len(collections) == 1 else None
        return gr.Dropdown(choices=collections, value=value)
        
    def update_fields(db_name: str, collection_name: str) -> gr.Dropdown:
        """Update fields dropdown when collection changes"""
        if db_name and collection_name:
            fields = db_utils.get_field_names(db_name, collection_name)
            return gr.Dropdown(choices=fields)
        return gr.Dropdown(choices=[])
    
    def generate_embeddings(
        db_name: str, 
        collection_name: str, 
        field_name: str, 
        embedding_field: str, 
        limit: int = 10, 
        progress=gr.Progress()
    ) -> Tuple[str, str]:
        """Generate embeddings for documents with progress tracking"""
        try:
            db = db_utils.client[db_name]
            collection = db[collection_name]
            
            # Count documents that need embeddings
            total_docs = collection.count_documents({field_name: {"$exists": True}})
            if total_docs == 0:
                return f"No documents found with field '{field_name}'", ""
                
            # Get total count of documents that need processing
            query = {
                field_name: {"$exists": True},
                embedding_field: {"$exists": False}  # Only get docs without embeddings
            }
            total_to_process = collection.count_documents(query)
            if total_to_process == 0:
                return "No documents found that need embeddings", ""
                
            # Apply limit if specified
            if limit > 0:
                total_to_process = min(total_to_process, limit)
                
            print(f"\nFound {total_to_process} documents that need embeddings...")
            
            # Progress tracking
            progress_text = ""
            def update_progress(prog: float, processed: int, total: int):
                nonlocal progress_text
                progress_text = f"Progress: {prog:.1f}% ({processed}/{total} documents)\n"
                print(progress_text)  # Terminal logging
                progress(prog/100, f"Processed {processed}/{total} documents")
                
            # Show initial progress
            update_progress(0, 0, total_to_process)
            
            # Create cursor for batch processing
            cursor = collection.find(query)
            if limit > 0:
                cursor = cursor.limit(limit)
                
            # Generate embeddings in parallel with cursor-based batching
            processed = parallel_generate_embeddings(
                collection=collection,
                cursor=cursor,
                field_name=field_name,
                embedding_field=embedding_field,
                openai_client=openai_client,
                total_docs=total_to_process,
                callback=update_progress
            )
                    
            # Return completion message and final progress
            instructions = f"""
Successfully generated embeddings for {processed} documents using parallel processing!

To create the vector search index in MongoDB Atlas:
1. Go to your Atlas cluster
2. Click on 'Search' tab
3. Create an index named 'vector_index' with this configuration:
{{
  "fields": [
    {{
      "type": "vector",
      "path": "{embedding_field}",
      "numDimensions": 1536,
      "similarity": "dotProduct"
    }}
  ]
}}

You can now use the search tab with:
- Field to search: {field_name}
- Embedding field: {embedding_field}
"""
            return instructions, progress_text
            
        except Exception as e:
            return f"Error: {str(e)}", ""
    
    # Create the tab UI
    with gr.Tab("Generate Embeddings") as tab:
        with gr.Row():
            db_input = gr.Dropdown(
                choices=databases,
                label="Select Database",
                info="Available databases in Atlas cluster"
            )
            collection_input = gr.Dropdown(
                choices=[],
                label="Select Collection",
                info="Collections in selected database"
            )
        with gr.Row():
            field_input = gr.Dropdown(
                choices=[],
                label="Select Field for Embeddings",
                info="Fields available in collection"
            )
            embedding_field_input = gr.Textbox(
                label="Embedding Field Name",
                value="embedding",
                info="Field name where embeddings will be stored"
            )
            limit_input = gr.Number(
                label="Document Limit",
                value=10,
                minimum=0,
                info="Number of documents to process (0 for all documents)"
            )
            
        generate_btn = gr.Button("Generate Embeddings")
        generate_output = gr.Textbox(label="Results", lines=10)
        progress_output = gr.Textbox(label="Progress", lines=3)
        
        # Set up event handlers
        db_input.change(
            fn=update_collections,
            inputs=[db_input],
            outputs=[collection_input]
        )
        
        collection_input.change(
            fn=update_fields,
            inputs=[db_input, collection_input],
            outputs=[field_input]
        )
        
        generate_btn.click(
            fn=generate_embeddings,
            inputs=[
                db_input, 
                collection_input, 
                field_input, 
                embedding_field_input, 
                limit_input
            ],
            outputs=[generate_output, progress_output]
        )
    
    # Return the tab and its interface elements
    interface = {
        'db_input': db_input,
        'collection_input': collection_input,
        'field_input': field_input,
        'embedding_field_input': embedding_field_input,
        'limit_input': limit_input,
        'generate_btn': generate_btn,
        'generate_output': generate_output,
        'progress_output': progress_output
    }
    
    return tab, interface