File size: 8,324 Bytes
f275c93 ffd0ad7 2d4137f f275c93 ed629f0 f275c93 ed629f0 f275c93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import argparse
from openai import OpenAI
import ast
import cairosvg
import json
import os
import utils
import traceback
from dotenv import load_dotenv
from PIL import Image
from prompts import sketch_first_prompt, system_prompt, gt_example
def call_argparse():
parser = argparse.ArgumentParser(description='Process Arguments')
# General
parser.add_argument('--concept_to_draw', type=str, default="cat")
parser.add_argument('--path2save', type=str, default=f"results/test")
parser.add_argument('--temperature', type=float, default=0.3)
parser.add_argument('--model', type=str, default='o3')
parser.add_argument('--gen_mode', type=str, default='generation', choices=['generation', 'completion'])
# Grid params
parser.add_argument('--res', type=int, default=50, help="the resolution of the grid is set to 50x50")
parser.add_argument('--cell_size', type=int, default=12, help="size of each cell in the grid")
parser.add_argument('--stroke_width', type=float, default=7.0)
args = parser.parse_args()
args.grid_size = (args.res + 1) * args.cell_size
args.save_name = args.concept_to_draw.replace(" ", "_")
args.path2save = f"{args.path2save}/{args.save_name}"
if not os.path.exists(args.path2save):
os.makedirs(args.path2save)
with open(f"{args.path2save}/experiment_log.json", 'w') as json_file:
json.dump([], json_file, indent=4)
return args
class SketchApp:
def __init__(self, args):
# General
self.path2save = args.path2save
self.target_concept = args.concept_to_draw
# Grid related
self.res = args.res
self.num_cells = args.res
self.cell_size = args.cell_size
self.grid_size = (args.grid_size, args.grid_size)
self.init_canvas, self.positions = utils.create_grid_image(res=args.res, cell_size=args.cell_size, header_size=args.cell_size)
self.init_canvas_str = utils.image_to_str(self.init_canvas)
self.cells_to_pixels_map = utils.cells_to_pixels(args.res, args.cell_size, header_size=args.cell_size)
# SVG related
self.stroke_width = args.stroke_width
# LLM Setup (you need to provide your OPENAI_API_KEY in your .env file)
# self.cache = False
self.max_tokens = 3000
openai_key = os.getenv("OPENAI_API_KEY")
self.client = OpenAI(api_key=openai_key)
self.model = "gpt-4o"
self.input_prompt = sketch_first_prompt.format(concept=args.concept_to_draw, gt_sketches_str=gt_example)
self.gen_mode = args.gen_mode
self.temperature = args.temperature
def call_llm(self, system_message, other_msg, additional_args):
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "system", "content": system_message}] + other_msg,
max_tokens=self.max_tokens,
temperature=self.temperature,
stop=additional_args.get("stop", None)
)
return response.choices[0].message.content
def define_input_to_llm(self, msg_history, init_canvas_str, msg):
content = []
if init_canvas_str is not None:
content.append({
"type": "image_url",
"image_url": "data:image/jpeg;base64," + init_canvas_str
})
content.append({"type": "text", "text": msg})
other_msg = msg_history + [{"role": "user", "content": content}]
return other_msg
def get_response_from_llm(
self,
msg,
system_message,
msg_history=[],
init_canvas_str=None,
prefill_msg=None,
# seed_mode="stochastic",
stop=None,
gen_mode="generation"
):
additional_args = {}
# if seed_mode == "deterministic":
#additional_args["temperature"] = 0.0
#additional_args["top_k"] = 1
# if self.cache:
# system_message = [{
# "type": "text",
# "text": system_message,
# "cache_control": {"type": "ephemeral"}
# }]
# other_msg should contain all messgae without the system prompt
other_msg = self.define_input_to_llm(msg_history, init_canvas_str, msg)
if gen_mode == "completion":
if prefill_msg:
other_msg = other_msg + [{"role": "assistant", "content": f"{prefill_msg}"}]
# In case of stroke by stroke generation
if stop:
additional_args["stop"]= stop
else:
additional_args["stop"]= ["</answer>"]
response = self.call_llm(system_message, other_msg, additional_args)
content = response
if gen_mode == "completion":
other_msg = other_msg[:-1] # remove initial assistant prompt
content = f"{prefill_msg}{content}"
# saves to json
if self.path2save is not None:
system_message_json = [{"role": "system", "content": system_message}]
new_msg_history = other_msg + [
{
"role": "assistant",
"content": [
{
"type": "text",
"text": content,
}
],
}
]
with open(f"{self.path2save}/experiment_log.json", 'w') as json_file:
json.dump(system_message_json + new_msg_history, json_file, indent=4)
print(f"Data has been saved to [{self.path2save}/experiment_log.json]")
return content
def call_model_for_sketch_generation(self):
print("Calling LLM...")
add_args = {}
add_args["stop"] = "</answer>"
msg_history = []
init_canvas_str = None # self.init_canvas_str
all_llm_output = self.get_response_from_llm(
msg=self.input_prompt,
system_message=system_prompt.format(res=self.res),
msg_history=msg_history,
init_canvas_str=init_canvas_str,
#seed_mode=self.seed_mode,
gen_mode=self.gen_mode,
**add_args
)
all_llm_output += f"</answer>"
return all_llm_output
def parse_model_to_svg(self, model_rep_sketch):
# Parse model_rep with xml
strokes_list_str, t_values_str = utils.parse_xml_string(model_rep_sketch, self.res)
strokes_list, t_values = ast.literal_eval(strokes_list_str), ast.literal_eval(t_values_str)
# extract control points from sampled lists
all_control_points = utils.get_control_points(strokes_list, t_values, self.cells_to_pixels_map)
# define SVG based on control point
sketch_text_svg = utils.format_svg(all_control_points, dim=self.grid_size, stroke_width=self.stroke_width)
return sketch_text_svg
def generate_sketch(self):
sketching_commands = self.call_model_for_sketch_generation()
model_strokes_svg = self.parse_model_to_svg(sketching_commands)
# saved the SVG sketch
with open(f"{self.path2save}/{self.target_concept}.svg", "w") as svg_file:
svg_file.write(model_strokes_svg)
# vector->pixel
# save the sketch to png with blank backgournd
cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=f"{self.path2save}/{self.target_concept}.png", background_color="white")
# save the sketch to png on the canvas
output_png_path = f"{self.path2save}/{self.target_concept}_canvas.png"
cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=output_png_path)
foreground = Image.open(output_png_path)
self.init_canvas.paste(Image.open(output_png_path), (0, 0), foreground)
self.init_canvas.save(output_png_path)
# Initialize and run the SketchApp
if __name__ == '__main__':
args = call_argparse()
sketch_app = SketchApp(args)
for attempts in range(3):
try:
sketch_app.generate_sketch()
exit(0)
except Exception as e:
print(f"An error has occurred: {e}")
traceback.print_exc() |