|
import os |
|
import time |
|
print(f"Starting up: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import spaces |
|
|
|
|
|
from pathlib import Path |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
import uvicorn |
|
import gradio as gr |
|
from datetime import datetime |
|
import sys |
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
gr.set_static_paths(paths=["static/"]) |
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
static_dir = Path('./static') |
|
static_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
|
|
|
|
|
|
|
|
|
import datamapplot |
|
import numpy as np |
|
import requests |
|
import io |
|
import pandas as pd |
|
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders |
|
from itertools import chain |
|
from compress_pickle import load, dump |
|
from urllib.parse import urlparse, parse_qs |
|
import re |
|
import pyalex |
|
pyalex.config.email = "[email protected]" |
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
from adapters import AutoAdapterModel |
|
import torch |
|
from tqdm import tqdm |
|
from numba.typed import List |
|
import pickle |
|
import pynndescent |
|
import umap |
|
|
|
|
|
|
|
|
|
|
|
print(f"Imports are done: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
|
|
|
|
def openalex_url_to_pyalex_query(url): |
|
""" |
|
Convert an OpenAlex search URL to a pyalex query. |
|
|
|
Args: |
|
url (str): The OpenAlex search URL. |
|
|
|
Returns: |
|
tuple: (Works object, dict of parameters) |
|
""" |
|
parsed_url = urlparse(url) |
|
query_params = parse_qs(parsed_url.query) |
|
|
|
|
|
query = Works() |
|
|
|
|
|
if 'filter' in query_params: |
|
filters = query_params['filter'][0].split(',') |
|
for f in filters: |
|
if ':' in f: |
|
key, value = f.split(':', 1) |
|
if key == 'default.search': |
|
query = query.search(value) |
|
else: |
|
query = query.filter(**{key: value}) |
|
|
|
|
|
if 'sort' in query_params: |
|
sort_params = query_params['sort'][0].split(',') |
|
for s in sort_params: |
|
if s.startswith('-'): |
|
query = query.sort(**{s[1:]: 'desc'}) |
|
else: |
|
query = query.sort(**{s: 'asc'}) |
|
|
|
|
|
params = {} |
|
for key in ['page', 'per-page', 'sample', 'seed']: |
|
if key in query_params: |
|
params[key] = query_params[key][0] |
|
|
|
return query, params |
|
|
|
|
|
def invert_abstract(inv_index): |
|
if inv_index is not None: |
|
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] |
|
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) |
|
else: |
|
return ' ' |
|
|
|
def get_pub(x): |
|
try: |
|
source = x['source']['display_name'] |
|
if source not in ['parsed_publication','Deleted Journal']: |
|
return source |
|
else: |
|
return ' ' |
|
except: |
|
return ' ' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Setting up language model: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
model = SentenceTransformer("m7n/discipline-tuned_specter_2_024") |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def create_embeddings(texts_to_embedd): |
|
|
|
embeddings = model.encode(texts_to_embedd,show_progress_bar=True,batch_size=32) |
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
print(f"Language model is set up: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
|
|
def predict(text_input, sample_size_slider, reduce_sample_checkbox,sample_reduction_method, progress=gr.Progress()): |
|
|
|
print('getting data to project') |
|
progress(0, desc="Starting...") |
|
|
|
query, params = openalex_url_to_pyalex_query(text_input) |
|
query_length = query.count() |
|
print(f'Requesting {query_length} entries...') |
|
|
|
records = [] |
|
|
|
for i, record in enumerate(chain(*query.paginate(per_page=200))): |
|
records.append(record) |
|
|
|
|
|
|
|
progress(0.3 * i / query_length, desc="Getting queried data...") |
|
|
|
|
|
|
|
|
|
records_df = pd.DataFrame(records) |
|
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] |
|
|
|
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] |
|
|
|
|
|
records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ') |
|
records_df['abstract'] = records_df['abstract'].fillna(' ') |
|
records_df['title'] = records_df['title'].fillna(' ') |
|
|
|
|
|
|
|
|
|
|
|
if reduce_sample_checkbox: |
|
sample_size = min(sample_size_slider, len(records_df)) |
|
if sample_reduction_method == "Random": |
|
records_df = records_df.sample(sample_size) |
|
|
|
elif sample_reduction_method == "Order of Results": |
|
records_df = records_df.iloc[:sample_size] |
|
|
|
print(records_df) |
|
|
|
|
|
progress(0.3, desc="Embedding Data...") |
|
texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])] |
|
|
|
embeddings = create_embeddings(texts_to_embedd) |
|
print(embeddings) |
|
|
|
progress(0.5, desc="Project into UMAP-embedding...") |
|
umap_embeddings = mapper.transform(embeddings) |
|
records_df[['x','y']] = umap_embeddings |
|
|
|
basedata_df['color'] = '#ced4d211' |
|
records_df['color'] = '#f98e31' |
|
|
|
progress(0.6, desc="Set up data...") |
|
|
|
stacked_df = pd.concat([basedata_df,records_df], axis=0, ignore_index=True) |
|
stacked_df = stacked_df.fillna("Unlabelled") |
|
stacked_df = stacked_df.reset_index(drop=True) |
|
print(stacked_df) |
|
|
|
extra_data = pd.DataFrame(stacked_df['doi']) |
|
|
|
|
|
file_name = f"{datetime.utcnow().strftime('%s')}.html" |
|
file_path = static_dir / file_name |
|
print(file_path) |
|
|
|
|
|
|
|
progress(0.7, desc="Plotting...") |
|
|
|
custom_css = """ |
|
|
|
|
|
#title-container { |
|
background: #edededaa; |
|
border-radius: 2px; |
|
|
|
box-shadow: 2px 3px 10px #aaaaaa00; |
|
} |
|
|
|
|
|
|
|
#search-container { |
|
position: fixed !important; |
|
top: 20px !important; |
|
right: 20px !important; |
|
left: auto !important; |
|
width: 200px !important; |
|
z-index: 9999 !important; |
|
} |
|
|
|
#search { |
|
// padding: 8px 8px !important; |
|
// border: none !important; |
|
// border-radius: 20px !important; |
|
background-color: #ffffffaa !important; |
|
font-family: 'Roboto Condensed', sans-serif !important; |
|
font-size: 14px; |
|
// box-shadow: 0 0px 0px #aaaaaa00 !important; |
|
} |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
plot = datamapplot.create_interactive_plot( |
|
stacked_df[['x','y']].values, |
|
np.array(stacked_df['cluster_1_labels']),np.array(stacked_df['cluster_2_labels']),np.array(stacked_df['cluster_3_labels']), |
|
hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()], |
|
marker_color_array=stacked_df['color'], |
|
|
|
use_medoids=True, |
|
width=1000, |
|
height=1000, |
|
|
|
|
|
point_radius_min_pixels=1, |
|
text_outline_width=5, |
|
point_hover_color='#5e2784', |
|
point_radius_max_pixels=7, |
|
color_label_text=False, |
|
font_family="Roboto Condensed", |
|
font_weight=700, |
|
tooltip_font_weight=600, |
|
tooltip_font_family="Roboto Condensed", |
|
extra_point_data=extra_data, |
|
on_click="window.open(`{doi}`)", |
|
custom_css=custom_css, |
|
initial_zoom_fraction=.8, |
|
enable_search=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.9, desc="Saving plot...") |
|
plot.save(file_path) |
|
|
|
progress(1.0, desc="Done!") |
|
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>""" |
|
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>' |
|
return link, iframe |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as block: |
|
gr.Markdown(""" |
|
## Mapping OpenAlex-Queries |
|
Enter the URL to an OpenAlex-search below. It will take a few minutes, but then the result will be projected onto a map of the OA database as a whole. |
|
""") |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
text_input = gr.Textbox(label="OpenAlex-search URL") |
|
with gr.Row(): |
|
reduce_sample_checkbox = gr.Checkbox(label="Reduce Sample Size", value=True, info="Reduce sample size.") |
|
sample_size_slider = gr.Slider(label="Sample Size", minimum=10, maximum=20000, step=10, value=1000, info="How many samples to keep.") |
|
sample_reduction_method = gr.Dropdown(["Order of Results", "Random"], label="Order of Results", info="How to choose the samples to keep.") |
|
|
|
|
|
new_btn = gr.Button("Run Query",variant='primary') |
|
markdown = gr.Markdown(label="") |
|
html = gr.HTML(label="HTML preview", show_label=True) |
|
|
|
new_btn.click(fn=predict, inputs=[text_input, sample_size_slider, reduce_sample_checkbox,sample_reduction_method], outputs=[markdown, html]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_basemap_data(): |
|
|
|
print(f"Getting basemap data: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
basedata_df =pickle.load(open('100k_filtered_OA_sample_cluster_and_positions_supervised.pkl', 'rb')) |
|
print(basedata_df) |
|
return basedata_df |
|
|
|
|
|
|
|
def setup_mapper(): |
|
print(f"Getting Mapper: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
params_new = pickle.load(open('umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl', 'rb')) |
|
print("setting up mapper...") |
|
mapper = umap.UMAP() |
|
|
|
|
|
umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() if k != 'target_backend'} |
|
mapper.set_params(**umap_params) |
|
|
|
for attr, value in params_new.get('umap_attributes', {}).items(): |
|
if attr != 'embedding_': |
|
setattr(mapper, attr, value) |
|
|
|
if 'embedding_' in params_new.get('umap_attributes', {}): |
|
mapper.embedding_ = List(params_new['umap_attributes']['embedding_']) |
|
|
|
return mapper |
|
|
|
|
|
|
|
|
|
url = "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/100k_filtered_OA_sample_cluster_and_positions_supervised.pkl" |
|
response = requests.get(url) |
|
with open("100k_filtered_OA_sample_cluster_and_positions_supervised.pkl", "wb") as f: |
|
f.write(response.content) |
|
|
|
url = "https://huggingface.co/datasets/m7n/intermediate_sci_pickle/resolve/main/umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl" |
|
response = requests.get(url) |
|
with open("umap_mapper_250k_random_OA_discipline_tuned_specter_2_params.pkl", "wb") as f: |
|
f.write(response.content) |
|
|
|
|
|
|
|
basedata_df = setup_basemap_data() |
|
mapper = setup_mapper() |
|
print(f"Setup done, starting up app: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, block, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|
|
|
|
|