Shuu12121 commited on
Commit
182358f
·
verified ·
1 Parent(s): c0b2459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -26
app.py CHANGED
@@ -11,51 +11,47 @@ model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
11
  model.eval()
12
 
13
  # --- Load CodeSearchNet dataset (test split only) ---
14
- dataset_all = load_dataset("code_search_net", split="test", trust_remote_code=True)
15
- lang_filter = ["python", "java", "javascript", "ruby", "go", "php"]
16
 
17
- # --- UI for language choice ---
18
- def get_random_query(lang: str, seed: int = 42):
19
- subset = dataset_all.filter(lambda x: x["language"] == lang)
20
  random.seed(seed)
21
- idx = random.randint(0, len(subset) - 1)
22
- sample = subset[idx]
23
- return sample["func_code_string"] or "", sample["func_documentation_string"] or ""
24
 
25
  @GPU
26
- def code_search_demo(lang: str, seed: int):
27
- code_str, doc_str = get_random_query(lang, seed)
28
  query_emb = model.encode(doc_str, convert_to_tensor=True)
29
 
30
- # ランダムに取得した同一言語の10件の関数とドキュメントを比較対象として選択
31
- candidates = dataset_all.filter(lambda x: x["language"] == lang).shuffle(seed=seed).select(range(10))
32
- candidate_texts = [c["func_code_string"] or "" for c in candidates]
33
- candidate_embeddings = model.encode(candidate_texts, convert_to_tensor=True)
34
 
35
- # 類似度計算
36
  cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
37
- results = sorted(zip(candidate_texts, cos_scores), key=lambda x: x[1], reverse=True)
38
 
39
- # 結果フォーマット(ランキング付き)
40
- output = f"### 🔍 Query Docstring (Language: {lang})\n\n" + doc_str + "\n\n"
41
  output += "## 🏆 Top Matches:\n"
42
  medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
43
  for i, (code, score) in enumerate(results):
44
  label = medals[i] if i < len(medals) else f"#{i+1}"
45
- output += f"\n**{label}** - Similarity: {score.item():.4f}\\n\\n```\\n{code.strip()[:1000]}\\n```\\n"
 
46
  return output
47
 
48
- # --- Gradio Interface ---
49
  demo = gr.Interface(
50
  fn=code_search_demo,
51
- inputs=[
52
- gr.Dropdown(["python", "java", "javascript", "ruby", "go", "php"], label="Language", value="python"),
53
- gr.Slider(0, 100000, value=42, step=1, label="Random Seed")
54
- ],
55
  outputs=gr.Markdown(label="Search Result"),
56
  title="🔎 CodeSearch-ModernBERT-Owl Demo",
57
- description="コードドキュメントから関数検索を行うデモ(CodeSearchNet + CodeModernBERT-Owl)"
58
  )
59
 
60
  if __name__ == "__main__":
61
- demo.launch()
 
11
  model.eval()
12
 
13
  # --- Load CodeSearchNet dataset (test split only) ---
14
+ dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True)
 
15
 
16
+ # --- Query & Candidate Generator ---
17
+ def get_random_query(seed: int = 42):
 
18
  random.seed(seed)
19
+ idx = random.randint(0, len(dataset) - 1)
20
+ sample = dataset[idx]
21
+ return sample["code"], sample["docstring"]
22
 
23
  @GPU
24
+ def code_search_demo(seed: int):
25
+ code_str, doc_str = get_random_query(seed)
26
  query_emb = model.encode(doc_str, convert_to_tensor=True)
27
 
28
+ # ランダムに10件取得
29
+ candidates = dataset.shuffle(seed=seed).select(range(10))
30
+ candidate_codes = [c["code"] for c in candidates]
31
+ candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
32
 
33
+ # 類似度スコア算出
34
  cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
35
+ results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)
36
 
37
+ # 結果出力
38
+ output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
39
  output += "## 🏆 Top Matches:\n"
40
  medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
41
  for i, (code, score) in enumerate(results):
42
  label = medals[i] if i < len(medals) else f"#{i+1}"
43
+ output += f"\n**{label}** - Similarity: {score.item():.4f}\n\n```python\n{code.strip()[:1000]}\n```\n"
44
+
45
  return output
46
 
47
+ # --- Gradio UI ---
48
  demo = gr.Interface(
49
  fn=code_search_demo,
50
+ inputs=gr.Slider(0, 100000, value=42, step=1, label="Random Seed"),
 
 
 
51
  outputs=gr.Markdown(label="Search Result"),
52
  title="🔎 CodeSearch-ModernBERT-Owl Demo",
53
+ description="docstring から類似 Python 関数を検索(CodeXGlue + ModernBERT-Owl)"
54
  )
55
 
56
  if __name__ == "__main__":
57
+ demo.launch()