File size: 6,062 Bytes
fc35a48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""
This file come from: https://github.com/microsoft/ToRA/blob/main/src/utils/python_executor.py
"""
import io
import regex
import pickle
import traceback
import copy
import datetime
import multiprocessing
import dateutil.relativedelta
import multiprocess
from multiprocess import Pool
from typing import Any, Dict, Optional
from pebble import ProcessPool
from tqdm import tqdm
from concurrent.futures import TimeoutError
from functools import partial
from timeout_decorator import timeout
from contextlib import redirect_stdout
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
raise RuntimeError()
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
def inject(self, var_dict: Dict[str, Any]) -> None:
for k, v in var_dict.items():
self._global_vars[k] = v
@property
def answer(self):
return self._global_vars['answer']
class DateRuntime(GenericRuntime):
GLOBAL_DICT = {
'datetime': datetime.datetime,
'timedelta': dateutil.relativedelta.relativedelta,
'relativedelta': dateutil.relativedelta.relativedelta
}
class CustomDict(dict):
def __iter__(self):
return list(super().__iter__()).__iter__()
class ColorObjectRuntime(GenericRuntime):
GLOBAL_DICT = {'dict': CustomDict}
class PythonExecutor:
def __init__(
self,
runtime: Optional[Any] = None,
get_answer_symbol: Optional[str] = None,
get_answer_expr: Optional[str] = None,
get_answer_from_stdout: bool = False,
timeout_length: int = 5,
) -> None:
self.runtime = runtime if runtime else GenericRuntime()
self.answer_symbol = get_answer_symbol
self.answer_expr = get_answer_expr
self.get_answer_from_stdout = get_answer_from_stdout
self.timeout_length = timeout_length
def process_generation_to_code(self, gens: str):
return [g.split('\n') for g in gens]
@staticmethod
def execute(
code,
get_answer_from_stdout = None,
runtime = None,
answer_symbol = None,
answer_expr = None,
timeout_length = 10,
):
try:
if get_answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
program_io.seek(0)
result = program_io.readlines()[-1]
elif answer_symbol:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
result = runtime._global_vars[answer_symbol]
elif answer_expr:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
else:
timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
result = timeout(timeout_length)(runtime.eval_code)(code[-1])
exec_info = "Done"
str(result)
pickle.dumps(result) # serialization check
except:
result = ''
exec_info = traceback.format_exc().split('\n')[-2]
return result, exec_info
def apply(self, code):
return self.batch_apply([code])[0]
def batch_apply(self, batch_code):
all_code_snippets = self.process_generation_to_code(batch_code)
timeout_cnt = 0
all_exec_results = []
with ProcessPool(max_workers=min(len(all_code_snippets), multiprocessing.cpu_count())) as pool:
executor = partial(
self.execute,
get_answer_from_stdout=self.get_answer_from_stdout,
runtime=self.runtime,
answer_symbol=self.answer_symbol,
answer_expr=self.answer_expr,
timeout_length=self.timeout_length, # this timeout not work
)
future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
iterator = future.result()
if len(all_code_snippets) > 100:
progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
else:
progress_bar = None
while True:
try:
result = next(iterator)
all_exec_results.append(result)
except StopIteration:
break
except TimeoutError as error:
print(error)
all_exec_results.append(("", "Timeout Error"))
timeout_cnt += 1
except Exception as error:
print(error)
exit()
if progress_bar is not None:
progress_bar.update(1)
if progress_bar is not None:
progress_bar.close()
batch_results = []
for code, (result, exec_info) in zip(all_code_snippets, all_exec_results):
batch_results.append((result, exec_info))
return batch_results
def _test():
batch_code = [
"""
print("Hello world!")
"""
]
executor = PythonExecutor(get_answer_from_stdout=True)
predictions = executor.apply(batch_code[0])
print(predictions)
if __name__ == '__main__':
_test() |