Spaces:
Running
on
Zero
Running
on
Zero
Update inference_coz_single.py
Browse files- inference_coz_single.py +59 -75
inference_coz_single.py
CHANGED
@@ -25,66 +25,77 @@ def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
|
|
25 |
# Helper: Generate a single VLM prompt for recursive_multiscale
|
26 |
# -------------------------------------------------------------------
|
27 |
def _generate_vlm_prompt(
|
28 |
-
vlm_model,
|
29 |
-
vlm_processor,
|
30 |
-
process_vision_info,
|
31 |
-
|
32 |
-
|
33 |
device: str = "cuda"
|
34 |
) -> str:
|
35 |
"""
|
36 |
-
Given two
|
37 |
-
-
|
38 |
-
-
|
39 |
-
|
40 |
-
Returns a string like “cat on sofa, pet, indoor, living room”, etc.
|
41 |
"""
|
42 |
-
|
|
|
43 |
message_text = (
|
44 |
"The second image is a zoom-in of the first image. "
|
45 |
"Based on this knowledge, what is in the second image? "
|
46 |
"Give me a set of words."
|
47 |
)
|
48 |
|
49 |
-
# (2) Build the two-image “chat” payload
|
|
|
|
|
|
|
|
|
50 |
messages = [
|
51 |
{"role": "system", "content": message_text},
|
52 |
{
|
53 |
"role": "user",
|
54 |
"content": [
|
55 |
-
{"type": "image", "image":
|
56 |
-
{"type": "image", "image":
|
57 |
],
|
58 |
},
|
59 |
]
|
60 |
|
61 |
-
# (3)
|
|
|
|
|
|
|
|
|
62 |
text = vlm_processor.apply_chat_template(
|
63 |
-
messages,
|
|
|
|
|
64 |
)
|
65 |
image_inputs, video_inputs = process_vision_info(messages)
|
|
|
66 |
inputs = vlm_processor(
|
67 |
-
text=[text],
|
68 |
-
images=image_inputs,
|
69 |
-
videos=video_inputs,
|
70 |
-
padding=True,
|
71 |
return_tensors="pt",
|
72 |
).to(device)
|
73 |
|
74 |
-
# (4) Generate
|
75 |
generated = vlm_model.generate(**inputs, max_new_tokens=128)
|
76 |
-
# strip off the prompt tokens from each generated sequence:
|
77 |
trimmed = [
|
78 |
-
out_ids[len(in_ids)
|
|
|
79 |
]
|
80 |
out_text = vlm_processor.batch_decode(
|
81 |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
82 |
)[0]
|
83 |
|
84 |
-
# (5) Return exactly the bare words (no extra “,” if no additional user prompt)
|
85 |
return out_text.strip()
|
86 |
|
87 |
|
|
|
88 |
# -------------------------------------------------------------------
|
89 |
# Main Function: recursive_multiscale_sr (with multiple centers)
|
90 |
# -------------------------------------------------------------------
|
@@ -203,88 +214,61 @@ def recursive_multiscale_sr(
|
|
203 |
###############################
|
204 |
# 6. Prepare the very first “full” image
|
205 |
###############################
|
206 |
-
# 6.1 Load + center crop → first_image
|
207 |
img0 = Image.open(input_png_path).convert("RGB")
|
208 |
img0 = resize_and_center_crop(img0, process_size)
|
209 |
|
210 |
-
#
|
211 |
-
|
212 |
-
img0.save(prev_path)
|
213 |
|
214 |
-
# We will maintain lists of PIL outputs and prompts:
|
215 |
sr_pil_list: list[Image.Image] = []
|
216 |
-
prompt_list:
|
217 |
|
218 |
-
###############################
|
219 |
-
# 7. Recursion loop (now up to rec_num times)
|
220 |
-
###############################
|
221 |
for rec in range(rec_num):
|
222 |
-
# (A)
|
223 |
-
|
224 |
-
|
225 |
|
226 |
-
# (1) Compute the “low-res” window size:
|
227 |
-
new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4
|
228 |
-
|
229 |
-
# (2) Map normalized center → pixel center, then clamp so crop stays in bounds:
|
230 |
cx_norm, cy_norm = centers[rec]
|
231 |
cx = int(cx_norm * w)
|
232 |
cy = int(cy_norm * h)
|
233 |
-
half_w = new_w // 2
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
top = cy - half_h
|
239 |
-
# clamp left ∈ [0, w - new_w], top ∈ [0, h - new_h]
|
240 |
-
left = max(0, min(left, w - new_w))
|
241 |
-
top = max(0, min(top, h - new_h))
|
242 |
-
right = left + new_w
|
243 |
-
bottom = top + new_h
|
244 |
|
245 |
cropped = prev_pil.crop((left, top, right, bottom))
|
246 |
|
247 |
-
# (B)
|
248 |
-
|
249 |
-
zoom_path = os.path.join(td, f"step{rec+1}_zoom.png")
|
250 |
-
zoomed.save(zoom_path)
|
251 |
|
252 |
-
# (C) Generate
|
253 |
prompt_tag = _generate_vlm_prompt(
|
254 |
vlm_model=vlm_model,
|
255 |
vlm_processor=vlm_processor,
|
256 |
process_vision_info=process_vision_info,
|
257 |
-
|
258 |
-
|
259 |
device=device,
|
260 |
)
|
261 |
-
# (By default, no extra user prompt is appended.)
|
262 |
|
263 |
-
# (D) Prepare
|
264 |
to_tensor = transforms.ToTensor()
|
265 |
-
lq = to_tensor(
|
266 |
lq = (lq * 2.0) - 1.0
|
267 |
|
268 |
-
# (E)
|
269 |
with torch.no_grad():
|
270 |
-
out_tensor = model_test(lq, prompt=prompt_tag)[0]
|
271 |
out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
|
272 |
-
# back to PIL in [0,1]:
|
273 |
out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
|
274 |
|
275 |
-
# (F)
|
276 |
-
|
277 |
-
out_pil.save(out_path)
|
278 |
-
prev_path = out_path
|
279 |
|
280 |
-
# (G) Append
|
281 |
sr_pil_list.append(out_pil)
|
282 |
prompt_list.append(prompt_tag)
|
283 |
|
284 |
-
# end for(rec)
|
285 |
-
|
286 |
-
###############################
|
287 |
-
# 8. Return the SR outputs & prompts
|
288 |
-
###############################
|
289 |
-
# The list sr_pil_list = [ SR1, SR2, …, SR_rec_num ] in order.
|
290 |
return sr_pil_list, prompt_list
|
|
|
25 |
# Helper: Generate a single VLM prompt for recursive_multiscale
|
26 |
# -------------------------------------------------------------------
|
27 |
def _generate_vlm_prompt(
|
28 |
+
vlm_model: Qwen2_5_VLForConditionalGeneration,
|
29 |
+
vlm_processor: AutoProcessor,
|
30 |
+
process_vision_info, # this is your helper that turns “messages” → image_inputs / video_inputs
|
31 |
+
prev_pil: Image.Image, # <– pass PIL instead of path
|
32 |
+
zoomed_pil: Image.Image, # <– pass PIL instead of path
|
33 |
device: str = "cuda"
|
34 |
) -> str:
|
35 |
"""
|
36 |
+
Given two PIL.Image inputs:
|
37 |
+
- prev_pil: the “full” image at the previous recursion.
|
38 |
+
- zoomed_pil: the cropped+resized (zoom) image for this step.
|
39 |
+
Returns a single “recursive_multiscale” prompt string.
|
|
|
40 |
"""
|
41 |
+
|
42 |
+
# (1) System message
|
43 |
message_text = (
|
44 |
"The second image is a zoom-in of the first image. "
|
45 |
"Based on this knowledge, what is in the second image? "
|
46 |
"Give me a set of words."
|
47 |
)
|
48 |
|
49 |
+
# (2) Build the two-image “chat” payload
|
50 |
+
#
|
51 |
+
# Instead of passing a filename, we pass the actual PIL.Image.
|
52 |
+
# The processor’s `process_vision_info` should know how to turn
|
53 |
+
# a message of the form {"type":"image","image": PIL_IMAGE} into tensors.
|
54 |
messages = [
|
55 |
{"role": "system", "content": message_text},
|
56 |
{
|
57 |
"role": "user",
|
58 |
"content": [
|
59 |
+
{"type": "image", "image": prev_pil},
|
60 |
+
{"type": "image", "image": zoomed_pil},
|
61 |
],
|
62 |
},
|
63 |
]
|
64 |
|
65 |
+
# (3) Now run the “chat” through the VL processor
|
66 |
+
#
|
67 |
+
# - `apply_chat_template` will build the tokenized prompt (without running it yet).
|
68 |
+
# - `process_vision_info` should inspect the same `messages` list and return
|
69 |
+
# `image_inputs` and `video_inputs` (tensors) for any attached PIL images.
|
70 |
text = vlm_processor.apply_chat_template(
|
71 |
+
messages,
|
72 |
+
tokenize=False,
|
73 |
+
add_generation_prompt=True
|
74 |
)
|
75 |
image_inputs, video_inputs = process_vision_info(messages)
|
76 |
+
|
77 |
inputs = vlm_processor(
|
78 |
+
text=[text],
|
79 |
+
images=image_inputs,
|
80 |
+
videos=video_inputs,
|
81 |
+
padding=True,
|
82 |
return_tensors="pt",
|
83 |
).to(device)
|
84 |
|
85 |
+
# (4) Generate and decode
|
86 |
generated = vlm_model.generate(**inputs, max_new_tokens=128)
|
|
|
87 |
trimmed = [
|
88 |
+
out_ids[len(in_ids):]
|
89 |
+
for in_ids, out_ids in zip(inputs.input_ids, generated)
|
90 |
]
|
91 |
out_text = vlm_processor.batch_decode(
|
92 |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
93 |
)[0]
|
94 |
|
|
|
95 |
return out_text.strip()
|
96 |
|
97 |
|
98 |
+
|
99 |
# -------------------------------------------------------------------
|
100 |
# Main Function: recursive_multiscale_sr (with multiple centers)
|
101 |
# -------------------------------------------------------------------
|
|
|
214 |
###############################
|
215 |
# 6. Prepare the very first “full” image
|
216 |
###############################
|
217 |
+
# (6.1) Load + center crop → first_image (512×512)
|
218 |
img0 = Image.open(input_png_path).convert("RGB")
|
219 |
img0 = resize_and_center_crop(img0, process_size)
|
220 |
|
221 |
+
# Note: we no longer need to write “prev.png” to disk. Just keep it in memory.
|
222 |
+
prev_pil = img0.copy()
|
|
|
223 |
|
|
|
224 |
sr_pil_list: list[Image.Image] = []
|
225 |
+
prompt_list: list[str] = []
|
226 |
|
|
|
|
|
|
|
227 |
for rec in range(rec_num):
|
228 |
+
# (A) Compute low-res crop window on prev_pil
|
229 |
+
w, h = prev_pil.size # (512×512)
|
230 |
+
new_w, new_h = w // upscale, h // upscale
|
231 |
|
|
|
|
|
|
|
|
|
232 |
cx_norm, cy_norm = centers[rec]
|
233 |
cx = int(cx_norm * w)
|
234 |
cy = int(cy_norm * h)
|
235 |
+
half_w, half_h = new_w // 2, new_h // 2
|
236 |
+
|
237 |
+
left = max(0, min(cx - half_w, w - new_w))
|
238 |
+
top = max(0, min(cy - half_h, h - new_h))
|
239 |
+
right, bottom = left + new_w, top + new_h
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
cropped = prev_pil.crop((left, top, right, bottom))
|
242 |
|
243 |
+
# (B) Upsample that crop back to (512×512)
|
244 |
+
zoomed_pil = cropped.resize((w, h), Image.BICUBIC)
|
|
|
|
|
245 |
|
246 |
+
# (C) Generate VLM prompt by passing PILs directly:
|
247 |
prompt_tag = _generate_vlm_prompt(
|
248 |
vlm_model=vlm_model,
|
249 |
vlm_processor=vlm_processor,
|
250 |
process_vision_info=process_vision_info,
|
251 |
+
prev_pil=prev_pil, # <– PIL
|
252 |
+
zoomed_pil=zoomed_pil, # <– PIL
|
253 |
device=device,
|
254 |
)
|
|
|
255 |
|
256 |
+
# (D) Prepare “zoomed_pil” → tensor in [−1, 1]
|
257 |
to_tensor = transforms.ToTensor()
|
258 |
+
lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512)
|
259 |
lq = (lq * 2.0) - 1.0
|
260 |
|
261 |
+
# (E) Run SR inference
|
262 |
with torch.no_grad():
|
263 |
+
out_tensor = model_test(lq, prompt=prompt_tag)[0]
|
264 |
out_tensor = out_tensor.clamp(-1.0, 1.0).cpu()
|
|
|
265 |
out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5)
|
266 |
|
267 |
+
# (F) Bookkeeping: set prev_pil = out_pil for next iteration
|
268 |
+
prev_pil = out_pil
|
|
|
|
|
269 |
|
270 |
+
# (G) Append to results
|
271 |
sr_pil_list.append(out_pil)
|
272 |
prompt_list.append(prompt_tag)
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
return sr_pil_list, prompt_list
|