Spaces:
Runtime error
Runtime error
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset | |
import gradio as gr | |
import torch | |
import random | |
from sentence_transformers import SentenceTransformer, util | |
from datasets import load_dataset | |
from spaces import GPU | |
# --- Load model --- | |
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl") | |
model.eval() | |
# --- Load CodeSearchNet dataset (test split only) --- | |
dataset_all = load_dataset("code_search_net", split="test") | |
lang_filter = ["python", "java", "javascript", "ruby", "go", "php"] | |
# --- UI for language choice --- | |
def get_random_query(lang: str, seed: int = 42): | |
subset = dataset_all.filter(lambda x: x["language"] == lang) | |
random.seed(seed) | |
idx = random.randint(0, len(subset) - 1) | |
sample = subset[idx] | |
return sample["function"] or "", sample["docstring"] or "" | |
def code_search_demo(lang: str, seed: int): | |
code_str, doc_str = get_random_query(lang, seed) | |
query_emb = model.encode(doc_str, convert_to_tensor=True) | |
# ランダムに取得した同一言語の10件の関数とドキュメントを比較対象として選択 | |
candidates = dataset_all.filter(lambda x: x["language"] == lang).shuffle(seed=seed).select(range(10)) | |
candidate_texts = [c["function"] or "" for c in candidates] | |
candidate_embeddings = model.encode(candidate_texts, convert_to_tensor=True) | |
# 類似度計算 | |
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0] | |
results = sorted(zip(candidate_texts, cos_scores), key=lambda x: x[1], reverse=True) | |
# 結果フォーマット(ランキング付き) | |
output = f"### 🔍 Query Docstring (Language: {lang})\n\n" + doc_str + "\n\n" | |
output += "## 🏆 Top Matches:\n" | |
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))] | |
for i, (code, score) in enumerate(results): | |
label = medals[i] if i < len(medals) else f"#{i+1}" | |
output += f"\n**{label}** - Similarity: {score.item():.4f}\\n\\n```\\n{code.strip()[:1000]}\\n```\\n" | |
return output | |
# --- Gradio Interface --- | |
demo = gr.Interface( | |
fn=code_search_demo, | |
inputs=[ | |
gr.Dropdown(["python", "java", "javascript", "ruby", "go", "php"], label="Language", value="python"), | |
gr.Slider(0, 100000, value=42, step=1, label="Random Seed") | |
], | |
outputs=gr.Markdown(label="Search Result"), | |
title="🔎 CodeSearch-ModernBERT-Owl Demo", | |
description="コードドキュメントから関数検索を行うデモ(CodeSearchNet + CodeModernBERT-Owl)" | |
) | |
if __name__ == "__main__": | |
demo.launch() |