Jintonic92's picture
Update src/ThirdModule/module3.py
a0a2f13 verified
# # module3.py
import re
import requests
from typing import Optional, Tuple
import logging
from dotenv import load_dotenv
import os
from collections import Counter
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load .env file
load_dotenv()
# Hugging Face API information
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
API_KEY = os.getenv("HUGGINGFACE_API_KEY")
if not API_KEY:
raise ValueError("API_KEY๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. .env ํŒŒ์ผ์„ ํ™•์ธํ•˜์„ธ์š”.")
class AnswerVerifier:
def verify_answer(self, question: str, choices: dict, num_checks: int = 5) -> Tuple[Optional[str], str]:
"""
Self-consistency approach๋ฅผ ํ™œ์šฉํ•œ ๋‹ต๋ณ€ ๊ฒ€์ฆ
num_checks: ๋™์ผ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ๋ฐ˜๋ณต ๊ฒ€์ฆํ•  ํšŸ์ˆ˜
๋ฐ˜ํ™˜๊ฐ’: (๊ฒ€์ฆ๋œ ๋‹ต์•ˆ, ์„ค๋ช…) ํŠœํ”Œ
"""
try:
answers = []
for i, _ in enumerate(range(num_checks)):
prompt = self._create_prompt(question, choices)
headers = {"Authorization": f"Bearer {API_KEY}"}
response = requests.post(
API_URL,
headers=headers,
json={"inputs": prompt}
)
response.raise_for_status()
response_data = response.json()
logger.debug(f"Raw API response: {response_data}")
# API ์‘๋‹ต ์ฒ˜๋ฆฌ
generated_text = self._process_response(response_data)
logger.debug(f"Trial {i+1}:")
logger.debug(f"Generated text: {generated_text}")
answer = self._extract_answer(generated_text)
logger.debug(f"Extracted answer: {answer}")
if answer:
answers.append(answer)
if not answers:
return None, "No valid answers extracted"
# # ๋‹ค์ˆ˜๊ฒฐ ํˆฌํ‘œ๋กœ ์ตœ์ข… ๋‹ต์•ˆ ๊ฒฐ์ •
# final_answer, explanation = self._get_majority_vote(answers)
# logger.info(f"Final verified answer: {final_answer} ({explanation})")
# return final_answer, explanation
# Return only the final answer instead of a tuple
final_answer, explanation = self._get_majority_vote(answers)
logger.info(f"Final verified answer: {final_answer} ({explanation})")
return final_answer # ๊ธฐ์กด: return final_answer, explanation
except Exception as e:
logger.error(f"Error in verify_answer: {e}")
return None, f"Error occurred: {str(e)}"
def _create_prompt(self, question: str, choices: dict) -> str:
"""๊ฐœ์„ ๋œ ํ”„๋กฌํ”„ํŠธ - ๋” ๋ช…ํ™•ํ•œ ์‘๋‹ต ํ˜•์‹ ์š”๊ตฌ"""
return f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an expert mathematics teacher evaluating multiple-choice answers.
Analyze the question and options carefully to select the correct answer.
IMPORTANT: You must respond ONLY with "Answer: X" where X is A, B, C, or D.
Do not include any explanation or additional text.
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Question: {question}
Options:
A) {choices['A']}
B) {choices['B']}
C) {choices['C']}
D) {choices['D']}
Provide your answer in the format: "Answer: X" (where X is A, B, C, or D)
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""".strip()
def _process_response(self, response_data) -> str:
"""API ์‘๋‹ต ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ - ๊ฐœ์„ ๋œ ๋ฒ„์ „"""
generated_text = ""
if isinstance(response_data, list):
if response_data and isinstance(response_data[0], dict):
generated_text = response_data[0].get('generated_text', '')
else:
generated_text = response_data[0] if response_data else ''
elif isinstance(response_data, dict):
generated_text = response_data.get('generated_text', '')
else:
generated_text = str(response_data)
# assistant ์‘๋‹ต ๋ถ€๋ถ„๋งŒ ์ถ”์ถœ
parts = generated_text.split('<|start_header_id|>assistant<|end_header_id|>')
if len(parts) > 1:
return parts[-1].strip()
return generated_text.strip()
def _extract_answer(self, response: str) -> Optional[str]:
"""๊ฐœ์„ ๋œ ๋‹ต์•ˆ ์ถ”์ถœ ๋กœ์ง"""
response = response.strip().upper()
# 1. "ANSWER: X" ํ˜•์‹ ์ฐพ๊ธฐ
answer_pattern = r'(?:ANSWER:|CORRECT ANSWER:)\s*([ABCD])'
answer_match = re.search(answer_pattern, response)
if answer_match:
return answer_match.group(1)
# 2. ๋‹จ๋…์œผ๋กœ ์žˆ๋Š” A, B, C, D ์ฐพ๊ธฐ
standalone_pattern = r'\b([ABCD])\b'
matches = re.findall(standalone_pattern, response)
# ๋งˆ์ง€๋ง‰์— ์žˆ๋Š” ๋‹ต์•ˆ ๋ฐ˜ํ™˜ (์ผ๋ฐ˜์ ์œผ๋กœ ์ตœ์ข… ๋‹ต์•ˆ์ด ๋งˆ์ง€๋ง‰์— ์œ„์น˜)
if matches:
return matches[-1]
return None
def _get_majority_vote(self, answers: list) -> Tuple[str, str]:
"""๊ฐœ์„ ๋œ ๋‹ค์ˆ˜๊ฒฐ ํˆฌํ‘œ ์‹œ์Šคํ…œ"""
if not answers:
return None, "No valid answers extracted"
counter = Counter(answers)
# ๋™์ ์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
max_count = max(counter.values())
top_answers = [ans for ans, count in counter.items() if count == max_count]
if len(top_answers) > 1:
return None, f"Tie between answers: {top_answers}"
final_answer = counter.most_common(1)[0][0]
total_votes = len(answers)
confidence = (counter[final_answer] / total_votes) * 100
# ์‹ ๋ขฐ๋„ ์ž„๊ณ„๊ฐ’ ์„ค์ •
if confidence < 60:
return None, f"Low confidence ({confidence:.1f}%) for answer {final_answer}"
explanation = (f"Answer '{final_answer}' selected with {confidence:.1f}% confidence "
f"({counter[final_answer]}/{total_votes} votes). "
f"Distribution: {dict(counter)}")
return final_answer, explanation