import time from typing import List, Dict, Any, Optional, Union import numpy as np from .mini_bench.reward_agent import ProgressJudgeAgent from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash import json import os from concurrent.futures import ThreadPoolExecutor, as_completed def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5): """하나의 unit을 처리해 (idx, reward, thought)를 돌려준다.""" agent = ProgressJudgeAgent(configs) current_temperature = configs["temperature"] rewards = [] n_err = 0 retry_count = 0 judge_hash_count_thought = {} while len(rewards) < n_samples and retry_count < max_retries: # 외부 API 호출 responses, _ = agent.generate_probs( unit, "ours", n=n_samples - len(rewards), temperature=current_temperature ) for response in responses: content = response["response"] thought = content # 전체를 로그로 저장 reward = REWARD_PROCESSORS[reward_processor](response) rewards.append(reward) if np.isnan(reward) or reward is None: n_err += 1 else: judge_hash = extract_judge_hash(response) judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought) if n_err > 0: # 실패 시 온도를 높여 재시도 if n_samples == 1: current_temperature = 0.5 retry_count += 1 reward = np.nanmean(rewards) if np.isnan(reward): print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0") reward = 0.0 print(judge_hash_count_thought) thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1] return idx, reward, thought def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8): """원본 get_ar_reward를 스레드 버전으로 교체.""" n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor] temperature = 0.5 if n_samples > 1 else 0.0 configs = { "model_name": model_name, "base_url": base_url, "api_key": "empty", "temperature": temperature, "num_generate": 1, "use_checklist": True, "input_type": "text_only", "text_obs_type": "axtree", "image_obs_type": "som", "use_in_progress": True, "use_multimodal": False, "use_log_probs": True, } t_start = time.time() results = [None] * len(dataset) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit( _process_unit, idx, unit, configs, n_samples, reward_processor ) for idx, unit in enumerate(dataset) ] for fut in as_completed(futures): idx, reward, thought = fut.result() results[idx] = (reward, thought) # 순서 보존된 리스트로 분리 final_rewards = [float(r) for r, _ in results] thoughts = [t for _, t in results] print(f"Time taken (threaded): {time.time() - t_start:.2f} s") return final_rewards, thoughts