Spaces:
Sleeping
Sleeping
Update sketch/gen_sketch.py
Browse files- sketch/gen_sketch.py +225 -0
sketch/gen_sketch.py
CHANGED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from openai import OpenAI
|
3 |
+
import ast
|
4 |
+
import cairosvg
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import utils
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
from PIL import Image
|
12 |
+
from prompts import sketch_first_prompt, system_prompt, gt_example
|
13 |
+
|
14 |
+
|
15 |
+
def call_argparse():
|
16 |
+
parser = argparse.ArgumentParser(description='Process Arguments')
|
17 |
+
|
18 |
+
# General
|
19 |
+
parser.add_argument('--concept_to_draw', type=str, default="cat")
|
20 |
+
parser.add_argument('--seed_mode', type=str, default='deterministic', choices=['deterministic', 'stochastic'])
|
21 |
+
parser.add_argument('--path2save', type=str, default=f"results/test")
|
22 |
+
parser.add_argument('--model', type=str, default='gpt-4o')
|
23 |
+
parser.add_argument('--gen_mode', type=str, default='generation', choices=['generation', 'completion'])
|
24 |
+
|
25 |
+
# Grid params
|
26 |
+
parser.add_argument('--res', type=int, default=50, help="the resolution of the grid is set to 50x50")
|
27 |
+
parser.add_argument('--cell_size', type=int, default=12, help="size of each cell in the grid")
|
28 |
+
parser.add_argument('--stroke_width', type=float, default=7.0)
|
29 |
+
|
30 |
+
args = parser.parse_args()
|
31 |
+
args.grid_size = (args.res + 1) * args.cell_size
|
32 |
+
|
33 |
+
args.save_name = args.concept_to_draw.replace(" ", "_")
|
34 |
+
args.path2save = f"{args.path2save}/{args.save_name}"
|
35 |
+
if not os.path.exists(args.path2save):
|
36 |
+
os.makedirs(args.path2save)
|
37 |
+
with open(f"{args.path2save}/experiment_log.json", 'w') as json_file:
|
38 |
+
json.dump([], json_file, indent=4)
|
39 |
+
return args
|
40 |
+
|
41 |
+
|
42 |
+
class SketchApp:
|
43 |
+
def __init__(self, args):
|
44 |
+
# General
|
45 |
+
self.path2save = args.path2save
|
46 |
+
self.target_concept = args.concept_to_draw
|
47 |
+
|
48 |
+
# Grid related
|
49 |
+
self.res = args.res
|
50 |
+
self.num_cells = args.res
|
51 |
+
self.cell_size = args.cell_size
|
52 |
+
self.grid_size = (args.grid_size, args.grid_size)
|
53 |
+
self.init_canvas, self.positions = utils.create_grid_image(res=args.res, cell_size=args.cell_size, header_size=args.cell_size)
|
54 |
+
self.init_canvas_str = utils.image_to_str(self.init_canvas)
|
55 |
+
self.cells_to_pixels_map = utils.cells_to_pixels(args.res, args.cell_size, header_size=args.cell_size)
|
56 |
+
|
57 |
+
# SVG related
|
58 |
+
self.stroke_width = args.stroke_width
|
59 |
+
|
60 |
+
# LLM Setup (you need to provide your OPENAI_API_KEY in your .env file)
|
61 |
+
# self.cache = False
|
62 |
+
self.max_tokens = 3000
|
63 |
+
load_dotenv()
|
64 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
65 |
+
self.client = OpenAI(api_key=openai_key)
|
66 |
+
self.model = "gpt-4o"
|
67 |
+
self.input_prompt = sketch_first_prompt.format(concept=args.concept_to_draw, gt_sketches_str=gt_example)
|
68 |
+
self.gen_mode = args.gen_mode
|
69 |
+
self.seed_mode = args.seed_mode
|
70 |
+
|
71 |
+
|
72 |
+
def call_llm(self, system_message, other_msg, additional_args):
|
73 |
+
response = self.client.chat.completions.create(
|
74 |
+
model=self.model,
|
75 |
+
messages=[{"role": "system", "content": system_message}] + other_msg,
|
76 |
+
max_tokens=self.max_tokens,
|
77 |
+
temperature=additional_args.get("temperature", 0.0),
|
78 |
+
stop=additional_args.get("stop", None)
|
79 |
+
)
|
80 |
+
return response.choices[0].message.content
|
81 |
+
|
82 |
+
|
83 |
+
def define_input_to_llm(self, msg_history, init_canvas_str, msg):
|
84 |
+
content = []
|
85 |
+
if init_canvas_str is not None:
|
86 |
+
content.append({
|
87 |
+
"type": "image_url",
|
88 |
+
"image_url": "data:image/jpeg;base64," + init_canvas_str
|
89 |
+
})
|
90 |
+
content.append({"type": "text", "text": msg})
|
91 |
+
|
92 |
+
other_msg = msg_history + [{"role": "user", "content": content}]
|
93 |
+
return other_msg
|
94 |
+
|
95 |
+
|
96 |
+
def get_response_from_llm(
|
97 |
+
self,
|
98 |
+
msg,
|
99 |
+
system_message,
|
100 |
+
msg_history=[],
|
101 |
+
init_canvas_str=None,
|
102 |
+
prefill_msg=None,
|
103 |
+
# seed_mode="stochastic",
|
104 |
+
stop=None,
|
105 |
+
gen_mode="generation"
|
106 |
+
):
|
107 |
+
additional_args = {}
|
108 |
+
# if seed_mode == "deterministic":
|
109 |
+
#additional_args["temperature"] = 0.0
|
110 |
+
#additional_args["top_k"] = 1
|
111 |
+
|
112 |
+
# if self.cache:
|
113 |
+
# system_message = [{
|
114 |
+
# "type": "text",
|
115 |
+
# "text": system_message,
|
116 |
+
# "cache_control": {"type": "ephemeral"}
|
117 |
+
# }]
|
118 |
+
|
119 |
+
# other_msg should contain all messgae without the system prompt
|
120 |
+
other_msg = self.define_input_to_llm(msg_history, init_canvas_str, msg)
|
121 |
+
|
122 |
+
if gen_mode == "completion":
|
123 |
+
if prefill_msg:
|
124 |
+
other_msg = other_msg + [{"role": "assistant", "content": f"{prefill_msg}"}]
|
125 |
+
|
126 |
+
# In case of stroke by stroke generation
|
127 |
+
if stop:
|
128 |
+
additional_args["stop"]= stop
|
129 |
+
else:
|
130 |
+
additional_args["stop"]= ["</answer>"]
|
131 |
+
|
132 |
+
response = self.call_llm(system_message, other_msg, additional_args)
|
133 |
+
content = response
|
134 |
+
|
135 |
+
if gen_mode == "completion":
|
136 |
+
other_msg = other_msg[:-1] # remove initial assistant prompt
|
137 |
+
content = f"{prefill_msg}{content}"
|
138 |
+
|
139 |
+
# saves to json
|
140 |
+
if self.path2save is not None:
|
141 |
+
system_message_json = [{"role": "system", "content": system_message}]
|
142 |
+
new_msg_history = other_msg + [
|
143 |
+
{
|
144 |
+
"role": "assistant",
|
145 |
+
"content": [
|
146 |
+
{
|
147 |
+
"type": "text",
|
148 |
+
"text": content,
|
149 |
+
}
|
150 |
+
],
|
151 |
+
}
|
152 |
+
]
|
153 |
+
with open(f"{self.path2save}/experiment_log.json", 'w') as json_file:
|
154 |
+
json.dump(system_message_json + new_msg_history, json_file, indent=4)
|
155 |
+
print(f"Data has been saved to [{self.path2save}/experiment_log.json]")
|
156 |
+
return content
|
157 |
+
|
158 |
+
|
159 |
+
def call_model_for_sketch_generation(self):
|
160 |
+
print("Calling LLM...")
|
161 |
+
|
162 |
+
add_args = {}
|
163 |
+
add_args["stop"] = "</answer>"
|
164 |
+
|
165 |
+
msg_history = []
|
166 |
+
init_canvas_str = None # self.init_canvas_str
|
167 |
+
|
168 |
+
all_llm_output = self.get_response_from_llm(
|
169 |
+
msg=self.input_prompt,
|
170 |
+
system_message=system_prompt.format(res=self.res),
|
171 |
+
msg_history=msg_history,
|
172 |
+
init_canvas_str=init_canvas_str,
|
173 |
+
#seed_mode=self.seed_mode,
|
174 |
+
gen_mode=self.gen_mode,
|
175 |
+
**add_args
|
176 |
+
)
|
177 |
+
|
178 |
+
all_llm_output += f"</answer>"
|
179 |
+
return all_llm_output
|
180 |
+
|
181 |
+
|
182 |
+
def parse_model_to_svg(self, model_rep_sketch):
|
183 |
+
# Parse model_rep with xml
|
184 |
+
strokes_list_str, t_values_str = utils.parse_xml_string(model_rep_sketch, self.res)
|
185 |
+
strokes_list, t_values = ast.literal_eval(strokes_list_str), ast.literal_eval(t_values_str)
|
186 |
+
|
187 |
+
# extract control points from sampled lists
|
188 |
+
all_control_points = utils.get_control_points(strokes_list, t_values, self.cells_to_pixels_map)
|
189 |
+
|
190 |
+
# define SVG based on control point
|
191 |
+
sketch_text_svg = utils.format_svg(all_control_points, dim=self.grid_size, stroke_width=self.stroke_width)
|
192 |
+
return sketch_text_svg
|
193 |
+
|
194 |
+
|
195 |
+
def generate_sketch(self):
|
196 |
+
sketching_commands = self.call_model_for_sketch_generation()
|
197 |
+
model_strokes_svg = self.parse_model_to_svg(sketching_commands)
|
198 |
+
# saved the SVG sketch
|
199 |
+
with open(f"{self.path2save}/{self.target_concept}.svg", "w") as svg_file:
|
200 |
+
svg_file.write(model_strokes_svg)
|
201 |
+
|
202 |
+
# vector->pixel
|
203 |
+
# save the sketch to png with blank backgournd
|
204 |
+
cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=f"{self.path2save}/{self.target_concept}.png", background_color="white")
|
205 |
+
|
206 |
+
# save the sketch to png on the canvas
|
207 |
+
output_png_path = f"{self.path2save}/{self.target_concept}_canvas.png"
|
208 |
+
cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=output_png_path)
|
209 |
+
foreground = Image.open(output_png_path)
|
210 |
+
self.init_canvas.paste(Image.open(output_png_path), (0, 0), foreground)
|
211 |
+
self.init_canvas.save(output_png_path)
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
# Initialize and run the SketchApp
|
216 |
+
if __name__ == '__main__':
|
217 |
+
args = call_argparse()
|
218 |
+
sketch_app = SketchApp(args)
|
219 |
+
for attempts in range(3):
|
220 |
+
try:
|
221 |
+
sketch_app.generate_sketch()
|
222 |
+
exit(0)
|
223 |
+
except Exception as e:
|
224 |
+
print(f"An error has occurred: {e}")
|
225 |
+
traceback.print_exc()
|