Spaces:
Running
on
Zero
Running
on
Zero
zamalali
commited on
Commit
·
8d67bd2
1
Parent(s):
7c0e46f
Added changes to enhance the metrics
Browse files- __pycache__/main.cpython-313.pyc +0 -0
- app.py +36 -21
- main.py +127 -36
__pycache__/main.cpython-313.pyc
ADDED
Binary file (17.6 kB). View file
|
|
app.py
CHANGED
@@ -2,9 +2,8 @@ import gradio as gr
|
|
2 |
import time
|
3 |
import threading
|
4 |
import logging
|
5 |
-
import spaces
|
6 |
from main import run_repository_ranking # Your repository ranking function
|
7 |
-
|
8 |
# ---------------------------
|
9 |
# Global Logging Buffer Setup
|
10 |
# ---------------------------
|
@@ -38,10 +37,17 @@ def filter_logs(logs):
|
|
38 |
return filtered
|
39 |
|
40 |
def parse_result_to_html(raw_result: str) -> str:
|
41 |
-
|
|
|
|
|
|
|
42 |
entries = raw_result.strip().split("Final Rank:")
|
|
|
|
|
|
|
|
|
43 |
html = """
|
44 |
-
<table border="1" style="width:
|
45 |
<thead>
|
46 |
<tr>
|
47 |
<th>Rank</th>
|
@@ -52,10 +58,10 @@ def parse_result_to_html(raw_result: str) -> str:
|
|
52 |
</thead>
|
53 |
<tbody>
|
54 |
"""
|
55 |
-
for entry in entries
|
56 |
lines = entry.strip().split("\n")
|
57 |
data = {}
|
58 |
-
data["Final Rank"] = lines[0].strip()
|
59 |
for line in lines[1:]:
|
60 |
if ": " in line:
|
61 |
key, val = line.split(": ", 1)
|
@@ -115,9 +121,18 @@ def lite_runner(topic):
|
|
115 |
yield status, details
|
116 |
|
117 |
# ---------------------------
|
118 |
-
# App UI Setup Using Gradio Soft Theme
|
119 |
# ---------------------------
|
120 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
gr.HTML(
|
122 |
"""
|
123 |
<head>
|
@@ -131,20 +146,20 @@ with gr.Blocks(theme="gstaff/sketch", title="DeepGit Lite", fill_width=True) as
|
|
131 |
# DeepGit Lite
|
132 |
Explore GitHub repositories with deep semantic search.
|
133 |
Check out our [GitHub](https://github.com/zamalali/DeepGit) for more details.
|
134 |
-
"""
|
|
|
135 |
)
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
detail_display = gr.HTML(label="Results")
|
148 |
|
149 |
run_button.click(
|
150 |
fn=lite_runner,
|
@@ -164,7 +179,7 @@ with gr.Blocks(theme="gstaff/sketch", title="DeepGit Lite", fill_width=True) as
|
|
164 |
|
165 |
gr.HTML(
|
166 |
"""
|
167 |
-
<div>
|
168 |
Made with ❤️ by <b>Zamal</b>
|
169 |
</div>
|
170 |
"""
|
|
|
2 |
import time
|
3 |
import threading
|
4 |
import logging
|
|
|
5 |
from main import run_repository_ranking # Your repository ranking function
|
6 |
+
import spaces
|
7 |
# ---------------------------
|
8 |
# Global Logging Buffer Setup
|
9 |
# ---------------------------
|
|
|
37 |
return filtered
|
38 |
|
39 |
def parse_result_to_html(raw_result: str) -> str:
|
40 |
+
"""
|
41 |
+
Parses the raw string output from run_repository_ranking to an HTML table.
|
42 |
+
Only the top 10 results are displayed.
|
43 |
+
"""
|
44 |
entries = raw_result.strip().split("Final Rank:")
|
45 |
+
# Only use the first 10 entries (if available)
|
46 |
+
entries = entries[1:11]
|
47 |
+
if not entries:
|
48 |
+
return "<p>No repositories found for your query.</p>"
|
49 |
html = """
|
50 |
+
<table border="1" style="width:80%; margin: auto; border-collapse: collapse;">
|
51 |
<thead>
|
52 |
<tr>
|
53 |
<th>Rank</th>
|
|
|
58 |
</thead>
|
59 |
<tbody>
|
60 |
"""
|
61 |
+
for entry in entries:
|
62 |
lines = entry.strip().split("\n")
|
63 |
data = {}
|
64 |
+
data["Final Rank"] = lines[0].strip() if lines else ""
|
65 |
for line in lines[1:]:
|
66 |
if ": " in line:
|
67 |
key, val = line.split(": ", 1)
|
|
|
121 |
yield status, details
|
122 |
|
123 |
# ---------------------------
|
124 |
+
# App UI Setup Using Gradio Soft Theme with Centered Layout
|
125 |
# ---------------------------
|
126 |
+
with gr.Blocks(
|
127 |
+
theme="gstaff/sketch",
|
128 |
+
title="DeepGit Lite",
|
129 |
+
css="""
|
130 |
+
/* Center header and footer */
|
131 |
+
#header { text-align: center; margin-bottom: 20px; }
|
132 |
+
#main-container { max-width: 800px; margin: auto; }
|
133 |
+
#footer { text-align: center; margin-top: 20px; }
|
134 |
+
"""
|
135 |
+
) as demo:
|
136 |
gr.HTML(
|
137 |
"""
|
138 |
<head>
|
|
|
146 |
# DeepGit Lite
|
147 |
Explore GitHub repositories with deep semantic search.
|
148 |
Check out our [GitHub](https://github.com/zamalali/DeepGit) for more details.
|
149 |
+
""",
|
150 |
+
elem_id="header"
|
151 |
)
|
152 |
|
153 |
+
# Centered main container for inputs and outputs.
|
154 |
+
with gr.Column(elem_id="main-container"):
|
155 |
+
research_input = gr.Textbox(
|
156 |
+
label="Research Query",
|
157 |
+
placeholder="Enter your research topic here, e.g., 'data augmentation pipelines for LLM fine-tuning'",
|
158 |
+
lines=3
|
159 |
+
)
|
160 |
+
run_button = gr.Button("Run DeepGit Lite", variant="primary")
|
161 |
+
status_display = gr.Markdown(label="Status")
|
162 |
+
detail_display = gr.HTML(label="Results")
|
|
|
163 |
|
164 |
run_button.click(
|
165 |
fn=lite_runner,
|
|
|
179 |
|
180 |
gr.HTML(
|
181 |
"""
|
182 |
+
<div id="footer">
|
183 |
Made with ❤️ by <b>Zamal</b>
|
184 |
</div>
|
185 |
"""
|
main.py
CHANGED
@@ -4,9 +4,10 @@ import requests
|
|
4 |
import numpy as np
|
5 |
import faiss
|
6 |
import re
|
7 |
-
|
8 |
-
from dotenv import load_dotenv
|
9 |
from pathlib import Path
|
|
|
|
|
10 |
from langchain_groq import ChatGroq
|
11 |
from langchain_core.prompts import ChatPromptTemplate
|
12 |
|
@@ -20,8 +21,10 @@ except ImportError:
|
|
20 |
# Environment Setup
|
21 |
# ---------------------------
|
22 |
load_dotenv()
|
|
|
|
|
23 |
|
24 |
-
# Setup a persistent session for GitHub API requests
|
25 |
session = requests.Session()
|
26 |
session.headers.update({
|
27 |
"Authorization": f"token {os.getenv('GITHUB_API_KEY')}",
|
@@ -29,7 +32,7 @@ session.headers.update({
|
|
29 |
})
|
30 |
|
31 |
# ---------------------------
|
32 |
-
# Langchain Groq Setup
|
33 |
# ---------------------------
|
34 |
llm = ChatGroq(
|
35 |
model="deepseek-r1-distill-llama-70b",
|
@@ -62,30 +65,44 @@ Rules:
|
|
62 |
- If your output does not strictly match the required format, correct it after your internal reasoning.
|
63 |
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
|
66 |
"""),
|
67 |
("human", "{query}")
|
68 |
])
|
69 |
chain = prompt | llm
|
70 |
|
71 |
-
def
|
72 |
"""
|
73 |
-
|
|
|
74 |
"""
|
75 |
-
|
76 |
-
|
77 |
-
end_index = response_str.index("</think>") + len("</think>")
|
78 |
-
tags = response_str[end_index:].strip()
|
79 |
-
return tags
|
80 |
-
else:
|
81 |
-
return response_str.strip()
|
82 |
|
83 |
-
def
|
84 |
"""
|
85 |
-
|
|
|
86 |
"""
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
|
91 |
print(f"\n🧠 [iterative_convert_to_search_tags] Input Query: {query}")
|
@@ -110,6 +127,7 @@ def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str
|
|
110 |
# GitHub API Helper Functions
|
111 |
# ---------------------------
|
112 |
def fetch_readme_content(repo_full_name):
|
|
|
113 |
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
|
114 |
response = session.get(readme_url)
|
115 |
if response.status_code == 200:
|
@@ -120,6 +138,30 @@ def fetch_readme_content(repo_full_name):
|
|
120 |
return ""
|
121 |
return ""
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def fetch_github_repositories(query, max_results=10):
|
124 |
"""
|
125 |
Searches GitHub repositories using the provided query and retrieves key information.
|
@@ -137,9 +179,8 @@ def fetch_github_repositories(query, max_results=10):
|
|
137 |
for repo in response.json().get('items', []):
|
138 |
repo_link = repo.get('html_url')
|
139 |
description = repo.get('description') or ""
|
140 |
-
|
141 |
-
|
142 |
-
combined_text = (description + "\n" + readme_content).strip()
|
143 |
repo_list.append({
|
144 |
"title": repo.get('name', 'No title available'),
|
145 |
"link": repo_link,
|
@@ -148,9 +189,9 @@ def fetch_github_repositories(query, max_results=10):
|
|
148 |
return repo_list
|
149 |
|
150 |
# ---------------------------
|
151 |
-
# Initialize SentenceTransformer Model
|
152 |
# ---------------------------
|
153 |
-
model = SentenceTransformer('all-
|
154 |
|
155 |
def robust_min_max_norm(scores):
|
156 |
"""
|
@@ -163,19 +204,65 @@ def robust_min_max_norm(scores):
|
|
163 |
return (scores - min_val) / (max_val - min_val)
|
164 |
|
165 |
# ---------------------------
|
166 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
# ---------------------------
|
168 |
def run_repository_ranking(query: str) -> str:
|
169 |
"""
|
170 |
Converts the user query into search tags, runs multiple GitHub queries (individual and combined),
|
171 |
-
deduplicates results, and applies hybrid
|
|
|
|
|
172 |
"""
|
173 |
# Step 1: Generate search tags from the query.
|
174 |
search_tags = iterative_convert_to_search_tags(query)
|
175 |
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
|
176 |
|
177 |
# Step 2: Handle target language extraction.
|
178 |
-
target_lang = None
|
179 |
if any(tag.startswith("target-") for tag in tag_list):
|
180 |
target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
|
181 |
target_lang = target_tag.replace("target-", "")
|
@@ -195,7 +282,7 @@ def run_repository_ranking(query: str) -> str:
|
|
195 |
repos = fetch_github_repositories(github_query, max_results=15)
|
196 |
all_repositories.extend(repos)
|
197 |
|
198 |
-
#
|
199 |
combined_query = " OR ".join(tag_list)
|
200 |
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
|
201 |
print("Combined GitHub Query:", combined_query)
|
@@ -208,7 +295,6 @@ def run_repository_ranking(query: str) -> str:
|
|
208 |
if repo["link"] not in unique_repositories:
|
209 |
unique_repositories[repo["link"]] = repo
|
210 |
else:
|
211 |
-
# Merge content if the repository appears in multiple queries.
|
212 |
existing_text = unique_repositories[repo["link"]]["combined_text"]
|
213 |
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
|
214 |
repositories = list(unique_repositories.values())
|
@@ -216,10 +302,10 @@ def run_repository_ranking(query: str) -> str:
|
|
216 |
if not repositories:
|
217 |
return "No repositories found for your query."
|
218 |
|
219 |
-
# Step 4: Prepare documents
|
220 |
docs = [repo.get("combined_text", "") for repo in repositories]
|
221 |
|
222 |
-
# Step 5:
|
223 |
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
|
224 |
if doc_embeddings.ndim == 1:
|
225 |
doc_embeddings = doc_embeddings.reshape(1, -1)
|
@@ -239,7 +325,7 @@ def run_repository_ranking(query: str) -> str:
|
|
239 |
dense_scores = D.squeeze()
|
240 |
norm_dense_scores = robust_min_max_norm(dense_scores)
|
241 |
|
242 |
-
# Step 6:
|
243 |
if BM25Okapi is not None:
|
244 |
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
|
245 |
bm25 = BM25Okapi(tokenized_docs)
|
@@ -249,22 +335,27 @@ def run_repository_ranking(query: str) -> str:
|
|
249 |
else:
|
250 |
norm_bm25_scores = np.zeros_like(norm_dense_scores)
|
251 |
|
252 |
-
# Step 7: Combine scores
|
253 |
-
alpha = 0.8
|
254 |
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
|
255 |
-
|
256 |
for idx, repo in enumerate(repositories):
|
257 |
repo["combined_score"] = float(combined_scores[idx])
|
258 |
|
259 |
-
# Step 8:
|
260 |
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
|
261 |
|
|
|
|
|
|
|
|
|
|
|
262 |
output = "\n=== Ranked Repositories ===\n"
|
263 |
-
for rank, repo in enumerate(
|
264 |
output += f"Final Rank: {rank}\n"
|
265 |
output += f"Title: {repo['title']}\n"
|
266 |
output += f"Link: {repo['link']}\n"
|
267 |
output += f"Combined Score: {repo.get('combined_score', 0):.4f}\n"
|
|
|
268 |
snippet = repo['combined_text'][:300].replace('\n', ' ')
|
269 |
output += f"Snippet: {snippet}...\n"
|
270 |
output += '-' * 80 + "\n"
|
@@ -275,6 +366,6 @@ def run_repository_ranking(query: str) -> str:
|
|
275 |
# Main Entry Point for Testing
|
276 |
# ---------------------------
|
277 |
if __name__ == "__main__":
|
278 |
-
test_query = "
|
279 |
result = run_repository_ranking(test_query)
|
280 |
print(result)
|
|
|
4 |
import numpy as np
|
5 |
import faiss
|
6 |
import re
|
7 |
+
import logging
|
|
|
8 |
from pathlib import Path
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
11 |
from langchain_groq import ChatGroq
|
12 |
from langchain_core.prompts import ChatPromptTemplate
|
13 |
|
|
|
21 |
# Environment Setup
|
22 |
# ---------------------------
|
23 |
load_dotenv()
|
24 |
+
# Set the cross-encoder model from environment or use a default SOTA model.
|
25 |
+
CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2")
|
26 |
|
27 |
+
# Setup a persistent session for GitHub API requests.
|
28 |
session = requests.Session()
|
29 |
session.headers.update({
|
30 |
"Authorization": f"token {os.getenv('GITHUB_API_KEY')}",
|
|
|
32 |
})
|
33 |
|
34 |
# ---------------------------
|
35 |
+
# Langchain Groq Setup for Search Tag Conversion
|
36 |
# ---------------------------
|
37 |
llm = ChatGroq(
|
38 |
model="deepseek-r1-distill-llama-70b",
|
|
|
65 |
- If your output does not strictly match the required format, correct it after your internal reasoning.
|
66 |
- Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories.
|
67 |
|
68 |
+
Excellent Examples:
|
69 |
+
|
70 |
+
Input: "No code tool to augment image and annotation"
|
71 |
+
Output: image-augmentation:albumentations
|
72 |
+
|
73 |
+
Input: "Repos around chain of thought prompting mainly for finetuned models"
|
74 |
+
Output: chain-of-thought:finetuned-llm
|
75 |
+
|
76 |
+
Input: "Find repositories implementing data augmentation pipelines in JavaScript"
|
77 |
+
Output: data-augmentation:target-javascript
|
78 |
+
|
79 |
Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations.
|
80 |
"""),
|
81 |
("human", "{query}")
|
82 |
])
|
83 |
chain = prompt | llm
|
84 |
|
85 |
+
def valid_tags(tags: str) -> bool:
|
86 |
"""
|
87 |
+
Validates that the output is one to six colon-separated tokens composed
|
88 |
+
of lowercase letters, numbers, and hyphens.
|
89 |
"""
|
90 |
+
pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$'
|
91 |
+
return re.match(pattern, tags) is not None
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
def parse_search_tags(response: str) -> str:
|
94 |
"""
|
95 |
+
Extracts a valid colon-separated tag string from the LLM response.
|
96 |
+
This function removes any chain-of-thought commentary.
|
97 |
"""
|
98 |
+
# Remove any text inside <think>...</think> blocks.
|
99 |
+
cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
|
100 |
+
# Use regex to find a valid tag pattern.
|
101 |
+
pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})'
|
102 |
+
match = re.search(pattern, cleaned)
|
103 |
+
if match:
|
104 |
+
return match.group(1).strip()
|
105 |
+
return cleaned.strip()
|
106 |
|
107 |
def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str:
|
108 |
print(f"\n🧠 [iterative_convert_to_search_tags] Input Query: {query}")
|
|
|
127 |
# GitHub API Helper Functions
|
128 |
# ---------------------------
|
129 |
def fetch_readme_content(repo_full_name):
|
130 |
+
"""Fetch the README content (if available) using the GitHub API."""
|
131 |
readme_url = f"https://api.github.com/repos/{repo_full_name}/readme"
|
132 |
response = session.get(readme_url)
|
133 |
if response.status_code == 200:
|
|
|
138 |
return ""
|
139 |
return ""
|
140 |
|
141 |
+
def fetch_markdown_contents(repo_full_name):
|
142 |
+
"""
|
143 |
+
Fetch all markdown files (except the README already fetched) from the root of the repository.
|
144 |
+
"""
|
145 |
+
url = f"https://api.github.com/repos/{repo_full_name}/contents"
|
146 |
+
response = session.get(url)
|
147 |
+
contents = ""
|
148 |
+
if response.status_code == 200:
|
149 |
+
items = response.json()
|
150 |
+
for item in items:
|
151 |
+
if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"):
|
152 |
+
file_url = item.get("download_url")
|
153 |
+
if file_url:
|
154 |
+
file_resp = requests.get(file_url)
|
155 |
+
if file_resp.status_code == 200:
|
156 |
+
contents += "\n" + file_resp.text
|
157 |
+
return contents
|
158 |
+
|
159 |
+
def fetch_all_markdown(repo_full_name):
|
160 |
+
"""Combine README with all markdown contents from the repository root."""
|
161 |
+
readme = fetch_readme_content(repo_full_name)
|
162 |
+
other_md = fetch_markdown_contents(repo_full_name)
|
163 |
+
return readme + "\n" + other_md
|
164 |
+
|
165 |
def fetch_github_repositories(query, max_results=10):
|
166 |
"""
|
167 |
Searches GitHub repositories using the provided query and retrieves key information.
|
|
|
179 |
for repo in response.json().get('items', []):
|
180 |
repo_link = repo.get('html_url')
|
181 |
description = repo.get('description') or ""
|
182 |
+
combined_markdown = fetch_all_markdown(repo.get('full_name'))
|
183 |
+
combined_text = (description + "\n" + combined_markdown).strip()
|
|
|
184 |
repo_list.append({
|
185 |
"title": repo.get('name', 'No title available'),
|
186 |
"link": repo_link,
|
|
|
189 |
return repo_list
|
190 |
|
191 |
# ---------------------------
|
192 |
+
# Initialize SentenceTransformer Model for Dense Retrieval
|
193 |
# ---------------------------
|
194 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
195 |
|
196 |
def robust_min_max_norm(scores):
|
197 |
"""
|
|
|
204 |
return (scores - min_val) / (max_val - min_val)
|
205 |
|
206 |
# ---------------------------
|
207 |
+
# Cross-Encoder Re-Ranking Function
|
208 |
+
# ---------------------------
|
209 |
+
def cross_encoder_rerank_candidates(candidates, query, model_name, top_n=10):
|
210 |
+
"""
|
211 |
+
Re-ranks candidate repositories using a cross-encoder model.
|
212 |
+
For long documents, the text is split into chunks and scores are aggregated.
|
213 |
+
"""
|
214 |
+
cross_encoder = CrossEncoder(model_name)
|
215 |
+
CHUNK_SIZE = 2000 # characters per chunk
|
216 |
+
MAX_DOC_LENGTH = 5000 # cap for long docs
|
217 |
+
MIN_DOC_LENGTH = 200 # threshold for short docs
|
218 |
+
|
219 |
+
def split_text(text, chunk_size=CHUNK_SIZE):
|
220 |
+
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
221 |
+
|
222 |
+
for candidate in candidates:
|
223 |
+
doc = candidate.get("combined_text", "")
|
224 |
+
if len(doc) > MAX_DOC_LENGTH:
|
225 |
+
doc = doc[:MAX_DOC_LENGTH]
|
226 |
+
try:
|
227 |
+
if len(doc) < MIN_DOC_LENGTH:
|
228 |
+
score = cross_encoder.predict([[query, doc]])
|
229 |
+
candidate["cross_encoder_score"] = float(score[0])
|
230 |
+
else:
|
231 |
+
chunks = split_text(doc)
|
232 |
+
pairs = [[query, chunk] for chunk in chunks]
|
233 |
+
scores = cross_encoder.predict(pairs)
|
234 |
+
max_score = np.max(scores) if len(scores) > 0 else 0.0
|
235 |
+
avg_score = np.mean(scores) if len(scores) > 0 else 0.0
|
236 |
+
candidate["cross_encoder_score"] = float(0.5 * max_score + 0.5 * avg_score)
|
237 |
+
except Exception as e:
|
238 |
+
logging.error(f"Error scoring candidate {candidate.get('link', 'unknown')}: {e}")
|
239 |
+
candidate["cross_encoder_score"] = 0.0
|
240 |
+
|
241 |
+
all_scores = [candidate["cross_encoder_score"] for candidate in candidates]
|
242 |
+
if all_scores:
|
243 |
+
min_score = min(all_scores)
|
244 |
+
if min_score < 0:
|
245 |
+
for candidate in candidates:
|
246 |
+
candidate["cross_encoder_score"] += -min_score
|
247 |
+
|
248 |
+
reranked = sorted(candidates, key=lambda x: x["cross_encoder_score"], reverse=True)
|
249 |
+
return reranked[:top_n]
|
250 |
+
|
251 |
+
# ---------------------------
|
252 |
+
# Main Function: Repository Ranking with Hybrid Retrieval and Cross-Encoder Re-Ranking
|
253 |
# ---------------------------
|
254 |
def run_repository_ranking(query: str) -> str:
|
255 |
"""
|
256 |
Converts the user query into search tags, runs multiple GitHub queries (individual and combined),
|
257 |
+
deduplicates results, and applies a hybrid ranking strategy:
|
258 |
+
- Dense embeddings (via SentenceTransformer) combined with BM25 scoring.
|
259 |
+
- Re-ranks top candidates using a cross-encoder for improved contextual alignment.
|
260 |
"""
|
261 |
# Step 1: Generate search tags from the query.
|
262 |
search_tags = iterative_convert_to_search_tags(query)
|
263 |
tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()]
|
264 |
|
265 |
# Step 2: Handle target language extraction.
|
|
|
266 |
if any(tag.startswith("target-") for tag in tag_list):
|
267 |
target_tag = next(tag for tag in tag_list if tag.startswith("target-"))
|
268 |
target_lang = target_tag.replace("target-", "")
|
|
|
282 |
repos = fetch_github_repositories(github_query, max_results=15)
|
283 |
all_repositories.extend(repos)
|
284 |
|
285 |
+
# Combined query using OR logic.
|
286 |
combined_query = " OR ".join(tag_list)
|
287 |
combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}"
|
288 |
print("Combined GitHub Query:", combined_query)
|
|
|
295 |
if repo["link"] not in unique_repositories:
|
296 |
unique_repositories[repo["link"]] = repo
|
297 |
else:
|
|
|
298 |
existing_text = unique_repositories[repo["link"]]["combined_text"]
|
299 |
unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"]
|
300 |
repositories = list(unique_repositories.values())
|
|
|
302 |
if not repositories:
|
303 |
return "No repositories found for your query."
|
304 |
|
305 |
+
# Step 4: Prepare documents.
|
306 |
docs = [repo.get("combined_text", "") for repo in repositories]
|
307 |
|
308 |
+
# Step 5: Dense retrieval.
|
309 |
doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16)
|
310 |
if doc_embeddings.ndim == 1:
|
311 |
doc_embeddings = doc_embeddings.reshape(1, -1)
|
|
|
325 |
dense_scores = D.squeeze()
|
326 |
norm_dense_scores = robust_min_max_norm(dense_scores)
|
327 |
|
328 |
+
# Step 6: BM25 scoring.
|
329 |
if BM25Okapi is not None:
|
330 |
tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs]
|
331 |
bm25 = BM25Okapi(tokenized_docs)
|
|
|
335 |
else:
|
336 |
norm_bm25_scores = np.zeros_like(norm_dense_scores)
|
337 |
|
338 |
+
# Step 7: Combine scores.
|
339 |
+
alpha = 0.8
|
340 |
combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores
|
|
|
341 |
for idx, repo in enumerate(repositories):
|
342 |
repo["combined_score"] = float(combined_scores[idx])
|
343 |
|
344 |
+
# Step 8: Initial ranking.
|
345 |
ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True)
|
346 |
|
347 |
+
# Step 9: Cross-Encoder Re-Ranking.
|
348 |
+
top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories
|
349 |
+
final_ranked = cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=10)
|
350 |
+
|
351 |
+
# Step 10: Format output.
|
352 |
output = "\n=== Ranked Repositories ===\n"
|
353 |
+
for rank, repo in enumerate(final_ranked, 1):
|
354 |
output += f"Final Rank: {rank}\n"
|
355 |
output += f"Title: {repo['title']}\n"
|
356 |
output += f"Link: {repo['link']}\n"
|
357 |
output += f"Combined Score: {repo.get('combined_score', 0):.4f}\n"
|
358 |
+
output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0):.4f}\n"
|
359 |
snippet = repo['combined_text'][:300].replace('\n', ' ')
|
360 |
output += f"Snippet: {snippet}...\n"
|
361 |
output += '-' * 80 + "\n"
|
|
|
366 |
# Main Entry Point for Testing
|
367 |
# ---------------------------
|
368 |
if __name__ == "__main__":
|
369 |
+
test_query = "Chain of thought prompting for reasoning models"
|
370 |
result = run_repository_ranking(test_query)
|
371 |
print(result)
|