File size: 2,288 Bytes
05edb6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import streamlit as st
from yasem import SpladeEmbedder
if os.getenv("SPACE_ID"):
USE_HF_SPACE = True
os.environ["HF_HOME"] = "/data/.huggingface"
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
USE_HF_SPACE = False
MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/wip-tmp-model-base-v1-pre-2")
@st.cache_resource
def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder:
revision = None
if MODEL_NAME == "hotchpotch/wip-tmp-model-base-v1-pre-2":
revision = "a7db67721ea22165faefba6f0b7ee726b0f3a78f"
embedder = SpladeEmbedder(
model_name,
revision=revision,
)
return embedder
def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]:
embedder = get_embedder()
embeddings = embedder.encode([input_text])
token_values = embedder.get_token_values(embeddings[0])
sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True) # type: ignore
return [(value, key) for key, value in sorted_tokens]
def main():
st.set_page_config(
page_title="SPLADE 日本語 demo",
layout="centered",
initial_sidebar_state="auto",
)
st.title("SPLADE 日本語 demo")
st.markdown("""
### 入力
以下のテキストエリアに解析したいテキストを入力してください。
""")
input_text = st.text_area("テキスト入力", height=200)
if st.button("解析開始"):
if input_text.strip():
with st.spinner("解析中..."):
sorted_tokens = get_token_values_sorted(input_text)
st.success("解析が完了しました。")
st.markdown("### 結果")
if sorted_tokens:
formatted_data = [
{"頻度": freq, "単語": word} for freq, word in sorted_tokens
]
st.table(formatted_data)
else:
st.warning("入力テキストから有効な単語が見つかりませんでした。")
else:
st.warning("テキストを入力してください。")
else:
get_embedder()
st.markdown("---")
st.markdown("© 2024 SPLADE 日本語 demo")
if __name__ == "__main__":
main()
|