File size: 2,621 Bytes
efb129d 1abd701 efb129d 1abd701 efb129d 1abd701 efb129d 99ff800 efb129d 1abd701 efb129d c0b2459 1abd701 efb129d 1abd701 efb129d c0b2459 efb129d 1abd701 efb129d 1abd701 efb129d c30e35d efb129d 1abd701 efb129d 1abd701 efb129d 1abd701 efb129d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU
# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()
# --- Load CodeSearchNet dataset (test split only) ---
dataset_all = load_dataset("code_search_net", split="test", trust_remote_code=True)
lang_filter = ["python", "java", "javascript", "ruby", "go", "php"]
# --- UI for language choice ---
def get_random_query(lang: str, seed: int = 42):
subset = dataset_all.filter(lambda x: x["language"] == lang)
random.seed(seed)
idx = random.randint(0, len(subset) - 1)
sample = subset[idx]
return sample["func_code_string"] or "", sample["func_documentation_string"] or ""
@GPU
def code_search_demo(lang: str, seed: int):
code_str, doc_str = get_random_query(lang, seed)
query_emb = model.encode(doc_str, convert_to_tensor=True)
# ランダムに取得した同一言語の10件の関数とドキュメントを比較対象として選択
candidates = dataset_all.filter(lambda x: x["language"] == lang).shuffle(seed=seed).select(range(10))
candidate_texts = [c["func_code_string"] or "" for c in candidates]
candidate_embeddings = model.encode(candidate_texts, convert_to_tensor=True)
# 類似度計算
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
results = sorted(zip(candidate_texts, cos_scores), key=lambda x: x[1], reverse=True)
# 結果フォーマット(ランキング付き)
output = f"### 🔍 Query Docstring (Language: {lang})\n\n" + doc_str + "\n\n"
output += "## 🏆 Top Matches:\n"
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
for i, (code, score) in enumerate(results):
label = medals[i] if i < len(medals) else f"#{i+1}"
output += f"\n**{label}** - Similarity: {score.item():.4f}\\n\\n```\\n{code.strip()[:1000]}\\n```\\n"
return output
# --- Gradio Interface ---
demo = gr.Interface(
fn=code_search_demo,
inputs=[
gr.Dropdown(["python", "java", "javascript", "ruby", "go", "php"], label="Language", value="python"),
gr.Slider(0, 100000, value=42, step=1, label="Random Seed")
],
outputs=gr.Markdown(label="Search Result"),
title="🔎 CodeSearch-ModernBERT-Owl Demo",
description="コードドキュメントから関数検索を行うデモ(CodeSearchNet + CodeModernBERT-Owl)"
)
if __name__ == "__main__":
demo.launch() |