faxnoprinter commited on
Commit
f275c93
·
verified ·
1 Parent(s): e18776a

Update sketch/gen_sketch.py

Browse files
Files changed (1) hide show
  1. sketch/gen_sketch.py +225 -0
sketch/gen_sketch.py CHANGED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from openai import OpenAI
3
+ import ast
4
+ import cairosvg
5
+ import json
6
+ import os
7
+ import utils
8
+ import traceback
9
+
10
+ from dotenv import load_dotenv
11
+ from PIL import Image
12
+ from prompts import sketch_first_prompt, system_prompt, gt_example
13
+
14
+
15
+ def call_argparse():
16
+ parser = argparse.ArgumentParser(description='Process Arguments')
17
+
18
+ # General
19
+ parser.add_argument('--concept_to_draw', type=str, default="cat")
20
+ parser.add_argument('--seed_mode', type=str, default='deterministic', choices=['deterministic', 'stochastic'])
21
+ parser.add_argument('--path2save', type=str, default=f"results/test")
22
+ parser.add_argument('--model', type=str, default='gpt-4o')
23
+ parser.add_argument('--gen_mode', type=str, default='generation', choices=['generation', 'completion'])
24
+
25
+ # Grid params
26
+ parser.add_argument('--res', type=int, default=50, help="the resolution of the grid is set to 50x50")
27
+ parser.add_argument('--cell_size', type=int, default=12, help="size of each cell in the grid")
28
+ parser.add_argument('--stroke_width', type=float, default=7.0)
29
+
30
+ args = parser.parse_args()
31
+ args.grid_size = (args.res + 1) * args.cell_size
32
+
33
+ args.save_name = args.concept_to_draw.replace(" ", "_")
34
+ args.path2save = f"{args.path2save}/{args.save_name}"
35
+ if not os.path.exists(args.path2save):
36
+ os.makedirs(args.path2save)
37
+ with open(f"{args.path2save}/experiment_log.json", 'w') as json_file:
38
+ json.dump([], json_file, indent=4)
39
+ return args
40
+
41
+
42
+ class SketchApp:
43
+ def __init__(self, args):
44
+ # General
45
+ self.path2save = args.path2save
46
+ self.target_concept = args.concept_to_draw
47
+
48
+ # Grid related
49
+ self.res = args.res
50
+ self.num_cells = args.res
51
+ self.cell_size = args.cell_size
52
+ self.grid_size = (args.grid_size, args.grid_size)
53
+ self.init_canvas, self.positions = utils.create_grid_image(res=args.res, cell_size=args.cell_size, header_size=args.cell_size)
54
+ self.init_canvas_str = utils.image_to_str(self.init_canvas)
55
+ self.cells_to_pixels_map = utils.cells_to_pixels(args.res, args.cell_size, header_size=args.cell_size)
56
+
57
+ # SVG related
58
+ self.stroke_width = args.stroke_width
59
+
60
+ # LLM Setup (you need to provide your OPENAI_API_KEY in your .env file)
61
+ # self.cache = False
62
+ self.max_tokens = 3000
63
+ load_dotenv()
64
+ openai_key = os.getenv("OPENAI_API_KEY")
65
+ self.client = OpenAI(api_key=openai_key)
66
+ self.model = "gpt-4o"
67
+ self.input_prompt = sketch_first_prompt.format(concept=args.concept_to_draw, gt_sketches_str=gt_example)
68
+ self.gen_mode = args.gen_mode
69
+ self.seed_mode = args.seed_mode
70
+
71
+
72
+ def call_llm(self, system_message, other_msg, additional_args):
73
+ response = self.client.chat.completions.create(
74
+ model=self.model,
75
+ messages=[{"role": "system", "content": system_message}] + other_msg,
76
+ max_tokens=self.max_tokens,
77
+ temperature=additional_args.get("temperature", 0.0),
78
+ stop=additional_args.get("stop", None)
79
+ )
80
+ return response.choices[0].message.content
81
+
82
+
83
+ def define_input_to_llm(self, msg_history, init_canvas_str, msg):
84
+ content = []
85
+ if init_canvas_str is not None:
86
+ content.append({
87
+ "type": "image_url",
88
+ "image_url": "data:image/jpeg;base64," + init_canvas_str
89
+ })
90
+ content.append({"type": "text", "text": msg})
91
+
92
+ other_msg = msg_history + [{"role": "user", "content": content}]
93
+ return other_msg
94
+
95
+
96
+ def get_response_from_llm(
97
+ self,
98
+ msg,
99
+ system_message,
100
+ msg_history=[],
101
+ init_canvas_str=None,
102
+ prefill_msg=None,
103
+ # seed_mode="stochastic",
104
+ stop=None,
105
+ gen_mode="generation"
106
+ ):
107
+ additional_args = {}
108
+ # if seed_mode == "deterministic":
109
+ #additional_args["temperature"] = 0.0
110
+ #additional_args["top_k"] = 1
111
+
112
+ # if self.cache:
113
+ # system_message = [{
114
+ # "type": "text",
115
+ # "text": system_message,
116
+ # "cache_control": {"type": "ephemeral"}
117
+ # }]
118
+
119
+ # other_msg should contain all messgae without the system prompt
120
+ other_msg = self.define_input_to_llm(msg_history, init_canvas_str, msg)
121
+
122
+ if gen_mode == "completion":
123
+ if prefill_msg:
124
+ other_msg = other_msg + [{"role": "assistant", "content": f"{prefill_msg}"}]
125
+
126
+ # In case of stroke by stroke generation
127
+ if stop:
128
+ additional_args["stop"]= stop
129
+ else:
130
+ additional_args["stop"]= ["</answer>"]
131
+
132
+ response = self.call_llm(system_message, other_msg, additional_args)
133
+ content = response
134
+
135
+ if gen_mode == "completion":
136
+ other_msg = other_msg[:-1] # remove initial assistant prompt
137
+ content = f"{prefill_msg}{content}"
138
+
139
+ # saves to json
140
+ if self.path2save is not None:
141
+ system_message_json = [{"role": "system", "content": system_message}]
142
+ new_msg_history = other_msg + [
143
+ {
144
+ "role": "assistant",
145
+ "content": [
146
+ {
147
+ "type": "text",
148
+ "text": content,
149
+ }
150
+ ],
151
+ }
152
+ ]
153
+ with open(f"{self.path2save}/experiment_log.json", 'w') as json_file:
154
+ json.dump(system_message_json + new_msg_history, json_file, indent=4)
155
+ print(f"Data has been saved to [{self.path2save}/experiment_log.json]")
156
+ return content
157
+
158
+
159
+ def call_model_for_sketch_generation(self):
160
+ print("Calling LLM...")
161
+
162
+ add_args = {}
163
+ add_args["stop"] = "</answer>"
164
+
165
+ msg_history = []
166
+ init_canvas_str = None # self.init_canvas_str
167
+
168
+ all_llm_output = self.get_response_from_llm(
169
+ msg=self.input_prompt,
170
+ system_message=system_prompt.format(res=self.res),
171
+ msg_history=msg_history,
172
+ init_canvas_str=init_canvas_str,
173
+ #seed_mode=self.seed_mode,
174
+ gen_mode=self.gen_mode,
175
+ **add_args
176
+ )
177
+
178
+ all_llm_output += f"</answer>"
179
+ return all_llm_output
180
+
181
+
182
+ def parse_model_to_svg(self, model_rep_sketch):
183
+ # Parse model_rep with xml
184
+ strokes_list_str, t_values_str = utils.parse_xml_string(model_rep_sketch, self.res)
185
+ strokes_list, t_values = ast.literal_eval(strokes_list_str), ast.literal_eval(t_values_str)
186
+
187
+ # extract control points from sampled lists
188
+ all_control_points = utils.get_control_points(strokes_list, t_values, self.cells_to_pixels_map)
189
+
190
+ # define SVG based on control point
191
+ sketch_text_svg = utils.format_svg(all_control_points, dim=self.grid_size, stroke_width=self.stroke_width)
192
+ return sketch_text_svg
193
+
194
+
195
+ def generate_sketch(self):
196
+ sketching_commands = self.call_model_for_sketch_generation()
197
+ model_strokes_svg = self.parse_model_to_svg(sketching_commands)
198
+ # saved the SVG sketch
199
+ with open(f"{self.path2save}/{self.target_concept}.svg", "w") as svg_file:
200
+ svg_file.write(model_strokes_svg)
201
+
202
+ # vector->pixel
203
+ # save the sketch to png with blank backgournd
204
+ cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=f"{self.path2save}/{self.target_concept}.png", background_color="white")
205
+
206
+ # save the sketch to png on the canvas
207
+ output_png_path = f"{self.path2save}/{self.target_concept}_canvas.png"
208
+ cairosvg.svg2png(url=f"{self.path2save}/{self.target_concept}.svg", write_to=output_png_path)
209
+ foreground = Image.open(output_png_path)
210
+ self.init_canvas.paste(Image.open(output_png_path), (0, 0), foreground)
211
+ self.init_canvas.save(output_png_path)
212
+
213
+
214
+
215
+ # Initialize and run the SketchApp
216
+ if __name__ == '__main__':
217
+ args = call_argparse()
218
+ sketch_app = SketchApp(args)
219
+ for attempts in range(3):
220
+ try:
221
+ sketch_app.generate_sketch()
222
+ exit(0)
223
+ except Exception as e:
224
+ print(f"An error has occurred: {e}")
225
+ traceback.print_exc()