FelixPhilip commited on
Commit
dc9d402
·
1 Parent(s): 3868d8d
Files changed (1) hide show
  1. Oracle/deepfundingoracle.py +33 -28
Oracle/deepfundingoracle.py CHANGED
@@ -20,6 +20,7 @@ import time
20
  import threading
21
  import logging
22
  import concurrent.futures
 
23
  import signal
24
  from tqdm import tqdm
25
  import sys
@@ -168,7 +169,10 @@ def fetch_github_features(df):
168
  def timeout_handler(signum, frame):
169
  raise TimeoutError("LLama model prediction timed out.")
170
 
171
- def assign_base_weight(df):
 
 
 
172
  print("[INFO] Starting base weight assignment using LLama model...", flush=True)
173
  logging.info("[INFO] Assigning base weights using LLama model...")
174
  start_time = time.time()
@@ -176,10 +180,10 @@ def assign_base_weight(df):
176
  base_weights = []
177
  llm_cache = {}
178
 
179
- for idx, row in tqdm(df.iterrows(), total=len(df), desc="Assigning weights"):
 
 
180
  repo = row.get("repo", "")
181
- print(f"[INFO] Assigning weight for repository {idx + 1}/{len(df)}: {repo}", flush=True)
182
- logging.info(f"[INFO] Processing repository {idx + 1}/{len(df)}: {repo}")
183
  parent = row.get("parent", "")
184
  stars = row.get("stars", 0)
185
  forks = row.get("forks", 0)
@@ -187,7 +191,7 @@ def assign_base_weight(df):
187
  issues = row.get("open_issues", 0)
188
  pulls = row.get("pulls", 0)
189
  activity = row.get("activity", "")
190
- prompt = (
191
  f"Repository: {repo}\n"
192
  f"GitHub Metrics: {stars} stars, {forks} forks, {watchers} watchers, {issues} open issues, {pulls} pull requests, activity: {activity}.\n"
193
  f"Parent or dependency: {parent}\n\n"
@@ -195,32 +199,33 @@ def assign_base_weight(df):
195
  "that reflects how influential the repository is as a source relative to its parent. "
196
  "Only output the numeric value."
197
  )
 
 
 
 
 
198
  try:
199
- if repo in llm_cache:
200
- weight = llm_cache[repo]
201
- else:
202
- print(f"[INFO] Sending prompt to LLama model for repo: {repo}", flush=True)
203
- start_llama_time = time.time()
204
- response = llama.predict(prompt)
205
- # Use regex to extract the first valid float from the response
206
- match = re.search(r"[-+]?\d*\.\d+|\d+", response)
207
- if match:
208
- weight = float(match.group())
209
- weight = min(max(weight, 0), 1)
210
- else:
211
- raise ValueError(f"No valid float found in response: {response}")
212
- end_llama_time = time.time()
213
- print(f"[INFO] Received weight {weight} for {repo} in {end_llama_time - start_llama_time:.2f} seconds.", flush=True)
214
- logging.info(f"[INFO] Processed repository {repo} in {end_llama_time - start_llama_time:.2f} seconds. Weight: {weight}")
215
- llm_cache[repo] = weight
216
  except Exception as e:
217
- print(f"[ERROR] Failed to process repository {repo}: {e}", flush=True)
218
- logging.error(f"[ERROR] Failed to process repository {repo}: {e}")
219
- weight = 0.0 # Default weight in case of failure (set to 0 for no work)
220
- base_weights.append(weight)
221
- print(f"[PROGRESS] Finished {idx + 1}/{len(df)} repositories.", flush=True)
 
 
 
 
 
 
 
 
 
222
 
223
- df["base_weight"] = base_weights
224
  end_time = time.time()
225
  print(f"[INFO] Base weights assigned successfully in {end_time - start_time:.2f} seconds.", flush=True)
226
  logging.info(f"[INFO] Base weights assigned successfully in {end_time - start_time:.2f} seconds.")
 
20
  import threading
21
  import logging
22
  import concurrent.futures
23
+ from concurrent.futures import ThreadPoolExecutor
24
  import signal
25
  from tqdm import tqdm
26
  import sys
 
169
  def timeout_handler(signum, frame):
170
  raise TimeoutError("LLama model prediction timed out.")
171
 
172
+ def assign_base_weight(df, max_workers=8):
173
+ """
174
+ Assign base weights using LLama model in parallel.
175
+ """
176
  print("[INFO] Starting base weight assignment using LLama model...", flush=True)
177
  logging.info("[INFO] Assigning base weights using LLama model...")
178
  start_time = time.time()
 
180
  base_weights = []
181
  llm_cache = {}
182
 
183
+ # Prepare prompts for all repositories
184
+ prompts = {}
185
+ for idx, row in df.iterrows():
186
  repo = row.get("repo", "")
 
 
187
  parent = row.get("parent", "")
188
  stars = row.get("stars", 0)
189
  forks = row.get("forks", 0)
 
191
  issues = row.get("open_issues", 0)
192
  pulls = row.get("pulls", 0)
193
  activity = row.get("activity", "")
194
+ prompts[idx] = (
195
  f"Repository: {repo}\n"
196
  f"GitHub Metrics: {stars} stars, {forks} forks, {watchers} watchers, {issues} open issues, {pulls} pull requests, activity: {activity}.\n"
197
  f"Parent or dependency: {parent}\n\n"
 
199
  "that reflects how influential the repository is as a source relative to its parent. "
200
  "Only output the numeric value."
201
  )
202
+
203
+ # Define the prediction function
204
+ def _predict(idx, prompt):
205
+ if idx in llm_cache:
206
+ return idx, llm_cache[idx]
207
  try:
208
+ resp = llama.predict(prompt)
209
+ match = re.search(r"[-+]?\d*\.\d+|\d+", resp)
210
+ weight = min(max(float(match.group()), 0), 1) if match else 0.0
211
+ llm_cache[idx] = weight
212
+ return idx, weight
 
 
 
 
 
 
 
 
 
 
 
 
213
  except Exception as e:
214
+ print(f"[ERROR] Failed to process repository {idx}: {e}", flush=True)
215
+ logging.error(f"[ERROR] Failed to process repository {idx}: {e}")
216
+ return idx, 0.0 # Default weight in case of failure
217
+
218
+ # Run predictions in parallel
219
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
220
+ futures = [executor.submit(_predict, idx, prompt) for idx, prompt in prompts.items()]
221
+ for fut in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="LLM Prompts"):
222
+ idx, weight = fut.result()
223
+ base_weights.append((idx, weight))
224
+
225
+ # Sort weights by index and assign to DataFrame
226
+ base_weights.sort(key=lambda x: x[0])
227
+ df["base_weight"] = [weight for _, weight in base_weights]
228
 
 
229
  end_time = time.time()
230
  print(f"[INFO] Base weights assigned successfully in {end_time - start_time:.2f} seconds.", flush=True)
231
  logging.info(f"[INFO] Base weights assigned successfully in {end_time - start_time:.2f} seconds.")