Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from together import Together | |
from dotenv import load_dotenv | |
from datasets import load_dataset | |
import json | |
import re | |
import os | |
from config import DATASETS, MODELS | |
import matplotlib.pyplot as plt | |
import altair as alt | |
load_dotenv() | |
client = Together(api_key=os.getenv('TOGETHERAI_API_KEY')) | |
def load_dataset_by_name(dataset_name, split="train"): | |
dataset_config = DATASETS[dataset_name] | |
dataset = load_dataset(dataset_config["loader"]) | |
df = pd.DataFrame(dataset[split]) | |
df = df[df['choice_type'] == 'single'] | |
questions = [] | |
for _, row in df.iterrows(): | |
options = [row['opa'], row['opb'], row['opc'], row['opd']] | |
correct_answer = options[row['cop']] | |
question_dict = { | |
'question': row['question'], | |
'options': options, | |
'correct_answer': correct_answer, | |
'subject_name': row['subject_name'], | |
'topic_name': row['topic_name'], | |
'explanation': row['exp'] | |
} | |
questions.append(question_dict) | |
st.write(f"Loaded {len(questions)} single-select questions from {dataset_name}") | |
return questions | |
def get_model_response(question, options, prompt_template, model_name): | |
try: | |
model_config = MODELS[model_name] | |
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)]) | |
prompt = prompt_template.replace("{question}", question).replace("{options}", options_text) | |
response = client.chat.completions.create( | |
model=model_config["model_id"], | |
messages=[{"role": "user", "content": prompt}] | |
) | |
response_text = response.choices[0].message.content.strip() | |
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) | |
json_response = json.loads(json_match.group(0)) | |
answer = json_response['answer'].strip() | |
answer = re.sub(r'^[A-D]\.\s*', '', answer) | |
if not any(answer.lower() == opt.lower() for opt in options): | |
return f"Error: Answer '{answer}' does not match any options" | |
return answer | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def evaluate_response(model_response, correct_answer): | |
if model_response.startswith("Error:"): | |
return False | |
return model_response.lower().strip() == correct_answer.lower().strip() | |
def main(): | |
st.set_page_config(page_title="LLM Benchmarking in Healthcare", layout="wide") | |
st.title("LLM Benchmarking in Healthcare") | |
if 'all_results' not in st.session_state: | |
st.session_state.all_results = {} | |
if 'detailed_model' not in st.session_state: | |
st.session_state.detailed_model = None | |
if 'detailed_dataset' not in st.session_state: | |
st.session_state.detailed_dataset = None | |
if 'last_evaluated_dataset' not in st.session_state: | |
st.session_state.last_evaluated_dataset = None | |
col1, col2 = st.columns(2) | |
with col1: | |
selected_dataset = st.selectbox( | |
"Select Dataset", | |
options=list(DATASETS.keys()), | |
help="Choose the dataset to evaluate on" | |
) | |
with col2: | |
selected_model = st.multiselect( | |
"Select Model(s)", | |
options=list(MODELS.keys()), | |
default=[list(MODELS.keys())[0]], | |
help="Choose one or more models to evaluate." | |
) | |
models_to_evaluate = selected_model | |
default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question. | |
Question: {question} | |
Options: | |
{options} | |
## Output Format: | |
Please provide you answer in JSON format that contains an "answer" field. | |
You may include any additional fields in your JSON response that you find relevant, such as: | |
- "choice reasoning": your detailed reasoning | |
- "elimination reasoning": why you ruled out other options | |
Example response format: | |
{ | |
"answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx) ", | |
"choice reasoning": "your detailed reasoning here", | |
"elimination reasoning": "why you ruled out other options" | |
} | |
Important: | |
- Only the "answer" field will be used for evaluation | |
- Ensure your response is in valid JSON format''' | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
prompt_template = st.text_area( | |
"Customize Prompt Template", | |
default_prompt, | |
height=400, | |
help="The below prompt is editable. Please feel free to edit it before your run." | |
) | |
with col2: | |
st.markdown(""" | |
### Prompt Variables | |
- `{question}`: The medical question | |
- `{options}`: The multiple choice options | |
""") | |
with st.spinner("Loading dataset..."): | |
questions = load_dataset_by_name(selected_dataset) | |
if not questions: | |
st.error("No questions were loaded successfully.") | |
return | |
subjects = list(set(q['subject_name'] for q in questions)) | |
selected_subject = st.selectbox("Filter by subject", ["All"] + subjects) | |
if selected_subject != "All": | |
questions = [q for q in questions if q['subject_name'] == selected_subject] | |
num_questions = st.number_input("Number of questions to evaluate", 1, len(questions)) | |
if st.button("Start Evaluation"): | |
if not os.getenv('TOGETHERAI_API_KEY'): | |
st.error("Please set the TOGETHERAI_API_KEY in your .env file") | |
return | |
progress_container = st.container() | |
with progress_container: | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
substatus_text = st.empty() | |
results_container = st.container() | |
all_results = {} | |
total_iterations = len(models_to_evaluate) * num_questions | |
current_iteration = 0 | |
for model_name in models_to_evaluate: | |
substatus_text.markdown(f"<small>Evaluating model: {model_name} on {selected_dataset}</small>", unsafe_allow_html=True) | |
results = [] | |
for i in range(num_questions): | |
question = questions[i] | |
current_iteration += 1 | |
progress = current_iteration / total_iterations | |
progress_bar.progress(progress) | |
status_text.text(f"Progress: {current_iteration}/{total_iterations} evaluations") | |
model_response = get_model_response( | |
question['question'], | |
question['options'], | |
prompt_template, | |
model_name | |
) | |
options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(question['options'])]) | |
formatted_prompt = prompt_template.replace("{question}", question['question']).replace("{options}", options_text) | |
raw_response = client.chat.completions.create( | |
model=MODELS[model_name]["model_id"], | |
messages=[{"role": "user", "content": formatted_prompt}], | |
temperature=0.7 | |
).choices[0].message.content.strip() | |
is_correct = evaluate_response(model_response, question['correct_answer']) | |
results.append({ | |
'question': question['question'], | |
'options': question['options'], | |
'model_response': model_response, | |
'raw_llm_response': raw_response, | |
'prompt_sent': formatted_prompt, | |
'correct_answer': question['correct_answer'], | |
'subject': question['subject_name'], | |
'is_correct': is_correct, | |
'explanation': question['explanation'] | |
}) | |
all_results[model_name] = results | |
st.session_state.all_results = all_results | |
st.session_state.last_evaluated_dataset = selected_dataset | |
if st.session_state.detailed_model is None and all_results: | |
st.session_state.detailed_model = list(all_results.keys())[0] | |
if st.session_state.detailed_dataset is None: | |
st.session_state.detailed_dataset = selected_dataset | |
st.rerun() | |
if st.session_state.all_results: | |
st.subheader("Evaluation Results") | |
model_metrics = {} | |
for model_name, results in st.session_state.all_results.items(): | |
df = pd.DataFrame(results) | |
metrics = { | |
'Accuracy': df['is_correct'].mean(), | |
} | |
model_metrics[model_name] = metrics | |
metrics_df = pd.DataFrame(model_metrics).T | |
st.subheader("Model Performance Comparison") | |
accuracy_chart = alt.Chart( | |
metrics_df.reset_index().melt(id_vars=['index'], value_vars=['Accuracy']) | |
).mark_bar().encode( | |
x=alt.X('index:N', title=None, axis=None), | |
y=alt.Y('value:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])), | |
color='index:N' | |
).properties( | |
height=300, | |
title={ | |
"text": "Model Accuracy", | |
"baseline": "bottom", | |
"orient": "bottom", | |
"dy": 20 | |
} | |
) | |
st.altair_chart(accuracy_chart, use_container_width=True) | |
if st.session_state.all_results: | |
st.subheader("Detailed Results") | |
def update_model(): | |
st.session_state.detailed_model = st.session_state.model_select | |
def update_dataset(): | |
st.session_state.detailed_dataset = st.session_state.dataset_select | |
col1, col2 = st.columns(2) | |
with col1: | |
selected_model_details = st.selectbox( | |
"Select model", | |
options=list(st.session_state.all_results.keys()), | |
key="model_select", | |
on_change=update_model, | |
index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model) | |
if st.session_state.detailed_model in st.session_state.all_results else 0 | |
) | |
with col2: | |
selected_dataset_details = st.selectbox( | |
"Select dataset", | |
options=[st.session_state.last_evaluated_dataset], | |
key="dataset_select", | |
on_change=update_dataset | |
) | |
if selected_model_details in st.session_state.all_results: | |
results = st.session_state.all_results[selected_model_details] | |
df = pd.DataFrame(results) | |
accuracy = df['is_correct'].mean() | |
st.metric("Accuracy", f"{accuracy:.2%}") | |
for idx, result in enumerate(results): | |
with st.expander(f"Question {idx + 1} - {result['subject']}"): | |
st.write("Question:", result['question']) | |
st.write("Options:") | |
for i, opt in enumerate(result['options']): | |
st.write(f"{chr(65+i)}. {opt}") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("Prompt Used:") | |
st.code(result['prompt_sent']) | |
with col2: | |
st.write("Raw Response:") | |
st.code(result['raw_llm_response']) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("Correct Answer:", result['correct_answer']) | |
st.write("Model Answer:", result['model_response']) | |
with col2: | |
if result['is_correct']: | |
st.success("Correct!") | |
else: | |
st.error("Incorrect") | |
st.write("Explanation:", result['explanation']) | |
else: | |
st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.") | |
st.markdown("---") | |
all_data = [] | |
first_model = list(st.session_state.all_results.keys())[0] | |
base_results = st.session_state.all_results[first_model] | |
for question_idx in range(len(base_results)): | |
row = { | |
'dataset': selected_dataset_details, | |
'question': base_results[question_idx]['question'], | |
'correct_answer': base_results[question_idx]['correct_answer'], | |
'subject': base_results[question_idx]['subject'], | |
'options': ' | '.join(base_results[question_idx]['options']) | |
} | |
for model_name in st.session_state.all_results.keys(): | |
model_results = st.session_state.all_results[model_name] | |
row[f'{model_name}_response'] = model_results[question_idx]['model_response'] | |
row[f'{model_name}_is_correct'] = model_results[question_idx]['is_correct'] | |
all_data.append(row) | |
complete_df = pd.DataFrame(all_data) | |
csv = complete_df.to_csv(index=False) | |
st.download_button( | |
label="Download All Results as CSV", | |
data=csv, | |
file_name=f"all_models_{selected_dataset_details}_results.csv", | |
mime="text/csv", | |
key="download_all_results" | |
) | |
if __name__ == "__main__": | |
main() |