File size: 2,920 Bytes
efb129d
1abd701
efb129d
 
 
 
 
1abd701
efb129d
 
 
1abd701
efb129d
066cc0b
1abd701
182358f
 
efb129d
182358f
 
 
1abd701
efb129d
182358f
 
efb129d
1abd701
6396265
182358f
6396265
 
 
182358f
 
1abd701
efb129d
182358f
1abd701
6396265
 
 
 
 
 
 
 
 
 
182358f
6396265
 
efb129d
6396265
efb129d
 
 
6396265
 
 
 
 
182358f
efb129d
1abd701
182358f
efb129d
 
182358f
efb129d
 
182358f
1abd701
 
 
182358f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU

# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()

# --- Load CodeSearchNet dataset (test split only) ---
dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")

# --- Query & Candidate Generator ---
def get_random_query(seed: int = 42):
    random.seed(seed)
    idx = random.randint(0, len(dataset) - 1)
    sample = dataset[idx]
    return sample["code"], sample["docstring"]

@GPU
def code_search_demo(seed: int):
    code_str, doc_str = get_random_query(seed)
    query_emb = model.encode(doc_str, convert_to_tensor=True)

    # ランダムに10件取得し、正解 index を含めるようにする(※現実には全件評価がおすすめ)
    candidates = dataset.shuffle(seed=seed).select(range(10))
    correct_label = dataset[seed]["label"]  # 正解 index(全体に対する)
    correct_code = dataset[correct_label]["code"]

    candidate_codes = [c["code"] for c in candidates]
    candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)

    cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
    results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)

    # 正解コードが Top-K に含まれているかを確認
    top_k = 10
    correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
    mrr = 0.0
    for rank, (code, _) in enumerate(results, start=1):
        if code.strip() == correct_code.strip():
            mrr = 1.0 / rank
            break

    # 出力構築
    output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
    output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
    output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
    output += "## 🏆 Top Matches:\n"

    medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
    for i, (code, score) in enumerate(results):
        label = medals[i] if i < len(medals) else f"#{i+1}"
        is_correct = "✅" if code.strip() == correct_code.strip() else ""
        output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n"

    return output


    return output

# --- Gradio UI ---
demo = gr.Interface(
    fn=code_search_demo,
    inputs=gr.Slider(0, 100000, value=42, step=1, label="Random Seed"),
    outputs=gr.Markdown(label="Search Result"),
    title="🔎 CodeSearch-ModernBERT-Owl Demo",
    description="docstring から類似 Python 関数を検索(CodeXGlue + ModernBERT-Owl)"
)

if __name__ == "__main__":
    demo.launch()