Spaces:
Runtime error
Runtime error
File size: 2,458 Bytes
da676c8 40c9d2b da676c8 12094be da676c8 eacbe96 da676c8 12094be f089045 da676c8 fa02d7f 6dd0ae0 12094be eacbe96 6dd0ae0 005c6a4 eacbe96 da676c8 eacbe96 da676c8 eacbe96 1ef9e65 eacbe96 1ef9e65 da676c8 eacbe96 6dd0ae0 eacbe96 4fd1747 12094be ac5b8a7 4fd1747 eacbe96 fa02d7f f089045 fa02d7f 38a8bac 12094be eacbe96 40c9d2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import streamlit as st
import pandas as pd
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
@st.cache(allow_output_mutation=True)
def get_model(model):
return pipeline("fill-mask", model=model, top_k=100)#set the maximum of tokens to be retrieved after each inference to model
HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
st.caption("This is a simple auto-completion where the next token is predicted per probability and a weigh if appears in user's history")
history_keyword_text = st.text_input("Enter users's history keywords (optional, i.e., 'Gates')", value="")
#history_keyword_text=''
text = st.text_input("Enter a text for auto completion...", value='Where is Bill')
#text='Where is Bill'
semantic_text = st.text_input("Enter users's history semantic (optional, i.e., 'Microsoft or President')", value="Microsoft")
#semantic_text='President'
model = st.selectbox("choose a model", ["roberta-base", "bert-base-uncased"])
#model='roberta-base'
nlp = get_model(model)
#data_load_state = st.text('Loading model...')
if text:
# data_load_state = st.text('Inference to model...')
result = nlp(text+' '+nlp.tokenizer.mask_token)
# data_load_state.text('')
sem_list=[_.strip() for _ in semantic_text.split(',')]
if len(semantic_text):
predicted_seq=[rec['sequence'] for rec in result]
predicted_embeddings = semantic_model.encode(predicted_seq, convert_to_tensor=True)
semantic_history_embeddings = semantic_model.encode(sem_list, convert_to_tensor=True)
cosine_scores = util.cos_sim(predicted_embeddings, semantic_history_embeddings)
for index, r in enumerate(result):
if len(semantic_text):
# for j_index in range(len(sem_list)):
if len(r['token_str'])>2: #skip spcial chars such as "?"
result[index]['score']+=float(sum(cosine_scores[index]))
if r['token_str'].lower().strip() in history_keyword_text.lower().strip() and len(r['token_str'].lower().strip())>1:
#found from history, then increase the score of tokens
result[index]['score']*=HISTORY_WEIGHT
#sort the results
df=pd.DataFrame(result).sort_values(by='score', ascending=False)
# show the results as a table
st.table(df) |