File size: 2,559 Bytes
efb129d
1abd701
efb129d
 
 
 
 
1abd701
efb129d
 
 
1abd701
efb129d
 
 
1abd701
efb129d
 
 
 
 
 
 
1abd701
efb129d
 
 
 
1abd701
efb129d
 
 
 
1abd701
efb129d
 
 
1abd701
efb129d
 
 
 
 
 
 
 
 
1abd701
efb129d
1abd701
efb129d
 
 
 
 
 
1abd701
efb129d
 
 
1abd701
 
 
efb129d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU

# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()

# --- Load CodeSearchNet dataset (test split only) ---
dataset_all = load_dataset("code_search_net", split="test")
lang_filter = ["python", "java", "javascript", "ruby", "go", "php"]

# --- UI for language choice ---
def get_random_query(lang: str, seed: int = 42):
    subset = dataset_all.filter(lambda x: x["language"] == lang)
    random.seed(seed)
    idx = random.randint(0, len(subset) - 1)
    sample = subset[idx]
    return sample["function"] or "", sample["docstring"] or ""

@GPU
def code_search_demo(lang: str, seed: int):
    code_str, doc_str = get_random_query(lang, seed)
    query_emb = model.encode(doc_str, convert_to_tensor=True)

    # ランダムに取得した同一言語の10件の関数とドキュメントを比較対象として選択
    candidates = dataset_all.filter(lambda x: x["language"] == lang).shuffle(seed=seed).select(range(10))
    candidate_texts = [c["function"] or "" for c in candidates]
    candidate_embeddings = model.encode(candidate_texts, convert_to_tensor=True)

    # 類似度計算
    cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
    results = sorted(zip(candidate_texts, cos_scores), key=lambda x: x[1], reverse=True)

    # 結果フォーマット(ランキング付き)
    output = f"### 🔍 Query Docstring (Language: {lang})\n\n" + doc_str + "\n\n"
    output += "## 🏆 Top Matches:\n"
    medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
    for i, (code, score) in enumerate(results):
        label = medals[i] if i < len(medals) else f"#{i+1}"
        output += f"\n**{label}** - Similarity: {score.item():.4f}\n\n```
{code.strip()[:1000]}
```\n"

    return output

# --- Gradio Interface ---
demo = gr.Interface(
    fn=code_search_demo,
    inputs=[
        gr.Dropdown(["python", "java", "javascript", "ruby", "go", "php"], label="Language", value="python"),
        gr.Slider(0, 100000, value=42, step=1, label="Random Seed")
    ],
    outputs=gr.Markdown(label="Search Result"),
    title="🔎 CodeSearch-ModernBERT-Owl Demo",
    description="コードドキュメントから関数検索を行うデモ(CodeSearchNet + CodeModernBERT-Owl)"
)

if __name__ == "__main__":
    demo.launch()