test_wprm3 / agent /reward_postprocessor.py
iruno's picture
Upload 245 files
498ffec verified
import numpy as np
import re
def extract_judge_hash(response):
"""
checklist ๋ณ„๋กœ yes, in, no๋ฅผ ํŒ๋‹จํ•œ ์ •๋ณด๋ฅผ hash ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋ฐ˜ํ™˜
"""
content = response['response']
try:
judge_content = content.lower().replace(' ', '').split('<answer>')[1].split('</answer>')[0]
except:
import traceback
traceback.print_exc()
return None
pattern = r":yes|:inprogress|:no"
matches = re.findall(pattern, judge_content)
matches = [{':yes': 'y', ':inprogress': 'i', ':no': 'n'}[match] for match in matches]
return ''.join(matches)
def average_logits(response):
"""
yes, in, no๋ฅผ logits ๋ ˆ๋ฒจ์—์„œ ๊ณ„์‚ฐ.
"""
judge_probs = response['judge_probs']
yes_ = np.mean([r['yes'] for r in judge_probs])
in_ = np.mean([r['in'] for r in judge_probs])
reward = yes_ + 0.5 * in_
return reward
REWARD_PROCESSORS = {
'avg_logits': average_logits
}
REWARD_PROCESSOR_N_SAMPLES = {
'avg_logits': 5
}