orionweller's picture
mmap
53b3bb9
raw
history blame
7.16 kB
import gradio as gr
import pickle
import numpy as np
import glob
import tqdm
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from tevatron.retriever.searcher import FaissFlatSearcher
import logging
import os
import json
import spaces
import ir_datasets
import pytrec_eval
from huggingface_hub import login
import faiss
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Authenticate with HF_TOKEN
login(token=os.environ['HF_TOKEN'])
# Global variables
CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint"
BASE_MODEL = "meta-llama/Llama-2-7b-hf"
tokenizer = None
model = None
retrievers = {}
corpus_lookups = {}
queries = {}
q_lookups = {}
qrels = {}
datasets = ["scifact"]
current_dataset = "scifact"
def pool(last_hidden_states, attention_mask):
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
return last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
def create_batch_dict(tokenizer, input_texts, max_length=512):
batch_dict = tokenizer(
input_texts,
max_length=max_length - 1,
return_token_type_ids=False,
return_attention_mask=False,
padding=False,
truncation=True
)
batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
return tokenizer.pad(
batch_dict,
padding=True,
pad_to_multiple_of=8,
return_attention_mask=True,
return_tensors="pt",
)
def load_model():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model_instance = AutoModel.from_pretrained(BASE_MODEL, device_map="auto", torch_dtype=torch.float16)
model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
model.eval()
def load_faiss_index(dataset_name):
index_path = f"{dataset_name}/faiss_index.bin"
if os.path.exists(index_path):
logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
return faiss.read_index(index_path, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
return None
def search_queries(dataset_name, q_reps, depth=1000):
faiss_index = load_faiss_index(dataset_name)
if faiss_index is None:
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
# Ensure q_reps is a 2D numpy array of the correct type
q_reps = np.ascontiguousarray(q_reps.astype('float32'))
# Perform the search
all_scores, all_indices = faiss_index.search(q_reps, depth)
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
# Clean up
del faiss_index
return all_scores, np.array(psg_indices)
def load_corpus_lookups(dataset_name):
global corpus_lookups
corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
index_files = glob.glob(corpus_path)
corpus_lookups[dataset_name] = []
for file in index_files:
with open(file, 'rb') as f:
_, p_lookup = pickle.load(f)
corpus_lookups[dataset_name] += p_lookup
def load_queries(dataset_name):
global queries, q_lookups, qrels
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
queries[dataset_name] = []
q_lookups[dataset_name] = {}
qrels[dataset_name] = {}
for query in dataset.queries_iter():
queries[dataset_name].append(query.text)
q_lookups[dataset_name][query.query_id] = query.text
for qrel in dataset.qrels_iter():
if qrel.query_id not in qrels[dataset_name]:
qrels[dataset_name][qrel.query_id] = {}
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
@spaces.GPU
def encode_queries(dataset_name, postfix):
global queries, tokenizer, model
input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[dataset_name]]
encoded_embeds = []
batch_size = 64
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc="Encoding queries"):
batch_input_texts = input_texts[start_idx: start_idx + batch_size]
batch_dict = create_batch_dict(tokenizer, batch_input_texts)
batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()}
with torch.cuda.amp.autocast():
outputs = model(**batch_dict)
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeds = F.normalize(embeds, p=2, dim=-1)
encoded_embeds.append(embeds.cpu().numpy())
return np.concatenate(encoded_embeds, axis=0)
def evaluate(qrels, results, k_values):
evaluator = pytrec_eval.RelevanceEvaluator(
qrels, {f"ndcg_cut.{k}" for k in k_values} | {f"recall.{k}" for k in k_values}
)
scores = evaluator.evaluate(results)
metrics = {}
for k in k_values:
metrics[f"NDCG@{k}"] = round(np.mean([query_scores[f"ndcg_cut_{k}"] for query_scores in scores.values()]), 3)
metrics[f"Recall@{k}"] = round(np.mean([query_scores[f"recall_{k}"] for query_scores in scores.values()]), 3)
return metrics
def run_evaluation(dataset, postfix):
global current_dataset
if dataset not in corpus_lookups or dataset not in queries:
load_corpus_lookups(dataset)
load_queries(dataset)
current_dataset = dataset
q_reps = encode_queries(dataset, postfix)
all_scores, psg_indices = search_queries(dataset, q_reps)
results = {qid: dict(zip(doc_ids, map(float, scores)))
for qid, scores, doc_ids in zip(q_lookups[dataset].keys(), all_scores, psg_indices)}
metrics = evaluate(qrels[dataset], results, k_values=[10, 100])
return {
"NDCG@10": metrics["NDCG@10"],
"Recall@100": metrics["Recall@100"]
}
def gradio_interface(dataset, postfix):
if 'model' not in globals() or model is None:
load_model()
for dataset in datasets:
print(f"Loading dataset: {dataset}")
load_corpus_lookups(dataset)
load_queries(dataset)
return run_evaluation(dataset, postfix)
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Dropdown(choices=datasets, label="Dataset", value="scifact"),
gr.Textbox(label="Prompt")
],
outputs=gr.JSON(label="Evaluation Results"),
title="Promptriever Demo",
description="Select a dataset and enter a prompt to evaluate the model's performance. Note: it takes about **ten seconds** to evaluate.",
examples=[
["scifact", ""],
["scifact", "Think carefully about these conditions when determining relevance."]
],
cache_examples=True,
)
# Launch the interface
iface.launch()