Spaces:
Sleeping
Sleeping
File size: 6,144 Bytes
26d32ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import streamlit as st
from langchain_community.llms import OpenAI
import argparse
from datasets import load_dataset
import yaml
from tqdm import tqdm
import re
def load_data(split="test"):
data = load_dataset("bigcode/humanevalpack")
print("=========== dataset statistics ===========")
print(len(data[split]))
print("==========================================")
return data[split]
def split_function_header_and_docstring(s):
# pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
pattern = re.compile(r"(\"\"\"(.*?)\"\"\"|\'\'\'(.*?)\'\'\')", re.DOTALL)
match = pattern.findall(s)
if match:
# docstring = match.group(-1)
docstring = match[-1][0]
code_without_docstring = s.replace(docstring, "").replace('"' * 6, "").strip()
docstring = docstring.replace('"', "")
else:
raise ValueError
return code_without_docstring, docstring
def prepare_model_input(code_data):
prompt = """Provide feedback on the errors in the given code and suggest the
correct code to address the described problem.
Problem Description:
{description}
Incorrect Code:
{wrong_code}"""
description = code_data["prompt"]
function_header, docstring = split_function_header_and_docstring(description)
problem = docstring.split(">>>")[0]
wrong_code = function_header + code_data["buggy_solution"]
template_dict = {"function_header": function_header, "description": problem, "wrong_code": wrong_code}
model_input = prompt.format(**template_dict)
return model_input, problem, function_header
def load_and_prepare_data():
dataset = load_data()
all_model_inputs = {}
print("### load and prepare data")
for data in tqdm(dataset):
problem_id = data['task_id']
buggy_solution = data['buggy_solution']
model_input, problem, function_header = prepare_model_input(data)
new_model_input =f"Provide feedback on the errors in the given code and suggest the correct code to address the described problem.\nProblem Description:{problem}\nIncorrect Code:\n{buggy_solution}\nFeedback:"
# data["header"] = function_header
all_model_inputs[problem_id] = {
"model_input": new_model_input,
"header": function_header,
"problem_description": problem,
"data": data
}
return all_model_inputs
dataset = load_dataset("bigcode/humanevalpack", split='test', trust_remote_code=True) # Ensuring consistent split usage
problem_ids = [problem['task_id'] for problem in dataset]
all_model_inputs = load_and_prepare_data()
# Initialize with dummy ports for demonstration purposes here
parser = argparse.ArgumentParser()
parser.add_argument("--editor_port", type=str, default="6000")
parser.add_argument("--critic_port", type=str, default="6001")
# Assuming args are passed via command line interface
args = parser.parse_args()
# Initialize Langchain LLMs for our models (please replace 'your_api_key' with actual API keys)
editor_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-editor", api_key="EMPTY", openai_api_base=f"https://editor.jp.ngrok.io/v1")
# critic_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-critic", api_key="EMPTY", openai_api_base=f"http://localhost:{args.critic_port}/v1")
critic_model = OpenAI(model="Anonymous-COFFEE/COFFEEPOTS-critic", api_key="EMPTY", openai_api_base=f"https://critic.jp.ngrok.io/v1")
st.title("Demo for COFFEEPOTS")
selected_task_id = st.selectbox("Select a problem ID:", problem_ids)
# Retrieve selected problem details
problem_details = dataset[problem_ids.index(selected_task_id)]
st.write(f"**Selected Problem ID:** {problem_details['task_id']}")
st.write(f"**Problem Description:**\n{all_model_inputs[selected_task_id]['problem_description']}")
# Display buggy code with syntax highlighting
st.code(problem_details['buggy_solution'], language='python')
status_text = st.empty()
code_output = st.code("", language="python")
def generate_feedback():
return critic_model.stream(input=f"{all_model_inputs[selected_task_id]['model_input']}", logit_bias=None)
# feedback = output.generations[0][0].text
# return feedback
# def generate_corrected_code():
# return "```python"+editor_model.stream(input=f"Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}", logit_bias=None)
def generate_corrected_code():
# Stream output from the editor model
yield "```python"
for text_chunk in editor_model.stream(input=f"[INST]Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}[/INST]", logit_bias=None):
yield text_chunk # Assuming each chunk is part of the final code
yield "```"
# time.sleep(0.02) # Simulate processing delay; Adjust timing as necessary
if st.button("Generate Feedback and Corrected Code"):
# Example of generating feedback and corrected code (replace these with actual model calls)
with st.spinner("Generating feedback..."):
# Simulate API call to critic_model
print(f"model input for critic:")
print(all_model_inputs[selected_task_id]['model_input'])
# output = critic_model.generate(prompts=[f"{all_model_inputs[selected_task_id]['model_input']}"], logit_bias=None)
# feedback = output.generations[0][0].text
# print(feedback)
# feedback = "dummy feedback"
# status_text.markdown(f"{feedback}")
feedback = status_text.write_stream(generate_feedback())
# status_text.code(f"{feedback}", language='python')
with st.spinner("Generating corrected code..."):
# Simulate API call to editor_model
# output = editor_model.generate(prompts=[f"Buggy Code:\n{problem_details['buggy_solution']}\nFeedback: {feedback}"], logit_bias=None)
# corrected_code = output.generations[0][0].text
# print(corrected_code)
# corrected_code = "dummy code"
# st.write("**Corrected Code:**")
corrected_code = code_output.write_stream(generate_corrected_code())
# code_output.code(corrected_code, language='python')
|