|
import os |
|
import streamlit as st |
|
import openai |
|
import pandas as pd |
|
import time |
|
from typing import List, Tuple |
|
from uuid import uuid4 |
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
def get_session_id(): |
|
if 'session_id' not in st.session_state: |
|
st.session_state.session_id = str(uuid4()) |
|
return st.session_state.session_id |
|
|
|
|
|
class SelfTaughtReasoner: |
|
def __init__(self, model_engine="text-davinci-003"): |
|
self.model_engine = model_engine |
|
self.prompt_examples = [] |
|
self.iterations = 0 |
|
self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct']) |
|
self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct']) |
|
self.fine_tuned_model = None |
|
|
|
def add_prompt_example(self, problem: str, rationale: str, answer: str): |
|
""" |
|
Adds a prompt example to the few-shot examples. |
|
""" |
|
self.prompt_examples.append({ |
|
'Problem': problem, |
|
'Rationale': rationale, |
|
'Answer': answer |
|
}) |
|
|
|
def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> str: |
|
""" |
|
Constructs the prompt for the OpenAI API call. |
|
""" |
|
prompt = "" |
|
for example in self.prompt_examples: |
|
prompt += f"Problem: {example['Problem']}\n" |
|
prompt += f"Rationale: {example['Rationale']}\n" |
|
prompt += f"Answer: {example['Answer']}\n\n" |
|
|
|
prompt += f"Problem: {problem}\n" |
|
if include_answer: |
|
prompt += f"Answer (as hint): {answer}\n" |
|
prompt += "Rationale:" |
|
return prompt |
|
|
|
def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]: |
|
""" |
|
Generates a rationale and answer for a given problem. |
|
""" |
|
prompt = self.construct_prompt(problem) |
|
try: |
|
response = openai.Completion.create( |
|
engine=self.model_engine, |
|
prompt=prompt, |
|
max_tokens=150, |
|
temperature=0.7, |
|
top_p=1, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=["\n\n", "Problem:", "Answer:"] |
|
) |
|
rationale = response.choices[0].text.strip() |
|
|
|
prompt += f" {rationale}\nAnswer:" |
|
answer_response = openai.Completion.create( |
|
engine=self.model_engine, |
|
prompt=prompt, |
|
max_tokens=10, |
|
temperature=0, |
|
top_p=1, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=["\n", "\n\n", "Problem:"] |
|
) |
|
answer = answer_response.choices[0].text.strip() |
|
return rationale, answer |
|
except Exception as e: |
|
st.error(f"Error generating rationale and answer: {e}") |
|
return "", "" |
|
|
|
def rationalize(self, problem: str, correct_answer: str) -> Tuple[str, str]: |
|
""" |
|
Generates a rationale for a given problem using the correct answer as a hint. |
|
""" |
|
prompt = self.construct_prompt(problem, include_answer=True, answer=correct_answer) |
|
try: |
|
response = openai.Completion.create( |
|
engine=self.model_engine, |
|
prompt=prompt, |
|
max_tokens=150, |
|
temperature=0.7, |
|
top_p=1, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=["\n\n", "Problem:", "Answer:"] |
|
) |
|
rationale = response.choices[0].text.strip() |
|
|
|
prompt += f" {rationale}\nAnswer:" |
|
answer_response = openai.Completion.create( |
|
engine=self.model_engine, |
|
prompt=prompt, |
|
max_tokens=10, |
|
temperature=0, |
|
top_p=1, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=["\n", "\n\n", "Problem:"] |
|
) |
|
answer = answer_response.choices[0].text.strip() |
|
return rationale, answer |
|
except Exception as e: |
|
st.error(f"Error during rationalization: {e}") |
|
return "", "" |
|
|
|
def fine_tune_model(self): |
|
""" |
|
Fine-tunes the model on the generated rationales. |
|
This is a placeholder function as fine-tuning would require |
|
training a new model which is beyond the scope of this app. |
|
""" |
|
|
|
|
|
|
|
time.sleep(1) |
|
self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}" |
|
st.success(f"Model fine-tuned: {self.fine_tuned_model}") |
|
|
|
def run_iteration(self, dataset: pd.DataFrame): |
|
""" |
|
Runs one iteration of the STaR process. |
|
""" |
|
st.write(f"### Iteration {self.iterations + 1}") |
|
progress_bar = st.progress(0) |
|
total = len(dataset) |
|
for idx, row in dataset.iterrows(): |
|
problem = row['Problem'] |
|
correct_answer = row['Answer'] |
|
|
|
rationale, answer = self.generate_rationale_and_answer(problem) |
|
is_correct = (answer.lower() == correct_answer.lower()) |
|
|
|
self.generated_data = self.generated_data.append({ |
|
'Problem': problem, |
|
'Rationale': rationale, |
|
'Answer': answer, |
|
'Is_Correct': is_correct |
|
}, ignore_index=True) |
|
|
|
if not is_correct: |
|
rationale, answer = self.rationalize(problem, correct_answer) |
|
is_correct = (answer.lower() == correct_answer.lower()) |
|
if is_correct: |
|
self.rationalized_data = self.rationalized_data.append({ |
|
'Problem': problem, |
|
'Rationale': rationale, |
|
'Answer': answer, |
|
'Is_Correct': is_correct |
|
}, ignore_index=True) |
|
progress_bar.progress((idx + 1) / total) |
|
|
|
st.write("Fine-tuning the model on correct rationales...") |
|
self.fine_tune_model() |
|
self.iterations += 1 |
|
|
|
|
|
def main(): |
|
st.title("Self-Taught Reasoner (STaR) Demonstration") |
|
st.write(""" |
|
This app demonstrates the Self-Taught Reasoner (STaR) workflow. Enter problems to solve, and see how the model generates rationales, filters correct answers, and fine-tunes itself iteratively. |
|
""") |
|
|
|
|
|
if 'star' not in st.session_state: |
|
st.session_state.star = SelfTaughtReasoner() |
|
|
|
star = st.session_state.star |
|
|
|
|
|
st.header("Step 1: Add Few-Shot Prompt Examples") |
|
st.write("Provide a few examples with problems, rationales, and answers to bootstrap the reasoning process.") |
|
|
|
with st.form(key='prompt_form'): |
|
example_problem = st.text_area("Example Problem", height=50) |
|
example_rationale = st.text_area("Example Rationale", height=100) |
|
example_answer = st.text_input("Example Answer") |
|
submit_example = st.form_submit_button("Add Example") |
|
|
|
if submit_example: |
|
if not example_problem or not example_rationale or not example_answer: |
|
st.warning("Please fill in all fields to add an example.") |
|
else: |
|
star.add_prompt_example(example_problem, example_rationale, example_answer) |
|
st.success("Example added.") |
|
|
|
if star.prompt_examples: |
|
st.subheader("Current Prompt Examples:") |
|
for idx, example in enumerate(star.prompt_examples): |
|
st.write(f"**Example {idx + 1}:**") |
|
st.write(f"Problem: {example['Problem']}") |
|
st.write(f"Rationale: {example['Rationale']}") |
|
st.write(f"Answer: {example['Answer']}") |
|
|
|
|
|
st.header("Step 2: Input Dataset") |
|
st.write("Provide a dataset of problems and correct answers for the STaR process.") |
|
|
|
dataset_input_method = st.radio("How would you like to input the dataset?", ("Manual Entry", "Upload CSV")) |
|
|
|
if dataset_input_method == "Manual Entry": |
|
with st.form(key='dataset_form'): |
|
dataset_problems = st.text_area("Enter problems and answers in the format 'Problem | Answer', one per line.", height=200) |
|
submit_dataset = st.form_submit_button("Submit Dataset") |
|
|
|
if submit_dataset: |
|
if not dataset_problems: |
|
st.warning("Please enter at least one problem and answer.") |
|
else: |
|
dataset = [] |
|
lines = dataset_problems.strip().split('\n') |
|
for line in lines: |
|
if '|' in line: |
|
problem, answer = line.split('|', 1) |
|
dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()}) |
|
else: |
|
st.error(f"Invalid format in line: {line}") |
|
if dataset: |
|
st.session_state.dataset = pd.DataFrame(dataset) |
|
st.success("Dataset loaded.") |
|
else: |
|
uploaded_file = st.file_uploader("Upload a CSV file with 'Problem' and 'Answer' columns.", type=['csv']) |
|
if uploaded_file: |
|
try: |
|
st.session_state.dataset = pd.read_csv(uploaded_file) |
|
if 'Problem' not in st.session_state.dataset.columns or 'Answer' not in st.session_state.dataset.columns: |
|
st.error("CSV must contain 'Problem' and 'Answer' columns.") |
|
del st.session_state.dataset |
|
else: |
|
st.success("Dataset loaded.") |
|
except Exception as e: |
|
st.error(f"Error loading CSV: {e}") |
|
|
|
if 'dataset' in st.session_state: |
|
st.subheader("Current Dataset:") |
|
st.dataframe(st.session_state.dataset.head()) |
|
|
|
|
|
st.header("Step 3: Run STaR Process") |
|
num_iterations = st.number_input("Number of Iterations to Run:", min_value=1, max_value=10, value=1) |
|
run_star = st.button("Run STaR") |
|
|
|
if run_star: |
|
if not star.prompt_examples: |
|
st.warning("Please add at least one prompt example before running STaR.") |
|
elif not openai.api_key: |
|
st.warning("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
|
else: |
|
for _ in range(num_iterations): |
|
star.run_iteration(st.session_state.dataset) |
|
|
|
st.header("Results") |
|
st.subheader("Generated Data") |
|
st.dataframe(star.generated_data) |
|
|
|
st.subheader("Rationalized Data") |
|
st.dataframe(star.rationalized_data) |
|
|
|
st.write("The model has been fine-tuned iteratively. You can now test it with new problems.") |
|
|
|
|
|
st.header("Step 4: Test the Fine-Tuned Model") |
|
test_problem = st.text_area("Enter a new problem to solve:", height=100) |
|
test_button = st.button("Solve Problem") |
|
|
|
if test_button: |
|
if not test_problem: |
|
st.warning("Please enter a problem to solve.") |
|
elif not star.fine_tuned_model: |
|
st.warning("The model has not been fine-tuned yet. Please run the STaR process first.") |
|
else: |
|
|
|
|
|
st.write("Generating rationale and answer using the fine-tuned model...") |
|
rationale, answer = star.generate_rationale_and_answer(test_problem) |
|
st.subheader("Rationale:") |
|
st.write(rationale) |
|
st.subheader("Answer:") |
|
st.write(answer) |
|
|
|
|
|
st.write("---") |
|
st.write("Developed as a demonstration of the STaR method.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|