File size: 2,621 Bytes
efb129d
1abd701
efb129d
 
 
 
 
1abd701
efb129d
 
 
1abd701
efb129d
99ff800
efb129d
1abd701
efb129d
 
 
 
 
 
c0b2459
1abd701
efb129d
 
 
 
1abd701
efb129d
 
c0b2459
efb129d
1abd701
efb129d
 
 
1abd701
efb129d
 
 
 
 
 
c30e35d
efb129d
1abd701
efb129d
 
 
 
 
 
1abd701
efb129d
 
 
1abd701
 
 
efb129d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU

# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()

# --- Load CodeSearchNet dataset (test split only) ---
dataset_all = load_dataset("code_search_net", split="test", trust_remote_code=True)
lang_filter = ["python", "java", "javascript", "ruby", "go", "php"]

# --- UI for language choice ---
def get_random_query(lang: str, seed: int = 42):
    subset = dataset_all.filter(lambda x: x["language"] == lang)
    random.seed(seed)
    idx = random.randint(0, len(subset) - 1)
    sample = subset[idx]
    return sample["func_code_string"] or "", sample["func_documentation_string"] or ""

@GPU
def code_search_demo(lang: str, seed: int):
    code_str, doc_str = get_random_query(lang, seed)
    query_emb = model.encode(doc_str, convert_to_tensor=True)

    # ランダムに取得した同一言語の10件の関数とドキュメントを比較対象として選択
    candidates = dataset_all.filter(lambda x: x["language"] == lang).shuffle(seed=seed).select(range(10))
    candidate_texts = [c["func_code_string"] or "" for c in candidates]
    candidate_embeddings = model.encode(candidate_texts, convert_to_tensor=True)

    # 類似度計算
    cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
    results = sorted(zip(candidate_texts, cos_scores), key=lambda x: x[1], reverse=True)

    # 結果フォーマット(ランキング付き)
    output = f"### 🔍 Query Docstring (Language: {lang})\n\n" + doc_str + "\n\n"
    output += "## 🏆 Top Matches:\n"
    medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
    for i, (code, score) in enumerate(results):
        label = medals[i] if i < len(medals) else f"#{i+1}"
        output += f"\n**{label}** - Similarity: {score.item():.4f}\\n\\n```\\n{code.strip()[:1000]}\\n```\\n"
    return output

# --- Gradio Interface ---
demo = gr.Interface(
    fn=code_search_demo,
    inputs=[
        gr.Dropdown(["python", "java", "javascript", "ruby", "go", "php"], label="Language", value="python"),
        gr.Slider(0, 100000, value=42, step=1, label="Random Seed")
    ],
    outputs=gr.Markdown(label="Search Result"),
    title="🔎 CodeSearch-ModernBERT-Owl Demo",
    description="コードドキュメントから関数検索を行うデモ(CodeSearchNet + CodeModernBERT-Owl)"
)

if __name__ == "__main__":
    demo.launch()