|
import os |
|
|
|
import streamlit as st |
|
from yasem import SpladeEmbedder |
|
|
|
if os.getenv("SPACE_ID"): |
|
USE_HF_SPACE = True |
|
os.environ["HF_HOME"] = "/data/.huggingface" |
|
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" |
|
else: |
|
USE_HF_SPACE = False |
|
|
|
MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/wip-tmp-model-base-v1-pre-2") |
|
|
|
|
|
@st.cache_resource |
|
def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder: |
|
revision = None |
|
if MODEL_NAME == "hotchpotch/wip-tmp-model-base-v1-pre-2": |
|
revision = "a7db67721ea22165faefba6f0b7ee726b0f3a78f" |
|
embedder = SpladeEmbedder( |
|
model_name, |
|
revision=revision, |
|
) |
|
return embedder |
|
|
|
|
|
def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]: |
|
embedder = get_embedder() |
|
embeddings = embedder.encode([input_text]) |
|
token_values = embedder.get_token_values(embeddings[0]) |
|
sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True) |
|
return [(value, key) for key, value in sorted_tokens] |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="SPLADE 日本語 demo", |
|
layout="centered", |
|
initial_sidebar_state="auto", |
|
) |
|
|
|
st.title("SPLADE 日本語 demo") |
|
|
|
st.markdown(""" |
|
### 入力 |
|
以下のテキストエリアに解析したいテキストを入力してください。 |
|
""") |
|
|
|
input_text = st.text_area("テキスト入力", height=200) |
|
|
|
if st.button("解析開始"): |
|
if input_text.strip(): |
|
with st.spinner("解析中..."): |
|
sorted_tokens = get_token_values_sorted(input_text) |
|
|
|
st.success("解析が完了しました。") |
|
|
|
st.markdown("### 結果") |
|
if sorted_tokens: |
|
formatted_data = [ |
|
{"頻度": freq, "単語": word} for freq, word in sorted_tokens |
|
] |
|
st.table(formatted_data) |
|
else: |
|
st.warning("入力テキストから有効な単語が見つかりませんでした。") |
|
else: |
|
st.warning("テキストを入力してください。") |
|
else: |
|
get_embedder() |
|
|
|
st.markdown("---") |
|
st.markdown("© 2024 SPLADE 日本語 demo") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|