File size: 1,021 Bytes
498ffec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import re


def extract_judge_hash(response):
    """
    checklist ๋ณ„๋กœ yes, in, no๋ฅผ ํŒ๋‹จํ•œ ์ •๋ณด๋ฅผ hash ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ˜ํ™˜
    """
    content = response['response']
    
    try:
        judge_content = content.lower().replace(' ', '').split('<answer>')[1].split('</answer>')[0]
    except:
        import traceback
        traceback.print_exc()
        return None
    pattern = r":yes|:inprogress|:no"
    matches = re.findall(pattern, judge_content)
    matches = [{':yes': 'y', ':inprogress': 'i', ':no': 'n'}[match] for match in matches]
    return ''.join(matches)

def average_logits(response):
    """
    yes, in, no๋ฅผ logits ๋ ˆ๋ฒจ์—์„œ ๊ณ„์‚ฐ.
    """
    judge_probs = response['judge_probs']
    
    yes_ = np.mean([r['yes'] for r in judge_probs])
    in_ = np.mean([r['in'] for r in judge_probs])
    
    reward = yes_ + 0.5 * in_
    return reward


REWARD_PROCESSORS = {
    'avg_logits': average_logits
}

REWARD_PROCESSOR_N_SAMPLES = {
    'avg_logits': 5
}