File size: 3,303 Bytes
498ffec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import time
from typing import List, Dict, Any, Optional, Union
import numpy as np
from .mini_bench.reward_agent import ProgressJudgeAgent
from .reward_postprocessor import REWARD_PROCESSORS, REWARD_PROCESSOR_N_SAMPLES, extract_judge_hash
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed

def _process_unit(idx, unit, configs, n_samples, reward_processor, max_retries=5):
    """하나의 unit을 처리해 (idx, reward, thought)를 돌려준다."""
    agent = ProgressJudgeAgent(configs)
    current_temperature = configs["temperature"]

    rewards = []
    n_err = 0
    retry_count = 0
    judge_hash_count_thought = {}

    while len(rewards) < n_samples and retry_count < max_retries:
        # 외부 API 호출
        responses, _ = agent.generate_probs(
            unit, "ours", n=n_samples - len(rewards), temperature=current_temperature
        )

        for response in responses:
            content = response["response"]
            thought = content                          # 전체를 로그로 저장
            reward = REWARD_PROCESSORS[reward_processor](response)
            rewards.append(reward)

            if np.isnan(reward) or reward is None:
                n_err += 1
            else:
                judge_hash = extract_judge_hash(response)
                judge_hash_count_thought[judge_hash] = (judge_hash_count_thought.get(judge_hash, (0, None))[0] + 1, thought)

        if n_err > 0:
            # 실패 시 온도를 높여 재시도
            if n_samples == 1:
                current_temperature = 0.5
        retry_count += 1

    reward = np.nanmean(rewards)
    if np.isnan(reward):
        print(f"[idx={idx}] Warning: reward is NaN after retries -> set 0")
        reward = 0.0
    print(judge_hash_count_thought)
    thought = max(judge_hash_count_thought.values(), key=lambda x: x[0])[1]

    return idx, reward, thought


def get_ar_reward(dataset, base_url, model_name, reward_processor='avg_logits', max_workers=8):
    """원본 get_ar_reward를 스레드 버전으로 교체."""
    n_samples = REWARD_PROCESSOR_N_SAMPLES[reward_processor]

    temperature = 0.5 if n_samples > 1 else 0.0

    configs = {
        "model_name": model_name,
        "base_url": base_url,
        "api_key": "empty",
        "temperature": temperature,
        "num_generate": 1,
        "use_checklist": True,
        "input_type": "text_only",
        "text_obs_type": "axtree",
        "image_obs_type": "som",
        "use_in_progress": True,
        "use_multimodal": False,
        "use_log_probs": True,
    }

    t_start = time.time()
    results = [None] * len(dataset)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(
                _process_unit, idx, unit, configs, n_samples, reward_processor
            )
            for idx, unit in enumerate(dataset)
        ]

        for fut in as_completed(futures):
            idx, reward, thought = fut.result()
            results[idx] = (reward, thought)

    # 순서 보존된 리스트로 분리
    final_rewards = [float(r) for r, _ in results]
    thoughts = [t for _, t in results]

    print(f"Time taken (threaded): {time.time() - t_start:.2f} s")
    return final_rewards, thoughts