Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import faiss | |
import numpy as np | |
import pandas as pd | |
import datasets | |
from transformers import AutoTokenizer, AutoModel | |
title = "HouseMD bot" | |
description = "Gradio Demo for telegram bot \ | |
To use it, simply add your text message. \ | |
I've used the API on this Space to deploy the model on a Telegram bot." | |
def embed_bert_cls(text, model, tokenizer): | |
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = model(**{k: v.to(model.device) for k, v in t.items()}) | |
embeds = model_output.last_hidden_state[:, 0, :] | |
embeds = torch.nn.functional.normalize(embeds) | |
return embeds[0].cpu().numpy() | |
def get_ranked_docs(query, vec_query_base, data, | |
bi_model, bi_tok, cross_model, cross_tok): | |
vec_shape = vec_query_base.shape[1] | |
index = faiss.IndexFlatL2(vec_shape) | |
index.add(vec_query_base) | |
xq = embed_bert_cls(query, bi_model, bi_tok) | |
D, I = index.search(xq.reshape(1, vec_shape), 50) | |
corpus = [] | |
for i in I[0]: | |
corpus.append(data['answer'][i]) | |
queries = [query] * len(corpus) | |
tokenized_texts = cross_tok( | |
queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt" | |
).to(config.model.device) | |
with torch.no_grad(): | |
ce_scores = cross_model( | |
tokenized_texts['input_ids'], tokenized_texts['attention_mask'] | |
).last_hidden_state[:, 0, :] | |
ce_scores = ce_scores @ ce_scores.T | |
scores = ce_scores.cpu().numpy() | |
scores_ix = np.argsort(scores)[::-1] | |
return corpus[scores_ix[0][0]] | |
def load_dataset(url='ekaterinatao/house_md_context3'): | |
dataset = datasets.load_dataset(url, split='train') | |
house_dataset = [] | |
for data in dataset: | |
if data['labels'] == 0: | |
house_dataset.append(data) | |
return house_dataset | |
def load_cls_base(url='ekaterinatao/house_md_cls_embeds'): | |
cls_dataset = datasets.load_dataset(url, split='train') | |
cls_base = np.stack([embed for embed in pd.DataFrame(cls_dataset)['cls_embeds']]) | |
return cls_base | |
def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'): | |
bi_model = AutoModel.from_pretrained(checkpoint) | |
bi_tok = AutoTokenizer.from_pretrained(checkpoint) | |
return bi_model, bi_tok | |
def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'): | |
cross_model = AutoModel.from_pretrained(checkpoint) | |
cross_tok = AutoTokenizer.from_pretrained(checkpoint) | |
return cross_model, cross_tok | |
def get_answer(message): | |
dataset = load_dataset() | |
cls_base = load_cls_base() | |
bi_enc_model = load_bi_enc_model() | |
cross_enc_model = load_cross_enc_model() | |
answer = get_ranked_docs( | |
query=message, vec_query_base=cls_base, data=dataset, | |
bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1], | |
cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1] | |
) | |
return answer | |
interface = gr.Interface( | |
fn=get_answer, | |
inputs=gr.inputs.Textbox(lines=3, label="Input message to House MD"), | |
outputs=gr.Textbox(label="House MD's answer"), | |
title=title, | |
description=description, | |
enable_queue=True | |
) | |
interface.launch(debug=True) |