|
|
|
import gradio as gr |
|
import torch |
|
import random |
|
from sentence_transformers import SentenceTransformer, util |
|
from datasets import load_dataset |
|
from spaces import GPU |
|
|
|
|
|
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl") |
|
model.eval() |
|
|
|
|
|
dataset_all = load_dataset("code_search_net", split="test") |
|
lang_filter = ["python", "java", "javascript", "ruby", "go", "php"] |
|
|
|
|
|
def get_random_query(lang: str, seed: int = 42): |
|
subset = dataset_all.filter(lambda x: x["language"] == lang) |
|
random.seed(seed) |
|
idx = random.randint(0, len(subset) - 1) |
|
sample = subset[idx] |
|
return sample["function"] or "", sample["docstring"] or "" |
|
|
|
@GPU |
|
def code_search_demo(lang: str, seed: int): |
|
code_str, doc_str = get_random_query(lang, seed) |
|
query_emb = model.encode(doc_str, convert_to_tensor=True) |
|
|
|
|
|
candidates = dataset_all.filter(lambda x: x["language"] == lang).shuffle(seed=seed).select(range(10)) |
|
candidate_texts = [c["function"] or "" for c in candidates] |
|
candidate_embeddings = model.encode(candidate_texts, convert_to_tensor=True) |
|
|
|
|
|
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0] |
|
results = sorted(zip(candidate_texts, cos_scores), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
output = f"### 🔍 Query Docstring (Language: {lang})\n\n" + doc_str + "\n\n" |
|
output += "## 🏆 Top Matches:\n" |
|
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))] |
|
for i, (code, score) in enumerate(results): |
|
label = medals[i] if i < len(medals) else f"#{i+1}" |
|
output += f"\n**{label}** - Similarity: {score.item():.4f}\n\n``` |
|
{code.strip()[:1000]} |
|
```\n" |
|
|
|
return output |
|
|
|
|
|
demo = gr.Interface( |
|
fn=code_search_demo, |
|
inputs=[ |
|
gr.Dropdown(["python", "java", "javascript", "ruby", "go", "php"], label="Language", value="python"), |
|
gr.Slider(0, 100000, value=42, step=1, label="Random Seed") |
|
], |
|
outputs=gr.Markdown(label="Search Result"), |
|
title="🔎 CodeSearch-ModernBERT-Owl Demo", |
|
description="コードドキュメントから関数検索を行うデモ(CodeSearchNet + CodeModernBERT-Owl)" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |