|
import os |
|
os.system("pip uninstall -y gradio") |
|
os.system("pip install --upgrade gradio") |
|
os.system("pip install datamapplot==0.3.0") |
|
os.system("pip install numba==0.59.1") |
|
os.system("pip install umap-learn==0.5.6") |
|
os.system("pip install pynndescent==0.5.12") |
|
|
|
|
|
|
|
import spaces |
|
|
|
|
|
from pathlib import Path |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
import uvicorn |
|
import gradio as gr |
|
from datetime import datetime |
|
import sys |
|
|
|
|
|
|
|
gr.set_static_paths(paths=["static/"]) |
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
static_dir = Path('./static') |
|
static_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
|
|
|
|
|
|
|
|
|
import datamapplot |
|
import numpy as np |
|
import requests |
|
import io |
|
import pandas as pd |
|
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders |
|
from itertools import chain |
|
from compress_pickle import load, dump |
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
from adapters import AutoAdapterModel |
|
import torch |
|
from tqdm import tqdm |
|
from numba.typed import List |
|
import pickle |
|
import pynndescent |
|
import umap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_records(search_term,progress): |
|
def invert_abstract(inv_index): |
|
if inv_index is not None: |
|
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] |
|
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) |
|
else: |
|
return ' ' |
|
|
|
def get_pub(x): |
|
try: |
|
source = x['source']['display_name'] |
|
if source not in ['parsed_publication','Deleted Journal']: |
|
return source |
|
else: |
|
return ' ' |
|
except: |
|
return ' ' |
|
|
|
|
|
query = Works().search([search_term]) |
|
query_length = Works().search([search_term]).count() |
|
|
|
records = [] |
|
total_pages = (query_length + 199) // 200 |
|
|
|
for i, record in enumerate(chain(*query.paginate(per_page=200))): |
|
records.append(record) |
|
|
|
|
|
achieved_progress = min(0.1, (i + 1) / query_length * 0.1) |
|
|
|
|
|
progress(achieved_progress, desc="Getting queried data...") |
|
|
|
|
|
|
|
records_df = pd.DataFrame(records) |
|
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] |
|
|
|
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] |
|
|
|
|
|
records_df['parsed_publication'] = records_df['parsed_publication'].fillna(' ') |
|
records_df['abstract'] = records_df['abstract'].fillna(' ') |
|
records_df['title'] = records_df['title'].fillna(' ') |
|
|
|
|
|
return records_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda") |
|
print(f"Using device: {device}") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base') |
|
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base') |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def create_embeddings(texts_to_embedd): |
|
|
|
|
|
|
|
print(len(texts_to_embedd)) |
|
|
|
|
|
model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True) |
|
model.set_active_adapters("proximity") |
|
|
|
model.to(device) |
|
|
|
def batch_generator(data, batch_size): |
|
"""Yield consecutive batches of data.""" |
|
for i in range(0, len(data), batch_size): |
|
yield data[i:i + batch_size] |
|
|
|
|
|
def encode_texts(texts, device, batch_size=16): |
|
"""Process texts in batches and return their embeddings.""" |
|
model.eval() |
|
with torch.no_grad(): |
|
all_embeddings = [] |
|
count = 0 |
|
for batch in tqdm(batch_generator(texts, batch_size)): |
|
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device) |
|
outputs = model(**inputs) |
|
embeddings = outputs.last_hidden_state[:, 0, :] |
|
|
|
all_embeddings.append(embeddings.cpu()) |
|
|
|
if count == 100: |
|
torch.mps.empty_cache() |
|
count = 0 |
|
|
|
count +=1 |
|
|
|
all_embeddings = torch.cat(all_embeddings, dim=0) |
|
return all_embeddings |
|
|
|
|
|
embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() |
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
def predict(text_input, sample_size_slider, reduce_sample_checkbox, progress=gr.Progress()): |
|
|
|
|
|
|
|
records_df = query_records(text_input,progress=progress) |
|
if reduce_sample_checkbox: |
|
records_df = records_df.sample(sample_size_slider) |
|
print(records_df) |
|
|
|
|
|
progress(0.3, desc="Embedding Data...") |
|
texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])] |
|
|
|
embeddings = create_embeddings(texts_to_embedd) |
|
print(embeddings) |
|
|
|
progress(0.5, desc="Project into UMAP-embedding...") |
|
umap_embeddings = mapper.transform(embeddings) |
|
records_df[['x','y']] = umap_embeddings |
|
|
|
basedata_df['color'] = '#ced4d211' |
|
records_df['color'] = '#f98e31' |
|
|
|
progress(0.6, desc="Set up data...") |
|
|
|
stacked_df = pd.concat([basedata_df,records_df], axis=0, ignore_index=True) |
|
stacked_df = stacked_df.fillna("Unlabelled") |
|
stacked_df = stacked_df.reset_index(drop=True) |
|
print(stacked_df) |
|
|
|
extra_data = pd.DataFrame(stacked_df['doi']) |
|
|
|
|
|
file_name = f"{datetime.utcnow().strftime('%s')}.html" |
|
file_path = static_dir / file_name |
|
print(file_path) |
|
|
|
|
|
|
|
progress(0.7, desc="Plotting...") |
|
|
|
custom_css = """ |
|
|
|
|
|
#title-container { |
|
background: #edededaa; |
|
border-radius: 2px; |
|
|
|
box-shadow: 2px 3px 10px #aaaaaa00; |
|
} |
|
|
|
|
|
|
|
#search-container { |
|
position: fixed !important; |
|
top: 20px !important; |
|
right: 20px !important; |
|
left: auto !important; |
|
width: 200px !important; |
|
z-index: 9999 !important; |
|
} |
|
|
|
#search { |
|
// padding: 8px 8px !important; |
|
// border: none !important; |
|
// border-radius: 20px !important; |
|
background-color: #ffffffaa !important; |
|
font-family: 'Roboto Condensed', sans-serif !important; |
|
font-size: 14px; |
|
// box-shadow: 0 0px 0px #aaaaaa00 !important; |
|
} |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
plot = datamapplot.create_interactive_plot( |
|
stacked_df[['x','y']].values, |
|
np.array(stacked_df['cluster_1_labels']),np.array(stacked_df['cluster_2_labels']),np.array(stacked_df['cluster_3_labels']), |
|
hover_text=[str(row['title']) for ix, row in stacked_df.iterrows()], |
|
marker_color_array=stacked_df['color'], |
|
|
|
use_medoids=True, |
|
width=1000, |
|
height=1000, |
|
|
|
|
|
point_radius_min_pixels=1, |
|
text_outline_width=5, |
|
point_hover_color='#5e2784', |
|
point_radius_max_pixels=7, |
|
color_label_text=False, |
|
font_family="Roboto Condensed", |
|
font_weight=700, |
|
tooltip_font_weight=600, |
|
tooltip_font_family="Roboto Condensed", |
|
extra_point_data=extra_data, |
|
on_click="window.open(`{doi}`)", |
|
custom_css=custom_css, |
|
initial_zoom_fraction=.8, |
|
enable_search=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
progress(0.9, desc="Saving plot...") |
|
plot.save(file_path) |
|
|
|
progress(1.0, desc="Done!") |
|
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>""" |
|
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>' |
|
return link, iframe |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as block: |
|
gr.Markdown(""" |
|
## Mapping OpenAlex-Queries |
|
|
|
This is a tool to further interdisciplinary research – you are a neuroscientist who has used ..., What have the ... been doing with them. |
|
You're a philosopher of science who wonders where the concept of a fitness landscape has appeared... |
|
""") |
|
|
|
with gr.Column(): |
|
text_input = gr.Textbox(label="OpenAlex Fulltext-Search") |
|
sample_size_slider = gr.Slider(label="Sample Size", minimum=10, maximum=20000, step=10, value=1000) |
|
reduce_sample_checkbox = gr.Checkbox(label="Reduce Sample Size", value=True) |
|
new_btn = gr.Button("Run Query") |
|
markdown = gr.Markdown(label="") |
|
html = gr.HTML(label="HTML preview", show_label=True) |
|
|
|
new_btn.click(fn=predict, inputs=[text_input, sample_size_slider, reduce_sample_checkbox], outputs=[markdown, html]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_basemap_data(): |
|
|
|
print("getting basemap data...") |
|
basedata_df = load("100k_filtered_OA_sample_cluster_and_positions.bz") |
|
print(basedata_df) |
|
return basedata_df |
|
|
|
|
|
|
|
def setup_mapper(): |
|
print("getting mapper...") |
|
|
|
params_new = pickle.load(open('umap_mapper_300k_random_OA_specter_2_params.pkl', 'rb')) |
|
print("setting up mapper...") |
|
mapper = umap.UMAP() |
|
|
|
|
|
umap_params = {k: v for k, v in params_new.get('umap_params', {}).items() if k != 'target_backend'} |
|
mapper.set_params(**umap_params) |
|
|
|
for attr, value in params_new.get('umap_attributes', {}).items(): |
|
if attr != 'embedding_': |
|
setattr(mapper, attr, value) |
|
|
|
if 'embedding_' in params_new.get('umap_attributes', {}): |
|
mapper.embedding_ = List(params_new['umap_attributes']['embedding_']) |
|
|
|
return mapper |
|
|
|
|
|
|
|
|
|
basedata_df = setup_basemap_data() |
|
mapper = setup_mapper() |
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, block, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|
|
|
|
|
|
|