house_md_bot / app.py
ekaterinatao's picture
Update app.py
d2b1190 verified
raw
history blame
3.25 kB
import gradio as gr
import torch
import faiss
import numpy as np
import pandas as pd
import datasets
from transformers import AutoTokenizer, AutoModel
title = "HouseMD bot"
description = "Gradio Demo for telegram bot \
To use it, simply add your text message. \
I've used the API on this Space to deploy the model on a Telegram bot."
def embed_bert_cls(text, model, tokenizer):
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**{k: v.to(model.device) for k, v in t.items()})
embeds = model_output.last_hidden_state[:, 0, :]
embeds = torch.nn.functional.normalize(embeds)
return embeds[0].cpu().numpy()
def get_ranked_docs(query, vec_query_base, data,
bi_model, bi_tok, cross_model, cross_tok):
vec_shape = vec_query_base.shape[1]
index = faiss.IndexFlatL2(vec_shape)
index.add(vec_query_base)
xq = embed_bert_cls(query, bi_model, bi_tok)
D, I = index.search(xq.reshape(1, vec_shape), 50)
corpus = []
for i in I[0]:
corpus.append(data['answer'][i])
queries = [query] * len(corpus)
tokenized_texts = cross_tok(
queries, corpus, max_length=128, padding=True, truncation=True, return_tensors="pt"
).to(config.model.device)
with torch.no_grad():
ce_scores = cross_model(
tokenized_texts['input_ids'], tokenized_texts['attention_mask']
).last_hidden_state[:, 0, :]
ce_scores = ce_scores @ ce_scores.T
scores = ce_scores.cpu().numpy()
scores_ix = np.argsort(scores)[::-1]
return corpus[scores_ix[0][0]]
def load_dataset(url='ekaterinatao/house_md_context3'):
dataset = datasets.load_dataset(url, split='train')
house_dataset = []
for data in dataset:
if data['labels'] == 0:
house_dataset.append(data)
return house_dataset
def load_cls_base(url='ekaterinatao/house_md_cls_embeds'):
cls_dataset = datasets.load_dataset(url, split='train')
cls_base = np.stack([embed for embed in pd.DataFrame(cls_dataset)['cls_embeds']])
return cls_base
def load_bi_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-bi-encoder'):
bi_model = AutoModel.from_pretrained(checkpoint)
bi_tok = AutoTokenizer.from_pretrained(checkpoint)
return bi_model, bi_tok
def load_cross_enc_model(checkpoint='ekaterinatao/house-md-bot-bert-cross-encoder'):
cross_model = AutoModel.from_pretrained(checkpoint)
cross_tok = AutoTokenizer.from_pretrained(checkpoint)
return cross_model, cross_tok
def get_answer(message):
dataset = load_dataset()
cls_base = load_cls_base()
bi_enc_model = load_bi_enc_model()
cross_enc_model = load_cross_enc_model()
answer = get_ranked_docs(
query=message, vec_query_base=cls_base, data=dataset,
bi_model=bi_enc_model[0], bi_tok=bi_enc_model[1],
cross_model=cross_enc_model[0], cross_tok=cross_enc_model[1]
)
return answer
interface = gr.Interface(
fn=get_answer,
inputs=gr.inputs.Textbox(lines=3, label="Input message to House MD"),
outputs=gr.Textbox(label="House MD's answer"),
title=title,
description=description,
enable_queue=True
)
interface.launch(debug=True)