File size: 34,242 Bytes
b1b5dc2
 
 
 
 
5460b58
2f63946
b1b5dc2
 
1d3ba67
 
 
b1b5dc2
 
 
 
 
 
 
 
387f245
 
 
b1b5dc2
 
 
 
 
 
1431c60
b1b5dc2
 
 
 
 
 
41c382c
887a63e
 
 
 
d491a19
 
b1b5dc2
 
 
60c1b0a
 
 
 
b1b5dc2
60c1b0a
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fbb97c
f2b9a34
 
2b75953
f2b9a34
 
66a3d8f
f2b9a34
60fbdb7
b1b5dc2
387f245
 
 
 
 
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c1b0a
887a63e
 
 
 
 
 
 
 
 
 
 
 
60c1b0a
 
887a63e
 
 
 
 
 
 
60c1b0a
 
 
 
 
 
3d53dd3
 
 
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d53dd3
 
 
 
 
 
887a63e
 
 
 
 
 
 
 
 
300b8a0
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1431c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1b5dc2
1431c60
 
 
 
 
 
 
 
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
887a63e
 
 
 
 
 
 
 
 
 
 
 
 
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37c745
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37c745
b1b5dc2
 
 
 
 
 
 
 
 
1431c60
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37c745
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d37c745
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1c0285
 
 
 
 
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
e1c0285
b1b5dc2
 
 
 
 
 
 
 
 
e1c0285
b1b5dc2
 
 
 
 
 
 
 
e1c0285
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5460b58
b1b5dc2
 
 
 
 
 
 
d37c745
b1b5dc2
887a63e
b1b5dc2
 
 
5460b58
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
887a63e
 
 
 
b1b5dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60fbdb7
2282ec1
b1b5dc2
 
 
387f245
60c1b0a
7809ddd
60c1b0a
7809ddd
387f245
 
 
b76e10d
15ce78e
b76e10d
15ce78e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
import time
print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}")

# Standard library imports
import os
from pathlib import Path
from datetime import datetime
from itertools import chain

import base64
import json

# Third-party imports
import numpy as np
import pandas as pd
import torch
import gradio as gr

print(f"Gradio version: {gr.__version__}")

from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import uvicorn
import matplotlib.pyplot as plt
import tqdm
import colormaps
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize

import random

import opinionated # for fonts
plt.style.use("opinionated_rc")

from sklearn.neighbors import NearestNeighbors


def is_running_in_hf_zero_gpu():
    print(os.environ.get("SPACES_ZERO_GPU"))
    return os.environ.get("SPACES_ZERO_GPU")
    
is_running_in_hf_zero_gpu()

def is_running_in_hf_space():
    return "SPACE_ID" in os.environ

if is_running_in_hf_space():
    import spaces # necessary to run on Zero.
    from spaces.zero.client import _get_token

#if is_running_in_hf_space():
#import spaces # necessary to run on Zero.
#print(f"Spaces version: {spaces.__version__}")

import datamapplot
import pyalex

# Local imports
from openalex_utils import (
    openalex_url_to_pyalex_query, 
    get_field,
    process_records_to_df,
    openalex_url_to_filename
)
from styles import DATAMAP_CUSTOM_CSS
from data_setup import (
    download_required_files,
    setup_basemap_data,
    setup_mapper,
    setup_embedding_model,
    
)

from network_utils import create_citation_graph, draw_citation_graph




# Configure OpenAlex
pyalex.config.email = "[email protected]"

print(f"Imports completed: {time.strftime('%Y-%m-%d %H:%M:%S')}")



# Create a static directory to store the dynamic HTML files
static_dir = Path("./static")
static_dir.mkdir(parents=True, exist_ok=True)

# Tell Gradio which absolute paths are allowed to be served
os.environ["GRADIO_ALLOWED_PATHS"] = str(static_dir.resolve())
print("os.environ['GRADIO_ALLOWED_PATHS'] =", os.environ["GRADIO_ALLOWED_PATHS"])


# Create FastAPI app
app = FastAPI()

# Mount the static directory
app.mount("/static", StaticFiles(directory="static"), name="static")





# Resource configuration
REQUIRED_FILES = {
    "100k_filtered_OA_sample_cluster_and_positions_supervised.pkl": 
        "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/100k_filtered_OA_sample_cluster_and_positions_supervised.pkl",
    "umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl":
        "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl"
}
BASEMAP_PATH = "100k_filtered_OA_sample_cluster_and_positions_supervised.pkl"
MAPPER_PARAMS_PATH = "umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl"
MODEL_NAME = "m7n/discipline-tuned_specter_2_024"

# Initialize models and data
start_time = time.time()
print("Initializing resources...")

download_required_files(REQUIRED_FILES)
basedata_df = setup_basemap_data(BASEMAP_PATH)
mapper = setup_mapper(MAPPER_PARAMS_PATH)
model = setup_embedding_model(MODEL_NAME)

print(f"Resources initialized in {time.time() - start_time:.2f} seconds")



# Setting up decorators for embedding on HF-Zero:
def no_op_decorator(func):
    """A no-op (no operation) decorator that simply returns the function."""
    def wrapper(*args, **kwargs):
        # Do nothing special
        return func(*args, **kwargs)
    return wrapper

# # Decide which decorator to use based on environment
# decorator_to_use = spaces.GPU() if is_running_in_hf_space() else no_op_decorator
# #duration=120


if is_running_in_hf_space():
    @spaces.GPU(duration=30)
    def create_embeddings_30(texts_to_embedd):
        """Create embeddings for the input texts using the loaded model."""
        return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
    
    @spaces.GPU(duration=59)
    def create_embeddings_59(texts_to_embedd):
        """Create embeddings for the input texts using the loaded model."""
        return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
    
    @spaces.GPU(duration=120)
    def create_embeddings_120(texts_to_embedd):
        """Create embeddings for the input texts using the loaded model."""
        return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
    
    @spaces.GPU(duration=299)
    def create_embeddings_299(texts_to_embedd):
        """Create embeddings for the input texts using the loaded model."""
        return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
    

else:
    def create_embeddings(texts_to_embedd):
        """Create embeddings for the input texts using the loaded model."""
        return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192)
    
    
    
    
    
    

def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox, 
           sample_reduction_method, plot_time_checkbox, 
           locally_approximate_publication_date_checkbox, 
           download_csv_checkbox, download_png_checkbox, citation_graph_checkbox, 
           progress=gr.Progress()):
    """
    Main prediction pipeline that processes OpenAlex queries and creates visualizations.
    
    Args:
        request (gr.Request): Gradio request object
        text_input (str): OpenAlex query URL
        sample_size_slider (int): Maximum number of samples to process
        reduce_sample_checkbox (bool): Whether to reduce sample size
        sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results")
        plot_time_checkbox (bool): Whether to color points by publication date
        locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting.
        progress (gr.Progress): Gradio progress tracker
    
    Returns:
        tuple: (link to visualization, iframe HTML)
    """
    # Get the authentication token
    if is_running_in_hf_space():
        token = _get_token(request)
        payload = token.split('.')[1]
        payload = f"{payload}{'=' * ((4 - len(payload) % 4) % 4)}"
        payload = json.loads(base64.urlsafe_b64decode(payload).decode())
        print(payload)
        user = payload['user']
        if user == None:
            user_type = "anonymous"
        elif '[pro]' in user:
            user_type = "pro"
        else:
            user_type = "registered"
        print(f"User type: {user_type}")
   

    # Check if input is empty or whitespace
    print(f"Input: {text_input}")
    if not text_input or text_input.isspace():
        error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field"
        return [
            error_message,  # iframe HTML
            gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False),  # html download
            gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False),  # csv download
            gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False),  # png download
            gr.Button(visible=False)  # cancel button state
        ]


    
    # Check if the input is a valid OpenAlex URL

    
    
    start_time = time.time()
    print('Starting data projection pipeline')
    progress(0.1, desc="Starting...")

    # Split input into multiple URLs if present
    urls = [url.strip() for url in text_input.split(';')]
    records = []
    total_query_length = 0
    
    # Use first URL for filename
    first_query, first_params = openalex_url_to_pyalex_query(urls[0])
    filename = openalex_url_to_filename(urls[0])
    print(f"Filename: {filename}")

    # Process each URL
    for i, url in enumerate(urls):
        query, params = openalex_url_to_pyalex_query(url)
        query_length = query.count()
        total_query_length += query_length
        print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...')
        
        target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length
        records_per_query = 0
        
        should_break = False
        for page in query.paginate(per_page=200, n_max=None):
            # Add retry mechanism for processing each page
            max_retries = 5
            base_wait_time = 1  # Starting wait time in seconds
            exponent = 1.5  # Exponential factor
            
            for retry_attempt in range(max_retries):
                try:
                    for record in page:
                        records.append(record)
                        records_per_query += 1
                        progress(0.1 + (0.2 * len(records) / (total_query_length)), 
                                desc=f"Getting data from query {i+1}/{len(urls)}...")
                        
                        if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size:
                            should_break = True
                            break
                    # If we get here without an exception, break the retry loop
                    break
                except Exception as e:
                    print(f"Error processing page: {e}")
                    if retry_attempt < max_retries - 1:
                        wait_time = base_wait_time * (exponent ** retry_attempt) + random.random()
                        print(f"Retrying in {wait_time:.2f} seconds (attempt {retry_attempt + 1}/{max_retries})...")
                        time.sleep(wait_time)
                    else:
                        print(f"Maximum retries reached. Continuing with next page.")
            
            if should_break:
                break
        if should_break:
            break
    print(f"Query completed in {time.time() - start_time:.2f} seconds")

    # Process records
    processing_start = time.time()
    records_df = process_records_to_df(records)
    
    if reduce_sample_checkbox and sample_reduction_method != "All":
        sample_size = min(sample_size_slider, len(records_df))        
        if sample_reduction_method == "n random samples":
            records_df = records_df.sample(sample_size)
        elif sample_reduction_method == "First n samples":
            records_df = records_df.iloc[:sample_size]
    print(f"Records processed in {time.time() - processing_start:.2f} seconds")
    
    # Create embeddings
    embedding_start = time.time()
    progress(0.3, desc="Embedding Data...")
    texts_to_embedd = [f"{title} {abstract}" for title, abstract 
                      in zip(records_df['title'], records_df['abstract'])]
    
    
    if is_running_in_hf_space():
        if len(texts_to_embedd) < 2000:
            embeddings = create_embeddings_30(texts_to_embedd)
        elif len(texts_to_embedd) < 4000 or user_type == "anonymous":
            embeddings = create_embeddings_59(texts_to_embedd)
        elif len(texts_to_embedd) < 8000:
            embeddings = create_embeddings_120(texts_to_embedd)
        else:
            embeddings = create_embeddings_299(texts_to_embedd)
    else:
        embeddings = create_embeddings(texts_to_embedd)
        
    print(f"Embeddings created in {time.time() - embedding_start:.2f} seconds")

    # Project embeddings
    projection_start = time.time()
    progress(0.5, desc="Project into UMAP-embedding...")
    umap_embeddings = mapper.transform(embeddings)
    records_df[['x','y']] = umap_embeddings
    print(f"Projection completed in {time.time() - projection_start:.2f} seconds")

    # Prepare visualization data
    viz_prep_start = time.time()
    progress(0.6, desc="Preparing visualization data...")
    
    basedata_df['color'] = '#ced4d211'
    
    if not plot_time_checkbox:
        records_df['color'] = '#5e2784'
    else:
        cmap = colormaps.haline
        if not locally_approximate_publication_date_checkbox:
            # Create color mapping based on publication years
            years = pd.to_numeric(records_df['publication_year'])
            norm = mcolors.Normalize(vmin=years.min(), vmax=years.max())
            records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in years]
            
        else:
            n_neighbors = 10  # Adjust this value to control smoothing
            nn = NearestNeighbors(n_neighbors=n_neighbors)
            nn.fit(umap_embeddings)
            distances, indices = nn.kneighbors(umap_embeddings)

            # Calculate local average publication year for each point
            local_years = np.array([
                np.mean(records_df['publication_year'].iloc[idx])
                for idx in indices
            ])
            norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max())
            records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years]
                        
            

    stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True)
    stacked_df = stacked_df.fillna("Unlabelled")
    stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()]
    extra_data = pd.DataFrame(stacked_df['doi'])
    print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds")
    if citation_graph_checkbox:
        citation_graph_start = time.time()
        citation_graph = create_citation_graph(records_df)
        graph_file_name = f"{filename}_citation_graph.jpg"
        graph_file_path = static_dir / graph_file_name
        draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True,
                            min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])])
        print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds")


    
    
    # Create and save plot
    plot_start = time.time()
    progress(0.7, desc="Creating interactive plot...")
    # Create a solid black colormap
    black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000'])
    
    
    plot = datamapplot.create_interactive_plot(
        stacked_df[['x','y']].values,
                                np.array(stacked_df['cluster_2_labels']),
        np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]),
        
        hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()],
        marker_color_array=stacked_df['color'],
        use_medoids=True, # Switch back once efficient mediod caclulation comes out!
        width=1000,
        height=1000,
        point_radius_min_pixels=1,
        text_outline_width=5,
        point_hover_color='#5e2784',
        point_radius_max_pixels=7,
        cmap=black_cmap,
        background_image=graph_file_name if citation_graph_checkbox else None,
        #color_label_text=False,
        font_family="Roboto Condensed",
        font_weight=600,
        tooltip_font_weight=600,
        tooltip_font_family="Roboto Condensed",
        extra_point_data=extra_data,
        on_click="window.open(`{doi}`)",
        custom_css=DATAMAP_CUSTOM_CSS,
        initial_zoom_fraction=.8,
        enable_search=False,
        offline_mode=False
    )

    # Save plot
    html_file_name = f"{filename}.html"
    html_file_path = static_dir / html_file_name
    plot.save(html_file_path)
    print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds")

    
   #datamapplot==0.5.1
    # Save additional files if requested
    csv_file_path = static_dir / f"{filename}.csv"
    png_file_path = static_dir / f"{filename}.png"
    
    if download_csv_checkbox:
        # Export relevant column
        export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']]
        export_df['parsed_field'] =   [get_field(row) for ix, row in export_df.iterrows()]
        export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']]
        if locally_approximate_publication_date_checkbox:
            export_df['approximate_publication_year'] = local_years
        export_df.to_csv(csv_file_path, index=False)
        
    if download_png_checkbox:
        png_start_time = time.time()
        print("Starting PNG generation...")

        # Sample and prepare data
        sample_prep_start = time.time()
        sample_to_plot = basedata_df#.sample(20000)
        labels1 = np.array(sample_to_plot['cluster_2_labels'])
        labels2 = np.array(['Unlabelled' if pd.isna(x) else x for x in sample_to_plot['parsed_field']])
        
        ratio = 0.6
        mask = np.random.random(size=len(labels1)) < ratio
        combined_labels = np.where(mask, labels1, labels2)
        
        # Get the 30 most common labels
        unique_labels, counts = np.unique(combined_labels, return_counts=True)
        top_30_labels = set(unique_labels[np.argsort(counts)[-80:]])
        
        # Replace less common labels with 'Unlabelled'
        combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels])
        #combined_labels = np.array(['Unlabelled'  for label in combined_labels])
        #if label not in top_30_labels else label
        colors_base = ['#536878' for _ in range(len(labels1))]
        print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds")

        # Create main plot
        print(labels1)
        print(labels2)
        print(sample_to_plot[['x','y']].values)
        print(combined_labels)
        
        main_plot_start = time.time()
        fig, ax = datamapplot.create_plot(
            sample_to_plot[['x','y']].values,
            combined_labels,
            label_wrap_width=12,
            label_over_points=True,
            dynamic_label_size=True,
            use_medoids=True, # Switch back once efficient mediod caclulation comes out!
            point_size=2,
            marker_color_array=colors_base,
            force_matplotlib=True,
            max_font_size=12,
            min_font_size=4,
            min_font_weight=100,
            max_font_weight=300,
            font_family="Roboto Condensed",
            color_label_text=False, add_glow=False,
            highlight_labels=list(np.unique(labels1)),
            label_font_size=8,
            highlight_label_keywords={"fontsize": 12, "fontweight": "bold", "bbox":{"boxstyle":"circle", "pad":0.75,'alpha':0.}},
        )
        print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds")

     
        if citation_graph_checkbox:

            # Read and add the graph image
            graph_img = plt.imread(graph_file_path)
            ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])],
                      alpha=0.9, aspect='auto')
            
            
        if len(records_df) > 50_000:
            point_size = .5
        elif len(records_df) > 10_000:
            point_size = 1
        else:
            point_size = 5
            
        # Time-based visualization
        scatter_start = time.time()
        if plot_time_checkbox:
            if locally_approximate_publication_date_checkbox:
                scatter = plt.scatter(
                    umap_embeddings[:,0],
                    umap_embeddings[:,1],
                    c=local_years,
                    cmap=colormaps.haline,
                    alpha=0.8,
                    s=point_size
                )
            else:
                years = pd.to_numeric(records_df['publication_year'])
                scatter = plt.scatter(
                    umap_embeddings[:,0],
                    umap_embeddings[:,1],
                    c=years,
                    cmap=colormaps.haline,
                    alpha=0.8,
                    s=point_size
                )
            plt.colorbar(scatter, shrink=0.5, format='%d')
        else:
            scatter = plt.scatter(
                umap_embeddings[:,0],
                umap_embeddings[:,1],
                c=records_df['color'],
                alpha=0.8,
                s=point_size
            )
        print(f"Scatter plot creation completed in {time.time() - scatter_start:.2f} seconds")

        # Save plot
        save_start = time.time()
        plt.axis('off')
        png_file_path = static_dir / f"{filename}.png"
        plt.savefig(png_file_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plot saving completed in {time.time() - save_start:.2f} seconds")
        
        print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds")





    progress(1.0, desc="Done!")
    print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds")
    iframe = f"""<iframe src="{html_file_path}" width="100%" height="1000px"></iframe>"""
    
    # Return iframe and download buttons with appropriate visibility
    return [
        iframe,
        gr.DownloadButton(label="Download Interactive Visualization", value=html_file_path, visible=True, variant='secondary'),
        gr.DownloadButton(label="Download CSV Data", value=csv_file_path, visible=download_csv_checkbox, variant='secondary'),
        gr.DownloadButton(label="Download Static Plot", value=png_file_path, visible=download_png_checkbox, variant='secondary'),
        gr.Button(visible=False)  # Return hidden state for cancel button
    ]

predict.zerogpu = True



theme = gr.themes.Monochrome(
    font=[gr.themes.GoogleFont("Roboto Condensed"), "ui-sans-serif", "system-ui", "sans-serif"],
    text_size="lg",
).set(
    button_secondary_background_fill="white",
    button_secondary_background_fill_hover="#f3f4f6",
    button_secondary_border_color="black",
    button_secondary_text_color="black",
    button_border_width="2px",
)


# Gradio interface setup
with gr.Blocks(theme=theme, css="""
    .gradio-container a {
        color: black !important;
        text-decoration: none !important;  /* Force remove default underline */
        font-weight: bold;
        transition: color 0.2s ease-in-out, border-bottom-color 0.2s ease-in-out;
        display: inline-block;  /* Enable proper spacing for descenders */
        line-height: 1.1;  /* Adjust line height */
        padding-bottom: 2px;  /* Add space for descenders */
    }
    .gradio-container a:hover {
        color: #b23310 !important;
        border-bottom: 3px solid #b23310;  /* Wider underline, only on hover */
    }
""") as demo:
    gr.Markdown("""
    <div style="max-width: 100%; margin: 0 auto;">
    <br>
    
    # OpenAlex Mapper
    
    OpenAlex Mapper is a way of projecting search queries from the amazing OpenAlex database on a background map of randomly sampled papers from OpenAlex, which allows you to easily investigate interdisciplinary connections. OpenAlex Mapper was developed by [Maximilian Noichl](https://maxnoichl.eu) and [Andrea Loettgers](https://unige.academia.edu/AndreaLoettgers) at the [Possible Life project](http://www.possiblelife.eu/).

    To use OpenAlex Mapper, first head over to [OpenAlex](https://openalex.org/) and search for something that interests you. For example, you could search for all the papers that make use of the [Kuramoto model](https://openalex.org/works?page=1&filter=default.search%3A%22Kuramoto%20Model%22), for all the papers that were published by researchers at [Utrecht University in 2019](https://openalex.org/works?page=1&filter=authorships.institutions.lineage%3Ai193662353,publication_year%3A2019), or for all the papers that cite Wittgenstein's [Philosophical Investigations](https://openalex.org/works?page=1&filter=cites%3Aw4251395411). Then you copy the URL to that search query into the OpenAlex search URL box below and click "Run Query." It will download all of these records from OpenAlex and embed them on our interactive map. As the embedding step is a little expensive, computationally, it's often a good idea to play around with smaller samples, before running a larger analysis (see below for a note on sample size and gpu-limits). After a little time, that map will appear and be available for you to interact with and download. You can find more explanations in the FAQs below.
    </div>
    
    """)
    

    with gr.Row():
        with gr.Column(scale=1):
            with gr.Row():
                run_btn = gr.Button("Run Query", variant='primary')
                cancel_btn = gr.Button("Cancel", visible=False, variant='secondary')
            
            # Create separate download buttons
            html_download = gr.DownloadButton("Download Interactive Visualization", visible=False, variant='secondary')
            csv_download = gr.DownloadButton("Download CSV Data", visible=False, variant='secondary')
            png_download = gr.DownloadButton("Download Static Plot", visible=False, variant='secondary')

            text_input = gr.Textbox(label="OpenAlex-search URL",
                                    info="Enter the URL to an OpenAlex-search.")
            
            gr.Markdown("### Sample Settings")
            reduce_sample_checkbox = gr.Checkbox(
                label="Reduce Sample Size",
                value=True,
                info="Reduce sample size."
            )
            sample_reduction_method = gr.Dropdown(
                ["All", "First n samples", "n random samples"],
                label="Sample Selection Method",
                value="First n samples",
                info="How to choose the samples to keep."
            )
            sample_size_slider = gr.Slider(
                label="Sample Size",
                minimum=500,
                maximum=20000,
                step=10,
                value=1000,
                info="How many samples to keep.",
                visible=True
            )
            
            gr.Markdown("### Plot Settings")
            plot_time_checkbox = gr.Checkbox(
                label="Plot Time",
                value=True,
                info="Colour points by their publication date."
            )
            locally_approximate_publication_date_checkbox = gr.Checkbox(
                label="Locally Approximate Publication Date",
                value=True,
                info="Colour points by the average publication date in their area."
            )
            
            gr.Markdown("### Download Options")
            download_csv_checkbox = gr.Checkbox(
                label="Generate CSV Export",
                value=False,
                info="Export the data as CSV file"
            )
            download_png_checkbox = gr.Checkbox(
                label="Generate Static PNG Plot",
                value=False,
                info="Export a static PNG visualization. This will make things slower!"
            )
            
            gr.Markdown("### Citation graph")
            citation_graph_checkbox = gr.Checkbox(
                label="Add Citation Graph",
                value=False,
                info="Adds a citation graph of the sample to the plot."
            )
            
            
            
        with gr.Column(scale=2):
            html = gr.HTML(
                value='<div style="width: 100%; height: 1000px; display: flex; justify-content: center; align-items: center; border: 1px solid #ccc; background-color: #f8f9fa;"><p style="font-size: 1.2em; color: #666;">The visualization map will appear here after running a query</p></div>',
                label="", 
                show_label=False
            )
    gr.Markdown("""
    <div style="max-width: 100%; margin: 0 auto;">
    
    # FAQs
    
    ## Who made this?

    This project was developed by [Maximilian Noichl](https://maxnoichl.eu) (Utrecht University), in cooperation with Andrea Loettger and Tarja Knuuttila at the [Possible Life project](http://www.possiblelife.eu/), at the University of Vienna. If this project is useful in any way for your research, we would appreciate citation of **...**

    This project received funding from the European Research Council under the European Union's Horizon 2020 research and innovation programme (LIFEMODE project, grant agreement No. 818772).

    ## How does it work?

    The base map for this project is developed by randomly downloading 250,000 articles from OpenAlex, then embedding their abstracts using our [fine-tuned](https://huggingface.co/m7n/discipline-tuned_specter_2_024) version of the [specter-2](https://huggingface.co/allenai/specter2_aug2023refresh_base) language model, running these embeddings through [UMAP](https://umap-learn.readthedocs.io/en/latest/) to give us a two-dimensional representation, and displaying that in an interactive window using [datamapplot](https://datamapplot.readthedocs.io/en/latest/index.html). After the data for your query is downloaded from OpenAlex, it then undergoes the exact same process, but the pre-trained UMAP model from earlier is used to project your new data points onto this original map, showing where they would show up if they were included in the original sample. For more details, you can take a look at the method section of this paper: **...**
    
    ## I'm getting an "out of GPU credits" error.
    
    Running the embedding process requires an expensive A100 GPU. To provide this, we make use of HuggingFace's ZeroGPU service. As an anonymous user, this entitles you to one minute of GPU runtime, which is enough for several small queries of around a thousand records every day. If you create a free account on HuggingFace, this should increase to five minutes of runtime, allowing you to run successful queries of up to 10,000 records at a time. If you need more, there's always the option to either buy a HuggingFace Pro subscription for roughly ten dollars a month (entitling you to 25 minutes of runtime every day) or get in touch with us to run the pipeline outside of the HuggingFace environment.
    
    ## I want to add multiple queries at once!

    That can be a good idea, e. g. if your interested in a specific paper, as well as all the papers that cite it. Just add the queries to the query box and separate them with a ";" without any spaces in between!

    ## I think I found a mistake in  the map.

    There are various considerations to take into account when working with this map:

    1. The language model we use is fine-tuned to separate disciplines from each other, but of course, disciplines are weird, partially subjective social categories, so what the model has learned might not always correspond perfectly to what you would expect to see.

    2. When pressing down a really high-dimensional space into a low-dimensional one, there will be trade-offs. For example, we see this big ring structure of the sciences on the map, but in the middle of the map there is a overly stretchedstring of bioinformaticsthat stretches from computer science at the bottom up to the life sciences clusters at the top. This is one of the areas where the UMAP algorithm had trouble pressing our high-dimensional dataset into a low-dimensional space. For more information on how to read a UMAP plot, I recommend looking into ["Understanding UMAP"](https://pair-code.github.io/understanding-umap/) by Andy Coenen & Adam Pearce.
    
    3. Finally, the labels we're using for the regions of this plot are created from OpenAlex's own labels of sub-disciplines. They give a rough indication of the papers that could be expected in this broad area of the map, but they are not necessarily the perfect label for the articles that are precisely below them. They are just located at the median point of a usually much larger, much broader, and fuzzier category, so they should always be taken with quite a big grain of salt.
    
    </div>
    """)

    def update_slider_visibility(method):
        return gr.Slider(visible=(method != "All"))

    sample_reduction_method.change(
        fn=update_slider_visibility,
        inputs=[sample_reduction_method],
        outputs=[sample_size_slider]
    )
    
    def show_cancel_button():
        return gr.Button(visible=True)
    
    def hide_cancel_button():
        return gr.Button(visible=False)
    
    show_cancel_button.zerogpu = True
    hide_cancel_button.zerogpu = True
    predict.zerogpu = True

    # Update the run button click event
    run_event = run_btn.click(
        fn=show_cancel_button,
        outputs=cancel_btn,
        queue=False
    ).then(
        fn=predict,
        inputs=[
            text_input, 
            sample_size_slider, 
            reduce_sample_checkbox, 
            sample_reduction_method, 
            plot_time_checkbox, 
            locally_approximate_publication_date_checkbox,
            download_csv_checkbox, 
            download_png_checkbox,
            citation_graph_checkbox
        ],
        outputs=[html, html_download, csv_download, png_download, cancel_btn]
    )

    # Add cancel button click event
    cancel_btn.click(
        fn=hide_cancel_button,
        outputs=cancel_btn,
        cancels=[run_event],
        queue=False  # Important to make the button hide immediately
    )


# demo.static_dirs = {
#     "static": str(static_dir)
# }


# Mount and run app
# app = gr.mount_gradio_app(app, demo, path="/",ssr_mode=False)

# app.zerogpu = True  # Add this line


# if __name__ == "__main__":
#     demo.launch(server_name="0.0.0.0", server_port=7860, share=True,allowed_paths=["/static"])
    
# Mount Gradio app to FastAPI
if is_running_in_hf_space():
    app = gr.mount_gradio_app(app, demo, path="/",ssr_mode=False) # setting to false for now. 
else:
    app = gr.mount_gradio_app(app, demo, path="/",ssr_mode=False) 

# Run both servers
if __name__ == "__main__":
    if is_running_in_hf_space():
        # For HF Spaces, use SSR mode
        os.environ["GRADIO_SSR_MODE"] = "True"
        uvicorn.run("app:app", host="0.0.0.0", port=7860)
    else:
        uvicorn.run(app, host="0.0.0.0", port=7860)