import spaces # necessary to run on Zero. from spaces.zero.client import _get_token import time print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}") # Standard library imports import os from pathlib import Path from datetime import datetime from itertools import chain # Third-party imports import numpy as np import pandas as pd import torch import gradio as gr print(f"Gradio version: {gr.__version__}") from fastapi import FastAPI from fastapi.staticfiles import StaticFiles import uvicorn import matplotlib.pyplot as plt import tqdm import colormaps import matplotlib.colors as mcolors from matplotlib.colors import Normalize import opinionated # for fonts plt.style.use("opinionated_rc") from sklearn.neighbors import NearestNeighbors def is_running_in_hf_space(): return "SPACE_ID" in os.environ #if is_running_in_hf_space(): import spaces # necessary to run on Zero. #print(f"Spaces version: {spaces.__version__}") import datamapplot import pyalex # Local imports from openalex_utils import ( openalex_url_to_pyalex_query, get_field, process_records_to_df, openalex_url_to_filename ) from styles import DATAMAP_CUSTOM_CSS from data_setup import ( download_required_files, setup_basemap_data, setup_mapper, setup_embedding_model, ) from network_utils import create_citation_graph, draw_citation_graph # Configure OpenAlex pyalex.config.email = "maximilian.noichl@uni-bamberg.de" print(f"Imports completed: {time.strftime('%Y-%m-%d %H:%M:%S')}") # Instead of FastAPI setup, just use Gradio's file serving static_dir = Path("static") static_dir.mkdir(exist_ok=True) # Create the static directory if it doesn't exist gr.set_static_paths(paths=["static/"]) # Resource configuration REQUIRED_FILES = { "100k_filtered_OA_sample_cluster_and_positions_supervised.pkl": "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/100k_filtered_OA_sample_cluster_and_positions_supervised.pkl", "umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl": "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl" } BASEMAP_PATH = "100k_filtered_OA_sample_cluster_and_positions_supervised.pkl" MAPPER_PARAMS_PATH = "umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl" MODEL_NAME = "m7n/discipline-tuned_specter_2_024" # Initialize models and data start_time = time.time() print("Initializing resources...") download_required_files(REQUIRED_FILES) basedata_df = setup_basemap_data(BASEMAP_PATH) mapper = setup_mapper(MAPPER_PARAMS_PATH) model = setup_embedding_model(MODEL_NAME) print(f"Resources initialized in {time.time() - start_time:.2f} seconds") # Setting up decorators for embedding on HF-Zero: def no_op_decorator(func): """A no-op (no operation) decorator that simply returns the function.""" def wrapper(*args, **kwargs): # Do nothing special return func(*args, **kwargs) return wrapper # # Decide which decorator to use based on environment # decorator_to_use = spaces.GPU() if is_running_in_hf_space() else no_op_decorator # #duration=120 # @decorator_to_use @spaces.GPU(duration=4*60) def create_embeddings(texts_to_embedd): """Create embeddings for the input texts using the loaded model.""" return model.encode(texts_to_embedd, show_progress_bar=True, batch_size=192) def predict(request: gr.Request, text_input, sample_size_slider, reduce_sample_checkbox, sample_reduction_method, plot_time_checkbox, locally_approximate_publication_date_checkbox, download_csv_checkbox, download_png_checkbox, citation_graph_checkbox, progress=gr.Progress()): """ Main prediction pipeline that processes OpenAlex queries and creates visualizations. Args: request (gr.Request): Gradio request object text_input (str): OpenAlex query URL sample_size_slider (int): Maximum number of samples to process reduce_sample_checkbox (bool): Whether to reduce sample size sample_reduction_method (str): Method for sample reduction ("Random" or "Order of Results") plot_time_checkbox (bool): Whether to color points by publication date locally_approximate_publication_date_checkbox (bool): Whether to approximate publication date locally before plotting. progress (gr.Progress): Gradio progress tracker Returns: tuple: (link to visualization, iframe HTML) """ # Get the authentication token token = _get_token(request) print(f"Token: {token}") print(f"Request: {request}") # Check if input is empty or whitespace print(f"Input: {text_input}") if not text_input or text_input.isspace(): error_message = "Error: Please enter a valid OpenAlex URL in the 'OpenAlex-search URL'-field" return [ error_message, # iframe HTML gr.DownloadButton(label="Download Interactive Visualization", value='html_file_path', visible=False), # html download gr.DownloadButton(label="Download CSV Data", value='csv_file_path', visible=False), # csv download gr.DownloadButton(label="Download Static Plot", value='png_file_path', visible=False), # png download gr.Button(visible=False) # cancel button state ] # Check if the input is a valid OpenAlex URL start_time = time.time() print('Starting data projection pipeline') progress(0.1, desc="Starting...") # Split input into multiple URLs if present urls = [url.strip() for url in text_input.split(';')] records = [] total_query_length = 0 # Use first URL for filename first_query, first_params = openalex_url_to_pyalex_query(urls[0]) filename = openalex_url_to_filename(urls[0]) print(f"Filename: {filename}") # Process each URL for i, url in enumerate(urls): query, params = openalex_url_to_pyalex_query(url) query_length = query.count() total_query_length += query_length print(f'Requesting {query_length} entries from query {i+1}/{len(urls)}...') target_size = sample_size_slider if reduce_sample_checkbox and sample_reduction_method == "First n samples" else query_length records_per_query = 0 should_break = False for page in query.paginate(per_page=200, n_max=None): for record in page: records.append(record) records_per_query += 1 progress(0.1 + (0.2 * len(records) / (total_query_length)), desc=f"Getting data from query {i+1}/{len(urls)}...") if reduce_sample_checkbox and sample_reduction_method == "First n samples" and records_per_query >= target_size: should_break = True break if should_break: break if should_break: break print(f"Query completed in {time.time() - start_time:.2f} seconds") # Process records processing_start = time.time() records_df = process_records_to_df(records) if reduce_sample_checkbox and sample_reduction_method != "All": sample_size = min(sample_size_slider, len(records_df)) if sample_reduction_method == "n random samples": records_df = records_df.sample(sample_size) elif sample_reduction_method == "First n samples": records_df = records_df.iloc[:sample_size] print(f"Records processed in {time.time() - processing_start:.2f} seconds") # Create embeddings embedding_start = time.time() progress(0.3, desc="Embedding Data...") texts_to_embedd = [f"{title} {abstract}" for title, abstract in zip(records_df['title'], records_df['abstract'])] embeddings = create_embeddings(texts_to_embedd) print(f"Embeddings created in {time.time() - embedding_start:.2f} seconds") # Project embeddings projection_start = time.time() progress(0.5, desc="Project into UMAP-embedding...") umap_embeddings = mapper.transform(embeddings) records_df[['x','y']] = umap_embeddings print(f"Projection completed in {time.time() - projection_start:.2f} seconds") # Prepare visualization data viz_prep_start = time.time() progress(0.6, desc="Preparing visualization data...") basedata_df['color'] = '#ced4d211' if not plot_time_checkbox: records_df['color'] = '#5e2784' else: cmap = colormaps.haline if not locally_approximate_publication_date_checkbox: # Create color mapping based on publication years years = pd.to_numeric(records_df['publication_year']) norm = mcolors.Normalize(vmin=years.min(), vmax=years.max()) records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in years] else: n_neighbors = 10 # Adjust this value to control smoothing nn = NearestNeighbors(n_neighbors=n_neighbors) nn.fit(umap_embeddings) distances, indices = nn.kneighbors(umap_embeddings) # Calculate local average publication year for each point local_years = np.array([ np.mean(records_df['publication_year'].iloc[idx]) for idx in indices ]) norm = mcolors.Normalize(vmin=local_years.min(), vmax=local_years.max()) records_df['color'] = [mcolors.to_hex(cmap(norm(year))) for year in local_years] stacked_df = pd.concat([basedata_df, records_df], axis=0, ignore_index=True) stacked_df = stacked_df.fillna("Unlabelled") stacked_df['parsed_field'] = [get_field(row) for ix, row in stacked_df.iterrows()] extra_data = pd.DataFrame(stacked_df['doi']) print(f"Visualization data prepared in {time.time() - viz_prep_start:.2f} seconds") if citation_graph_checkbox: citation_graph_start = time.time() citation_graph = create_citation_graph(records_df) graph_file_name = f"{filename}_citation_graph.jpg" graph_file_path = static_dir / graph_file_name draw_citation_graph(citation_graph,path=graph_file_path,bundle_edges=True, min_max_coordinates=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])]) print(f"Citation graph created and saved in {time.time() - citation_graph_start:.2f} seconds") # Create and save plot plot_start = time.time() progress(0.7, desc="Creating interactive plot...") # Create a solid black colormap black_cmap = mcolors.LinearSegmentedColormap.from_list('black', ['#000000', '#000000']) plot = datamapplot.create_interactive_plot( stacked_df[['x','y']].values, np.array(stacked_df['cluster_2_labels']), np.array(['Unlabelled' if pd.isna(x) else x for x in stacked_df['parsed_field']]), hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()], marker_color_array=stacked_df['color'], use_medoids=False, # Switch back once efficient mediod caclulation comes out! width=1000, height=1000, point_radius_min_pixels=1, text_outline_width=5, point_hover_color='#5e2784', point_radius_max_pixels=7, cmap=black_cmap, background_image=graph_file_name if citation_graph_checkbox else None, #color_label_text=False, font_family="Roboto Condensed", font_weight=600, tooltip_font_weight=600, tooltip_font_family="Roboto Condensed", extra_point_data=extra_data, on_click="window.open(`{doi}`)", custom_css=DATAMAP_CUSTOM_CSS, initial_zoom_fraction=.8, enable_search=False, offline_mode=False ) # Save plot html_file_name = f"{filename}.html" html_file_path = static_dir / html_file_name plot.save(html_file_path) print(f"Plot created and saved in {time.time() - plot_start:.2f} seconds") # Save additional files if requested csv_file_path = static_dir / f"{filename}.csv" png_file_path = static_dir / f"{filename}.png" if download_csv_checkbox: # Export relevant column export_df = records_df[['title', 'abstract', 'doi', 'publication_year', 'x', 'y','id','primary_topic']] export_df['parsed_field'] = [get_field(row) for ix, row in export_df.iterrows()] export_df['referenced_works'] = [', '.join(x) for x in records_df['referenced_works']] export_df.to_csv(csv_file_path, index=False) if download_png_checkbox: png_start_time = time.time() print("Starting PNG generation...") # Sample and prepare data sample_prep_start = time.time() sample_to_plot = basedata_df#.sample(20000) labels1 = np.array(sample_to_plot['cluster_2_labels']) labels2 = np.array(['Unlabelled' if pd.isna(x) else x for x in sample_to_plot['parsed_field']]) ratio = 0.6 mask = np.random.random(size=len(labels1)) < ratio combined_labels = np.where(mask, labels1, labels2) # Get the 30 most common labels unique_labels, counts = np.unique(combined_labels, return_counts=True) top_30_labels = set(unique_labels[np.argsort(counts)[-50:]]) # Replace less common labels with 'Unlabelled' combined_labels = np.array(['Unlabelled' if label not in top_30_labels else label for label in combined_labels]) #combined_labels = np.array(['Unlabelled' for label in combined_labels]) #if label not in top_30_labels else label colors_base = ['#536878' for _ in range(len(labels1))] print(f"Sample preparation completed in {time.time() - sample_prep_start:.2f} seconds") # Create main plot print(labels1) print(labels2) print(sample_to_plot[['x','y']].values) print(combined_labels) main_plot_start = time.time() fig, ax = datamapplot.create_plot( sample_to_plot[['x','y']].values, combined_labels, label_wrap_width=12, label_over_points=True, dynamic_label_size=True, use_medoids=False, # Switch back once efficient mediod caclulation comes out! point_size=2, marker_color_array=colors_base, force_matplotlib=True, max_font_size=12, min_font_size=4, min_font_weight=100, max_font_weight=300, font_family="Roboto Condensed", color_label_text=False, add_glow=False, highlight_labels=list(np.unique(labels1)), label_font_size=8, highlight_label_keywords={"fontsize": 12, "fontweight": "bold", "bbox":{"boxstyle":"circle", "pad":0.75,'alpha':0.}}, ) print(f"Main plot creation completed in {time.time() - main_plot_start:.2f} seconds") if citation_graph_checkbox: # Read and add the graph image graph_img = plt.imread(graph_file_path) ax.imshow(graph_img, extent=[np.min(stacked_df['x']),np.max(stacked_df['x']),np.min(stacked_df['y']),np.max(stacked_df['y'])], alpha=0.9, aspect='auto') # Time-based visualization scatter_start = time.time() if plot_time_checkbox: if locally_approximate_publication_date_checkbox: scatter = plt.scatter( umap_embeddings[:,0], umap_embeddings[:,1], c=local_years, cmap=colormaps.haline, alpha=0.8, s=5 ) else: years = pd.to_numeric(records_df['publication_year']) scatter = plt.scatter( umap_embeddings[:,0], umap_embeddings[:,1], c=years, cmap=colormaps.haline, alpha=0.8, s=5 ) plt.colorbar(scatter, shrink=0.5, format='%d') else: scatter = plt.scatter( umap_embeddings[:,0], umap_embeddings[:,1], c=records_df['color'], alpha=0.8, s=5 ) print(f"Scatter plot creation completed in {time.time() - scatter_start:.2f} seconds") # Save plot save_start = time.time() plt.axis('off') png_file_path = static_dir / f"{filename}.png" plt.savefig(png_file_path, dpi=300, bbox_inches='tight') plt.close() print(f"Plot saving completed in {time.time() - save_start:.2f} seconds") print(f"Total PNG generation completed in {time.time() - png_start_time:.2f} seconds") progress(1.0, desc="Done!") print(f"Total pipeline completed in {time.time() - start_time:.2f} seconds") iframe = f"""""" # Return iframe and download buttons with appropriate visibility return [ iframe, gr.DownloadButton(label="Download Interactive Visualization", value=html_file_path, visible=True, variant='secondary'), gr.DownloadButton(label="Download CSV Data", value=csv_file_path, visible=download_csv_checkbox, variant='secondary'), gr.DownloadButton(label="Download Static Plot", value=png_file_path, visible=download_png_checkbox, variant='secondary'), gr.Button(visible=False) # Return hidden state for cancel button ] predict.zerogpu = True theme = gr.themes.Monochrome( font=[gr.themes.GoogleFont("Roboto Condensed"), "ui-sans-serif", "system-ui", "sans-serif"], text_size="lg", ).set( button_secondary_background_fill="white", button_secondary_background_fill_hover="#f3f4f6", button_secondary_border_color="black", button_secondary_text_color="black", button_border_width="2px", ) # Gradio interface setup with gr.Blocks(theme=theme, css=""" .gradio-container a { color: black !important; text-decoration: none !important; /* Force remove default underline */ font-weight: bold; transition: color 0.2s ease-in-out, border-bottom-color 0.2s ease-in-out; display: inline-block; /* Enable proper spacing for descenders */ line-height: 1.1; /* Adjust line height */ padding-bottom: 2px; /* Add space for descenders */ } .gradio-container a:hover { color: #b23310 !important; border-bottom: 3px solid #b23310; /* Wider underline, only on hover */ } """) as demo: gr.Markdown("""
The visualization map will appear here after running a query