Spaces:
Runtime error
Runtime error
File size: 2,542 Bytes
6ae27e8 a41bdbc 6ae27e8 6e03e5d 6ae27e8 6e03e5d 0be3a1a 6e03e5d 6ae27e8 6e03e5d 6ae27e8 6e03e5d 6ae27e8 31f3439 6e03e5d 6ae27e8 6e03e5d 6ae27e8 6e03e5d a41bdbc 31f3439 6e03e5d 6ae27e8 6e03e5d 31f3439 6e03e5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import streamlit as st
import pandas as pd
from backend import inference
from backend.config import MODELS_ID
st.title('Demo using Flax-Sentence-Tranformers')
st.sidebar.title('Tasks')
menu = st.sidebar.radio("", options=["Sentence Similarity", "Search", "Clustering"], index=0)
st.markdown('''
Hi! This is the demo for the [flax sentence embeddings](https://huggingface.co/flax-sentence-embeddings) created for the **Flax/JAX community week 🤗**. We are going to use three flax-sentence-embeddings models: a **distilroberta base**, a **mpnet base** and a **minilm-l6**. All were trained on all the dataset of the 1B+ train corpus with the v3 setup.
---
**Instructions**: You can compare the similarity of a main text with other texts of your choice (in the sidebar). In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the others.
For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
Please enjoy!!
''')
if menu == "Sentence Similarity":
select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
anchor = st.text_input(
'Please enter here the main text you want to compare:'
)
n_texts = st.number_input(
f'''How many texts you want to compare with: '{anchor}'?''',
value=2,
min_value=2)
inputs = []
for i in range(int(n_texts)):
input = st.text_input(f'Text {i + 1}:')
inputs.append(input)
if st.button('Tell me the similarity.'):
results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
df_results = {model: results[model] for model in results}
index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
df_total = pd.DataFrame(index=index)
for key, value in df_results.items():
df_total[key] = list(value['score'].values)
st.write('Here are the results for selected models:')
st.write(df_total)
st.write('Visualize the results of each model:')
st.line_chart(df_total)
elif menu == "Search":
select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
elif menu == "Clustering":
select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0]) |