|
import argparse |
|
from openai import OpenAI |
|
import ast |
|
import cairosvg |
|
import json |
|
import os |
|
import utils |
|
import traceback |
|
|
|
from dotenv import load_dotenv |
|
from PIL import Image |
|
from prompts import sketch_first_prompt, system_prompt, gt_example |
|
|
|
|
|
def call_argparse(): |
|
parser = argparse.ArgumentParser(description='Process Arguments') |
|
|
|
|
|
parser.add_argument('--concept_to_draw', type=str, default="cat") |
|
parser.add_argument('--path2save', type=str, default=f"results/test") |
|
parser.add_argument('--temperature', type=float, default=0.3) |
|
parser.add_argument('--model', type=str, default='o3') |
|
parser.add_argument('--gen_mode', type=str, default='generation', choices=['generation', 'completion']) |
|
|
|
|
|
parser.add_argument('--res', type=int, default=50, help="the resolution of the grid is set to 50x50") |
|
parser.add_argument('--cell_size', type=int, default=12, help="size of each cell in the grid") |
|
parser.add_argument('--stroke_width', type=float, default=7.0) |
|
|
|
args = parser.parse_args() |
|
args.grid_size = (args.res + 1) * args.cell_size |
|
|
|
args.save_name = args.concept_to_draw.replace(" ", "_") |
|
args.path2save = f"{args.path2save}/{args.save_name}" |
|
if not os.path.exists(args.path2save): |
|
os.makedirs(args.path2save) |
|
with open(f"{args.path2save}/experiment_log.json", 'w') as json_file: |
|
json.dump([], json_file, indent=4) |
|
return args |
|
|
|
|
|
class SketchApp: |
|
def __init__(self, args): |
|
|
|
self.path2save = args.path2save |
|
self.target_concept = args.concept_to_draw |
|
|
|
|
|
self.res = args.res |
|
self.num_cells = args.res |
|
self.cell_size = args.cell_size |
|
self.grid_size = (args.grid_size, args.grid_size) |
|
self.init_canvas, self.positions = utils.create_grid_image(res=args.res, cell_size=args.cell_size, header_size=args.cell_size) |
|
self.init_canvas_str = utils.image_to_str(self.init_canvas) |
|
self.cells_to_pixels_map = utils.cells_to_pixels(args.res, args.cell_size, header_size=args.cell_size) |
|
|
|
|
|
self.stroke_width = args.stroke_width |
|
|
|
|
|
|
|
self.max_tokens = 3000 |
|
openai_key = os.getenv("OPENAI_API_KEY") |
|
self.client = OpenAI(api_key=openai_key) |
|
self.model = "gpt-4o" |
|
self.input_prompt = sketch_first_prompt.format(concept=args.concept_to_draw, gt_sketches_str=gt_example) |
|
self.gen_mode = args.gen_mode |
|
self.temperature = args.temperature |
|
|
|
|
|
def call_llm(self, system_message, other_msg, additional_args): |
|
response = self.client.chat.completions.create( |
|
model=self.model, |
|
messages=[{"role": "system", "content": system_message}] + other_msg, |
|
max_tokens=self.max_tokens, |
|
temperature=self.temperature, |
|
stop=additional_args.get("stop", None) |
|
) |
|
return response.choices[0].message.content |
|
|
|
|
|
def define_input_to_llm(self, msg_history, init_canvas_str, msg): |
|
content = [] |
|
if init_canvas_str is not None: |
|
content.append({ |
|
"type": "image_url", |
|
"image_url": "data:image/jpeg;base64," + init_canvas_str |
|
}) |
|
content.append({"type": "text", "text": msg}) |
|
|
|
other_msg = msg_history + [{"role": "user", "content": content}] |
|
return other_msg |
|
|
|
|
|
def get_response_from_llm( |
|
self, |
|
msg, |
|
system_message, |
|
msg_history=[], |
|
init_canvas_str=None, |
|
prefill_msg=None, |
|
|
|
stop=None, |
|
gen_mode="generation" |
|
): |
|
additional_args = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
other_msg = self.define_input_to_llm(msg_history, init_canvas_str, msg) |
|
|
|
if gen_mode == "completion": |
|
if prefill_msg: |
|
other_msg = other_msg + [{"role": "assistant", "content": f"{prefill_msg}"}] |
|
|
|
|
|
if stop: |
|
additional_args["stop"]= stop |
|
else: |
|
additional_args["stop"]= ["</answer>"] |
|
|
|
response = self.call_llm(system_message, other_msg, additional_args) |
|
content = response |
|
|
|
if gen_mode == "completion": |
|
other_msg = other_msg[:-1] |
|
content = f"{prefill_msg}{content}" |
|
|
|
|
|
if self.path2save is not None: |
|
system_message_json = [{"role": "system", "content": system_message}] |
|
new_msg_history = other_msg + [ |
|
{ |
|
"role": "assistant", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": content, |
|
} |
|
], |
|
} |
|
] |
|
with open(f"{self.path2save}/experiment_log.json", 'w') as json_file: |
|
json.dump(system_message_json + new_msg_history, json_file, indent=4) |
|
print(f"Data has been saved to [{self.path2save}/experiment_log.json]") |
|
return content |
|
|
|
|
|
def call_model_for_sketch_generation(self): |
|
print("Calling LLM...") |
|
|
|
add_args = {} |
|
add_args["stop"] = "</answer>" |
|
|
|
msg_history = [] |
|
init_canvas_str = None |
|
|
|
all_llm_output = self.get_response_from_llm( |
|
msg=self.input_prompt, |
|
system_message=system_prompt.format(res=self.res), |
|
msg_history=msg_history, |
|
init_canvas_str=init_canvas_str, |
|
|
|
gen_mode=self.gen_mode, |
|
**add_args |
|
) |
|
|
|
all_llm_output += f"</answer>" |
|
return all_llm_output |
|
|
|
|
|
def parse_model_to_svg(self, model_rep_sketch): |
|
|
|
strokes_list_str, t_values_str = utils.parse_xml_string(model_rep_sketch, self.res) |
|
strokes_list, t_values = ast.literal_eval(strokes_list_str), ast.literal_eval(t_values_str) |
|
|
|
|
|
all_control_points = utils.get_control_points(strokes_list, t_values, self.cells_to_pixels_map) |
|
|
|
|
|
sketch_text_svg = utils.format_svg(all_control_points, dim=self.grid_size, stroke_width=self.stroke_width) |
|
return sketch_text_svg |
|
|
|
|
|
def generate_sketch(self): |
|
sketching_commands = self.call_model_for_sketch_generation() |
|
model_strokes_svg = self.parse_model_to_svg(sketching_commands) |
|
|
|
with open(f"{self.path2save}/{self.target_concept}.svg", "w") as svg_file: |
|
svg_file.write(model_strokes_svg) |
|
|
|
|
|
|
|
cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=f"{self.path2save}/{self.target_concept}.png", background_color="white") |
|
|
|
|
|
output_png_path = f"{self.path2save}/{self.target_concept}_canvas.png" |
|
cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=output_png_path) |
|
foreground = Image.open(output_png_path) |
|
self.init_canvas.paste(Image.open(output_png_path), (0, 0), foreground) |
|
self.init_canvas.save(output_png_path) |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
args = call_argparse() |
|
sketch_app = SketchApp(args) |
|
for attempts in range(3): |
|
try: |
|
sketch_app.generate_sketch() |
|
exit(0) |
|
except Exception as e: |
|
print(f"An error has occurred: {e}") |
|
traceback.print_exc() |