Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pickle | |
import numpy as np | |
import glob | |
import tqdm | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModel | |
from peft import PeftModel | |
from tevatron.retriever.searcher import FaissFlatSearcher | |
import logging | |
import os | |
import json | |
import spaces | |
import ir_datasets | |
import pytrec_eval | |
from huggingface_hub import login | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Authenticate with HF_TOKEN | |
login(token=os.environ['HF_TOKEN']) | |
# Global variables | |
CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint" | |
BASE_MODEL = "meta-llama/Llama-2-7b-hf" | |
tokenizer = None | |
model = None | |
retrievers = {} | |
corpus_lookups = {} | |
queries = {} | |
q_lookups = {} | |
qrels = {} | |
datasets = ["scifact"] # others are too large for the Space unfortunately :( | |
current_dataset = "scifact" | |
def pool(last_hidden_states, attention_mask): | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
sequence_lengths = attention_mask.sum(dim=1) - 1 | |
batch_size = last_hidden.shape[0] | |
return last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] | |
def create_batch_dict(tokenizer, input_texts, max_length=512): | |
batch_dict = tokenizer( | |
input_texts, | |
max_length=max_length - 1, | |
return_token_type_ids=False, | |
return_attention_mask=False, | |
padding=False, | |
truncation=True | |
) | |
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']] | |
return tokenizer.pad( | |
batch_dict, | |
padding=True, | |
pad_to_multiple_of=8, | |
return_attention_mask=True, | |
return_tensors="pt", | |
) | |
def load_model(): | |
global tokenizer, model | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
base_model_instance = AutoModel.from_pretrained(BASE_MODEL) | |
model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL) | |
model = model.merge_and_unload() | |
model.eval() | |
def load_corpus_embeddings(dataset_name): | |
global retrievers, corpus_lookups | |
corpus_path = f"{dataset_name}/corpus_emb.*.pkl" | |
index_files = glob.glob(corpus_path) | |
logger.info(f'Loading {len(index_files)} files into index for {dataset_name}.') | |
p_reps_0, p_lookup_0 = pickle_load(index_files[0]) | |
retrievers[dataset_name] = FaissFlatSearcher(p_reps_0) | |
shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]] | |
corpus_lookups[dataset_name] = [] | |
for p_reps, p_lookup in tqdm.tqdm(shards, desc=f'Loading shards into index for {dataset_name}', total=len(index_files)): | |
retrievers[dataset_name].add(p_reps) | |
corpus_lookups[dataset_name] += p_lookup | |
def pickle_load(path): | |
with open(path, 'rb') as f: | |
reps, lookup = pickle.load(f) | |
return np.array(reps), lookup | |
def load_queries(dataset_name): | |
global queries, q_lookups, qrels | |
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else "")) | |
queries[dataset_name] = [] | |
q_lookups[dataset_name] = {} | |
qrels[dataset_name] = {} | |
for query in dataset.queries_iter(): | |
queries[dataset_name].append(query.text) | |
q_lookups[dataset_name][query.query_id] = query.text | |
for qrel in dataset.qrels_iter(): | |
if qrel.query_id not in qrels[dataset_name]: | |
qrels[dataset_name][qrel.query_id] = {} | |
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance | |
def encode_queries(dataset_name, postfix): | |
global queries, tokenizer, model | |
model = model.cuda() | |
input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[dataset_name]] | |
encoded_embeds = [] | |
batch_size = 64 | |
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc="Encoding queries"): | |
batch_input_texts = input_texts[start_idx: start_idx + batch_size] | |
batch_dict = create_batch_dict(tokenizer, batch_input_texts) | |
batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()} | |
with torch.cuda.amp.autocast(): | |
outputs = model(**batch_dict) | |
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask']) | |
embeds = F.normalize(embeds, p=2, dim=-1) | |
encoded_embeds.append(embeds.cpu().numpy()) | |
# remove model from GPU | |
model = model.cpu() | |
return np.concatenate(encoded_embeds, axis=0) | |
def search_queries(dataset_name, q_reps, depth=1000): | |
all_scores, all_indices = retrievers[dataset_name].search(q_reps, depth) | |
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices] | |
return all_scores, np.array(psg_indices) | |
def evaluate(qrels, results, k_values): | |
evaluator = pytrec_eval.RelevanceEvaluator( | |
qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values} | |
) | |
scores = evaluator.evaluate(results) | |
metrics = {} | |
for k in k_values: | |
metrics[f"NDCG@{k}"] = round(np.mean([query_scores[f"ndcg_cut_{k}"] for query_scores in scores.values()]), 3) | |
metrics[f"Recall@{k}"] = round(np.mean([query_scores[f"recall_{k}"] for query_scores in scores.values()]), 3) | |
return metrics | |
def run_evaluation(dataset, postfix): | |
global current_dataset | |
if dataset not in retrievers or dataset not in queries: | |
load_corpus_embeddings(dataset) | |
load_queries(dataset) | |
current_dataset = dataset | |
q_reps = encode_queries(dataset, postfix) | |
all_scores, psg_indices = search_queries(dataset, q_reps) | |
results = {qid: dict(zip(doc_ids, map(float, scores))) | |
for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)} | |
metrics = evaluate(qrels[dataset], results, k_values=[10, 100]) | |
return { | |
"NDCG@10": metrics["NDCG@10"], | |
"Recall@100": metrics["Recall@100"] | |
} | |
def gradio_interface(dataset, postfix): | |
if 'model' not in globals() or model is None: | |
# Load model and initial datasets | |
load_model() | |
for dataset in datasets: | |
print(f"Loading dataset: {dataset}") | |
load_corpus_embeddings(dataset) | |
load_queries(dataset) | |
return run_evaluation(dataset, postfix) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Dropdown(choices=datasets, label="Dataset", value="scifact"), | |
gr.Textbox(label="Prompt") | |
], | |
outputs=gr.JSON(label="Evaluation Results"), | |
title="Promptriever Demo", | |
description="Select a dataset and enter a prompt to evaluate the model's performance. Note: it takes about **ten seconds** to evaluate.", | |
examples=[ | |
["scifact", ""], | |
["scifact", "When judging the relevance of a document, focus on the pragmatics of the query and consider irrelevant any documents for which the user would have used a different query."] | |
] | |
) | |
# Launch the interface | |
iface.launch() |