schroneko's picture
add files
babe057
raw
history blame
3.52 kB
# app.py
import gradio as gr
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
def load_model(model_name):
if model_name == "GLuCoSE-base-ja-v2":
return SentenceTransformer("pkshatech/GLuCoSE-base-ja-v2")
elif model_name == "RoSEtta-base-ja":
return SentenceTransformer("pkshatech/RoSEtta-base", trust_remote_code=True)
elif model_name == "ruri-large":
return SentenceTransformer("cl-nagoya/ruri-large")
def get_similarities(model_name, sentences):
model = load_model(model_name)
if model_name == "ruri-large":
sentences = [
"クエリ: " + s if i % 2 == 0 else "文章: " + s
for i, s in enumerate(sentences)
]
embeddings = model.encode(sentences, convert_to_tensor=True)
if model_name in ["GLuCoSE-base-ja-v2", "RoSEtta-base-ja"]:
similarities = model.similarity(embeddings, embeddings)
else: # ruri-large
similarities = F.cosine_similarity(
embeddings.unsqueeze(0), embeddings.unsqueeze(1), dim=2
)
return similarities.cpu().numpy()
def format_similarities(similarities):
return "\n".join([" ".join([f"{val:.4f}" for val in row]) for row in similarities])
def process_input(model_name, input_text):
sentences = [s.strip() for s in input_text.split("\n") if s.strip()]
similarities = get_similarities(model_name, sentences)
return format_similarities(similarities)
models = ["GLuCoSE-base-ja-v2", "RoSEtta-base-ja", "ruri-large"]
with gr.Blocks() as demo:
gr.Markdown("# Sentence Similarity Demo")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
choices=models, label="Select Model", value=models[0]
)
input_text = gr.Textbox(lines=5, label="Input Sentences (one per line)")
submit_btn = gr.Button(value="Calculate Similarities")
with gr.Column():
output_text = gr.Textbox(label="Similarity Matrix", lines=10)
submit_btn.click(
process_input, inputs=[model_dropdown, input_text], outputs=output_text
)
gr.Examples(
examples=[
[
"GLuCoSE-base-ja-v2",
"The weather is lovely today.\nIt's so sunny outside!\nHe drove to the stadium.",
],
[
"RoSEtta-base-ja",
"The weather is lovely today.\nIt's so sunny outside!\nHe drove to the stadium.",
],
[
"ruri-large",
"瑠璃色はどんな色?\n瑠璃色(るりいろ)は、紫みを帯びた濃い青。名は、半貴石の瑠璃(ラピスラズリ、英: lapis lazuli)による。JIS慣用色名では「こい紫みの青」(略号 dp-pB)と定義している[1][2]。\nワシやタカのように、鋭いくちばしと爪を持った大型の鳥類を総称して「何類」というでしょう?\nワシ、タカ、ハゲワシ、ハヤブサ、コンドル、フクロウが代表的である。これらの猛禽類はリンネ前後の時代(17~18世紀)には鷲類・鷹類・隼類及び梟類に分類された。ちなみにリンネは狩りをする鳥を単一の目(もく)にまとめ、vultur(コンドル、ハゲワシ)、falco(ワシ、タカ、ハヤブサなど)、strix(フクロウ)、lanius(モズ)の4属を含めている。",
],
],
inputs=[model_dropdown, input_text],
)
demo.launch()