hotchpotch commited on
Commit
05edb6c
·
1 Parent(s): 2eba95c
Files changed (2) hide show
  1. app.py +77 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from yasem import SpladeEmbedder
5
+
6
+ if os.getenv("SPACE_ID"):
7
+ USE_HF_SPACE = True
8
+ os.environ["HF_HOME"] = "/data/.huggingface"
9
+ os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
10
+ else:
11
+ USE_HF_SPACE = False
12
+
13
+ MODEL_NAME = os.environ.get("MODEL_NAME", "hotchpotch/wip-tmp-model-base-v1-pre-2")
14
+
15
+
16
+ @st.cache_resource
17
+ def get_embedder(model_name: str = MODEL_NAME) -> SpladeEmbedder:
18
+ revision = None
19
+ if MODEL_NAME == "hotchpotch/wip-tmp-model-base-v1-pre-2":
20
+ revision = "a7db67721ea22165faefba6f0b7ee726b0f3a78f"
21
+ embedder = SpladeEmbedder(
22
+ model_name,
23
+ revision=revision,
24
+ )
25
+ return embedder
26
+
27
+
28
+ def get_token_values_sorted(input_text: str) -> list[tuple[float, str]]:
29
+ embedder = get_embedder()
30
+ embeddings = embedder.encode([input_text])
31
+ token_values = embedder.get_token_values(embeddings[0])
32
+ sorted_tokens = sorted(token_values.items(), key=lambda item: item[1], reverse=True) # type: ignore
33
+ return [(value, key) for key, value in sorted_tokens]
34
+
35
+
36
+ def main():
37
+ st.set_page_config(
38
+ page_title="SPLADE 日本語 demo",
39
+ layout="centered",
40
+ initial_sidebar_state="auto",
41
+ )
42
+
43
+ st.title("SPLADE 日本語 demo")
44
+
45
+ st.markdown("""
46
+ ### 入力
47
+ 以下のテキストエリアに解析したいテキストを入力してください。
48
+ """)
49
+
50
+ input_text = st.text_area("テキスト入力", height=200)
51
+
52
+ if st.button("解析開始"):
53
+ if input_text.strip():
54
+ with st.spinner("解析中..."):
55
+ sorted_tokens = get_token_values_sorted(input_text)
56
+
57
+ st.success("解析が完了しました。")
58
+
59
+ st.markdown("### 結果")
60
+ if sorted_tokens:
61
+ formatted_data = [
62
+ {"頻度": freq, "単語": word} for freq, word in sorted_tokens
63
+ ]
64
+ st.table(formatted_data)
65
+ else:
66
+ st.warning("入力テキストから有効な単語が見つかりませんでした。")
67
+ else:
68
+ st.warning("テキストを入力してください。")
69
+ else:
70
+ get_embedder()
71
+
72
+ st.markdown("---")
73
+ st.markdown("© 2024 SPLADE 日本語 demo")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ yasem