Spaces:
Sleeping
Sleeping
File size: 3,303 Bytes
498ffec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import time
from typing import List, Dict, Any, Optional, Union
import numpy as np
from .mini_bench.reward_agent import ProgressJudgeAgent
from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5):
"""하나의 unit을 처리해 (idx, reward, thought)를 돌려준다."""
agent = ProgressJudgeAgent(configs)
current_temperature = configs["temperature"]
rewards = []
n_err = 0
retry_count = 0
judge_hash_count_thought = {}
while len(rewards) < n_samples and retry_count < max_retries:
# 외부 API 호출
responses, _ = agent.generate_probs(
unit, "ours", n=n_samples - len(rewards), temperature=current_temperature
)
for response in responses:
content = response["response"]
thought = content # 전체를 로그로 저장
reward = REWARD_PROCESSORS[reward_processor](response)
rewards.append(reward)
if np.isnan(reward) or reward is None:
n_err += 1
else:
judge_hash = extract_judge_hash(response)
judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought)
if n_err > 0:
# 실패 시 온도를 높여 재시도
if n_samples == 1:
current_temperature = 0.5
retry_count += 1
reward = np.nanmean(rewards)
if np.isnan(reward):
print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0")
reward = 0.0
print(judge_hash_count_thought)
thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1]
return idx, reward, thought
def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8):
"""원본 get_ar_reward를 스레드 버전으로 교체."""
n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor]
temperature = 0.5 if n_samples > 1 else 0.0
configs = {
"model_name": model_name,
"base_url": base_url,
"api_key": "empty",
"temperature": temperature,
"num_generate": 1,
"use_checklist": True,
"input_type": "text_only",
"text_obs_type": "axtree",
"image_obs_type": "som",
"use_in_progress": True,
"use_multimodal": False,
"use_log_probs": True,
}
t_start = time.time()
results = [None] * len(dataset)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [
executor.submit(
_process_unit, idx, unit, configs, n_samples, reward_processor
)
for idx, unit in enumerate(dataset)
]
for fut in as_completed(futures):
idx, reward, thought = fut.result()
results[idx] = (reward, thought)
# 순서 보존된 리스트로 분리
final_rewards = [float(r) for r, _ in results]
thoughts = [t for _, t in results]
print(f"Time taken (threaded): {time.time() - t_start:.2f} s")
return final_rewards, thoughts
|