{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PSET 1: Bottom-Up Synthesis\n", "\n", "I follow Algorithm 1 in the BUSTLE paper:\n", "\n", "> Odena, A. *et al.* BUSTLE: Bottom-Up Program Synthesis Through Learning-Guided Exploration. in *9th International Conference on Learning Representations*; 2021 May 3-7; Austria.\n", "\n", "First, I import the required libraries." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import itertools\n", "\n", "# argument parser for command line arguments\n", "import argparse\n", "\n", "# import arithmetic module\n", "# from arithmetic import *\n", "# from abstract_syntax_tree import OperatorNode\n", "from examples import example_set, check_examples\n", "import config" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "24" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])\n", "subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])\n", "multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])\n", "multiply_node.evaluate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, I define variables as proxies for command-line arguments provided to the synthesizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "domain = \"arithmetic\"\n", "examples_key = \"addition\"\n", "examples = example_set[examples_key]\n", "max_weight = 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, I define a function to check that, across all input-output pairs, all inputs are of the same length and that argument types are consistent across inputs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I provide examples of arithmetic operations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class IntegerVariable:\n", " '''\n", " Class to represent an integer variable. Note that position is the position of the variable in the input.\n", " For example, if the input is [4, 5, 6] and the variable is the third element (i.e., 6), then position = 2.\n", " '''\n", " def __init__(self, position):\n", " self.value = None # value of the variable, initially None\n", " self.position = position # position of the variable in the arguments to program\n", " self.type = int # type of the variable\n", "\n", " def assign(self, value):\n", " self.value = value\n", "\n", "class IntegerConstant:\n", " '''\n", " Class to represent an integer constant.\n", " '''\n", " def __init__(self, value):\n", " self.value = value # value of the constant\n", " self.type = int # type of the constant\n", "\n", "class Add:\n", " '''\n", " Operator to add two numerical values.\n", " '''\n", " def __init__(self):\n", " self.arity = 2 # number of arguments\n", " self.arg_types = [int, int] # argument types\n", " self.return_type = int # return type\n", " self.weight = 1 # weight\n", "\n", " def __call__(self, x, y):\n", " return x + y\n", " \n", " def str(x, y):\n", " return f\"{x} + {y}\"\n", "\n", "class Subtract:\n", " '''\n", " Operator to subtract two numerical values.\n", " '''\n", " def __init__(self):\n", " self.arity = 2 # number of arguments\n", " self.arg_types = [int, int] # argument types\n", " self.return_type = int # return type\n", " self.weight = 1 # weight\n", "\n", " def __call__(self, x, y):\n", " return x - y\n", " \n", " def str(x, y):\n", " return f\"{x} - {y}\"\n", " \n", "class Multiply:\n", " '''\n", " Operator to multiply two numerical values.\n", " '''\n", " def __init__(self):\n", " self.arity = 2 # number of arguments\n", " self.arg_types = [int, int] # argument types\n", " self.return_type = int # return type\n", " self.weight = 1 # weight\n", "\n", " def __call__(self, x, y):\n", " return x * y\n", " \n", " def str(x, y):\n", " return f\"{x} * {y}\" \n", "\n", "class Divide:\n", " '''\n", " Operator to divide two numerical values.\n", " '''\n", " def __init__(self):\n", " self.arity = 2 # number of arguments\n", " self.arg_types = [int, int] # argument types\n", " self.return_type = int # return type\n", " self.weight = 1 # weight\n", "\n", " def __call__(self, x, y):\n", " try: # check for division by zero error\n", " return x / y\n", " except ZeroDivisionError:\n", " return None\n", " \n", " def str(x, y):\n", " return f\"{x} / {y}\"\n", "\n", "\n", "'''\n", "GLOBAL CONSTANTS\n", "''' \n", "\n", "# define operators\n", "arithmetic_operators = [Add(), Subtract(), Multiply(), Divide()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I define a function to extract constants from examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def extract_constants(examples):\n", " '''\n", " Extracts the constants from the input-output examples. Also constructs variables as needed\n", " based on the input-output examples, and adds them to the list of constants.\n", " '''\n", "\n", " # check validity of provided examples\n", " # if valid, extract arity and argument types\n", " arity, arg_types = check_examples(examples)\n", "\n", " # initialize list of constants\n", " constants = []\n", "\n", " # get unique set of inputs\n", " inputs = [input for example in examples for input in example[0]]\n", " inputs = set(inputs)\n", "\n", " # add 1 to the set of inputs\n", " inputs.add(1)\n", "\n", " # extract constants in input\n", " for input in inputs:\n", "\n", " if type(input) == int:\n", " constants.append(IntegerConstant(input))\n", " elif type(input) == str:\n", " # constants.append(StringConstant(input))\n", " pass\n", " else:\n", " raise Exception(\"Input of unknown type.\")\n", " \n", " # initialize list of variables\n", " variables = []\n", "\n", " # extract variables in input\n", " for position, arg in enumerate(arg_types):\n", " if arg == int:\n", " variables.append(IntegerVariable(position))\n", " elif arg == str:\n", " # variables.append(StringVariable(position))\n", " pass\n", " else:\n", " raise Exception(\"Input of unknown type.\")\n", "\n", " return constants + variables" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# initialize program bank\n", "program_bank = extract_constants(examples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I define a function to determine observational equivalence." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def observationally_equivalent(a, b):\n", " \"\"\"\n", " Returns True if a and b are observationally equivalent, False otherwise.\n", " \"\"\"\n", "\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, I define the bottom-up synthesis algorithm." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# iterate over each level\n", "for i in range(2, max_weight):\n", "\n", " # define level program bank\n", " level_program_bank = []\n", "\n", " for op in arithmetic_operators:\n", "\n", " break" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }