import argparse from openai import OpenAI import ast import cairosvg import json import os import utils import traceback from dotenv import load_dotenv from PIL import Image from prompts import sketch_first_prompt, system_prompt, gt_example def call_argparse(): parser = argparse.ArgumentParser(description='Process Arguments') # General parser.add_argument('--concept_to_draw', type=str, default="cat") parser.add_argument('--path2save', type=str, default=f"results/test") parser.add_argument('--temperature', type=float, default=0.3) parser.add_argument('--model', type=str, default='o3') parser.add_argument('--gen_mode', type=str, default='generation', choices=['generation', 'completion']) # Grid params parser.add_argument('--res', type=int, default=50, help="the resolution of the grid is set to 50x50") parser.add_argument('--cell_size', type=int, default=12, help="size of each cell in the grid") parser.add_argument('--stroke_width', type=float, default=7.0) args = parser.parse_args() args.grid_size = (args.res + 1) * args.cell_size args.save_name = args.concept_to_draw.replace(" ", "_") args.path2save = f"{args.path2save}/{args.save_name}" if not os.path.exists(args.path2save): os.makedirs(args.path2save) with open(f"{args.path2save}/experiment_log.json", 'w') as json_file: json.dump([], json_file, indent=4) return args class SketchApp: def __init__(self, args): # General self.path2save = args.path2save self.target_concept = args.concept_to_draw # Grid related self.res = args.res self.num_cells = args.res self.cell_size = args.cell_size self.grid_size = (args.grid_size, args.grid_size) self.init_canvas, self.positions = utils.create_grid_image(res=args.res, cell_size=args.cell_size, header_size=args.cell_size) self.init_canvas_str = utils.image_to_str(self.init_canvas) self.cells_to_pixels_map = utils.cells_to_pixels(args.res, args.cell_size, header_size=args.cell_size) # SVG related self.stroke_width = args.stroke_width # LLM Setup (you need to provide your OPENAI_API_KEY in your .env file) # self.cache = False self.max_tokens = 3000 openai_key = os.getenv("OPENAI_API_KEY") self.client = OpenAI(api_key=openai_key) self.model = "gpt-4o" self.input_prompt = sketch_first_prompt.format(concept=args.concept_to_draw, gt_sketches_str=gt_example) self.gen_mode = args.gen_mode self.temperature = args.temperature def call_llm(self, system_message, other_msg, additional_args): response = self.client.chat.completions.create( model=self.model, messages=[{"role": "system", "content": system_message}] + other_msg, max_tokens=self.max_tokens, temperature=self.temperature, stop=additional_args.get("stop", None) ) return response.choices[0].message.content def define_input_to_llm(self, msg_history, init_canvas_str, msg): content = [] if init_canvas_str is not None: content.append({ "type": "image_url", "image_url": "data:image/jpeg;base64," + init_canvas_str }) content.append({"type": "text", "text": msg}) other_msg = msg_history + [{"role": "user", "content": content}] return other_msg def get_response_from_llm( self, msg, system_message, msg_history=[], init_canvas_str=None, prefill_msg=None, # seed_mode="stochastic", stop=None, gen_mode="generation" ): additional_args = {} # if seed_mode == "deterministic": #additional_args["temperature"] = 0.0 #additional_args["top_k"] = 1 # if self.cache: # system_message = [{ # "type": "text", # "text": system_message, # "cache_control": {"type": "ephemeral"} # }] # other_msg should contain all messgae without the system prompt other_msg = self.define_input_to_llm(msg_history, init_canvas_str, msg) if gen_mode == "completion": if prefill_msg: other_msg = other_msg + [{"role": "assistant", "content": f"{prefill_msg}"}] # In case of stroke by stroke generation if stop: additional_args["stop"]= stop else: additional_args["stop"]= [""] response = self.call_llm(system_message, other_msg, additional_args) content = response if gen_mode == "completion": other_msg = other_msg[:-1] # remove initial assistant prompt content = f"{prefill_msg}{content}" # saves to json if self.path2save is not None: system_message_json = [{"role": "system", "content": system_message}] new_msg_history = other_msg + [ { "role": "assistant", "content": [ { "type": "text", "text": content, } ], } ] with open(f"{self.path2save}/experiment_log.json", 'w') as json_file: json.dump(system_message_json + new_msg_history, json_file, indent=4) print(f"Data has been saved to [{self.path2save}/experiment_log.json]") return content def call_model_for_sketch_generation(self): print("Calling LLM...") add_args = {} add_args["stop"] = "" msg_history = [] init_canvas_str = None # self.init_canvas_str all_llm_output = self.get_response_from_llm( msg=self.input_prompt, system_message=system_prompt.format(res=self.res), msg_history=msg_history, init_canvas_str=init_canvas_str, #seed_mode=self.seed_mode, gen_mode=self.gen_mode, **add_args ) all_llm_output += f"" return all_llm_output def parse_model_to_svg(self, model_rep_sketch): # Parse model_rep with xml strokes_list_str, t_values_str = utils.parse_xml_string(model_rep_sketch, self.res) strokes_list, t_values = ast.literal_eval(strokes_list_str), ast.literal_eval(t_values_str) # extract control points from sampled lists all_control_points = utils.get_control_points(strokes_list, t_values, self.cells_to_pixels_map) # define SVG based on control point sketch_text_svg = utils.format_svg(all_control_points, dim=self.grid_size, stroke_width=self.stroke_width) return sketch_text_svg def generate_sketch(self): sketching_commands = self.call_model_for_sketch_generation() model_strokes_svg = self.parse_model_to_svg(sketching_commands) # saved the SVG sketch with open(f"{self.path2save}/{self.target_concept}.svg", "w") as svg_file: svg_file.write(model_strokes_svg) # vector->pixel # save the sketch to png with blank backgournd cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=f"{self.path2save}/{self.target_concept}.png", background_color="white") # save the sketch to png on the canvas output_png_path = f"{self.path2save}/{self.target_concept}_canvas.png" cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=output_png_path) foreground = Image.open(output_png_path) self.init_canvas.paste(Image.open(output_png_path), (0, 0), foreground) self.init_canvas.save(output_png_path) # Initialize and run the SketchApp if __name__ == '__main__': args = call_argparse() sketch_app = SketchApp(args) for attempts in range(3): try: sketch_app.generate_sketch() exit(0) except Exception as e: print(f"An error has occurred: {e}") traceback.print_exc()