|
import os |
|
os.system("pip uninstall -y gradio") |
|
os.system("pip install --upgrade gradio") |
|
|
|
|
|
|
|
import spaces |
|
|
|
|
|
from pathlib import Path |
|
from fastapi import FastAPI |
|
from fastapi.staticfiles import StaticFiles |
|
import uvicorn |
|
import gradio as gr |
|
from datetime import datetime |
|
import sys |
|
|
|
|
|
|
|
gr.set_static_paths(paths=["static/"]) |
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
static_dir = Path('./static') |
|
static_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
|
|
|
|
|
|
|
|
|
|
import datamapplot |
|
import numpy as np |
|
import requests |
|
import io |
|
import pandas as pd |
|
from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders |
|
from itertools import chain |
|
from compress_pickle import load, dump |
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
from adapters import AutoAdapterModel |
|
import torch |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_records(search_term): |
|
def invert_abstract(inv_index): |
|
if inv_index is not None: |
|
l_inv = [(w, p) for w, pos in inv_index.items() for p in pos] |
|
return " ".join(map(lambda x: x[0], sorted(l_inv, key=lambda x: x[1]))) |
|
else: |
|
return ' ' |
|
|
|
def get_pub(x): |
|
try: |
|
source = x['source']['display_name'] |
|
if source not in ['parsed_publication','Deleted Journal']: |
|
return source |
|
else: |
|
return ' ' |
|
except: |
|
return ' ' |
|
|
|
|
|
query = Works().search_filter(abstract=search_term) |
|
|
|
records = [] |
|
for record in chain(*query.paginate(per_page=200)): |
|
records.append(record) |
|
|
|
records_df = pd.DataFrame(records) |
|
records_df['abstract'] = [invert_abstract(t) for t in records_df['abstract_inverted_index']] |
|
records_df['parsed_publication'] = [get_pub(x) for x in records_df['primary_location']] |
|
|
|
|
|
return records_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda") |
|
print(f"Using device: {device}") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_aug2023refresh_base') |
|
model = AutoAdapterModel.from_pretrained('allenai/specter2_aug2023refresh_base') |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def create_embeddings(texts_to_embedd): |
|
|
|
|
|
|
|
print(len(texts_to_embedd)) |
|
|
|
|
|
model.load_adapter("allenai/specter2_aug2023refresh", source="hf", load_as="proximity", set_active=True) |
|
model.set_active_adapters("proximity") |
|
|
|
model.to(device) |
|
|
|
def batch_generator(data, batch_size): |
|
"""Yield consecutive batches of data.""" |
|
for i in range(0, len(data), batch_size): |
|
yield data[i:i + batch_size] |
|
|
|
|
|
def encode_texts(texts, device, batch_size=16): |
|
"""Process texts in batches and return their embeddings.""" |
|
model.eval() |
|
with torch.no_grad(): |
|
all_embeddings = [] |
|
count = 0 |
|
for batch in tqdm(batch_generator(texts, batch_size)): |
|
inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device) |
|
outputs = model(**inputs) |
|
embeddings = outputs.last_hidden_state[:, 0, :] |
|
|
|
all_embeddings.append(embeddings.cpu()) |
|
|
|
if count == 100: |
|
torch.mps.empty_cache() |
|
count = 0 |
|
|
|
count +=1 |
|
|
|
all_embeddings = torch.cat(all_embeddings, dim=0) |
|
return all_embeddings |
|
|
|
|
|
embeddings = encode_texts(texts_to_embedd, device, batch_size=32).cpu().numpy() |
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
|
|
|
def predict(text_input, progress=gr.Progress()): |
|
|
|
|
|
records_df = query_records(text_input) |
|
print(records_df) |
|
|
|
|
|
|
|
texts_to_embedd = [title + tokenizer.sep_token + publication + tokenizer.sep_token + abstract for title, publication, abstract in zip(records_df['title'],records_df['parsed_publication'], records_df['abstract'])] |
|
|
|
embeddings = create_embeddings(texts_to_embedd) |
|
print(embeddings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
file_name = f"{datetime.utcnow().strftime('%s')}.html" |
|
file_path = static_dir / file_name |
|
print(file_path) |
|
|
|
|
|
|
|
progress(0.7, desc="Loading hover data...") |
|
|
|
plot = datamapplot.create_interactive_plot( |
|
basedata_df[['x','y']].values, |
|
np.array(basedata_df['cluster_1_labels']), |
|
hover_text=[str(ix) + ', ' + str(row['parsed_publication']) + str(row['title']) for ix, row in basedata_df.iterrows()], |
|
font_family="Roboto Condensed", |
|
) |
|
|
|
progress(0.9, desc="Saving plot...") |
|
plot.save(file_path) |
|
|
|
progress(1.0, desc="Done!") |
|
iframe = f"""<iframe src="/static/{file_name}" width="100%" height="500px"></iframe>""" |
|
link = f'<a href="/static/{file_name}" target="_blank">{file_name}</a>' |
|
return link, iframe |
|
|
|
with gr.Blocks() as block: |
|
gr.Markdown(""" |
|
## Gradio + FastAPI + Static Server |
|
This is a demo of how to use Gradio with FastAPI and a static server. |
|
The Gradio app generates dynamic HTML files and stores them in a static directory. FastAPI serves the static files. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(label="Name") |
|
markdown = gr.Markdown(label="Output Box") |
|
new_btn = gr.Button("New") |
|
with gr.Column(): |
|
html = gr.HTML(label="HTML preview", show_label=True) |
|
|
|
new_btn.click(fn=predict, inputs=[text_input], outputs=[markdown, html]) |
|
|
|
|
|
|
|
|
|
def setup_basemap_data(): |
|
|
|
print("getting basemap data...") |
|
basedata_file= requests.get( |
|
"https://www.maxnoichl.eu/full/oa_project_on_scimap_background_data/100k_filtered_OA_sample_cluster_and_positions.bz" |
|
) |
|
|
|
static_dir = Path("static") |
|
static_dir.mkdir(exist_ok=True) |
|
bz_file_name = "100k_filtered_OA_sample_cluster_and_positions.bz" |
|
bz_file_path = static_dir / bz_file_name |
|
|
|
with open(bz_file_path, "wb") as f: |
|
f.write(basedata_file.content) |
|
|
|
|
|
|
|
basedata_df = load(bz_file_path) |
|
|
|
|
|
return basedata_df |
|
|
|
basedata_df = setup_basemap_data() |
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, block, path="/") |
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
|
|
|
|
|
|
|