Spaces:
Running
on
Zero
Running
on
Zero
import modal | |
import torch | |
from smolagents import AgentImage, Tool | |
from diffusers import StableDiffusionUpscalePipeline | |
from .app import app | |
from .image import image | |
class RemoteUpscalerModalApp: | |
def setup(self): | |
model_id = "stabilityai/stable-diffusion-x4-upscaler" | |
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained( | |
model_id, torch_dtype=torch.float16 | |
) | |
self.pipeline = self.pipeline.to("cuda") | |
def forward(self, low_res_imgs, prompts: list[str]): | |
print(len(low_res_imgs)) | |
print(low_res_imgs) | |
print(prompts) | |
low_res_imgs = [ | |
img.resize( | |
(min(512, img.width), min(512, img.height)) | |
) for img in low_res_imgs | |
] | |
upscaled_images = self.pipeline(prompt=prompts, image=low_res_imgs).images | |
return upscaled_images | |
class RemoteUpscalerTool(Tool): | |
name = "upscaler" | |
description = """ | |
Perform upscaling on images. | |
The "low_res_imgs" are PIL images. | |
The "prompts" are strings. | |
The output is a list of PIL images. | |
You can upscale multiple images at once. | |
""" | |
inputs = { | |
"low_res_imgs": { | |
"type": "array", | |
"description": "The low resolution images to upscale", | |
}, | |
"prompts": { | |
"type": "array", | |
"description": "The prompts to upscale the images", | |
}, | |
} | |
output_type = "object" | |
def __init__(self): | |
super().__init__() | |
tool_class = modal.Cls.from_name(app.name, RemoteUpscalerModalApp.__name__) | |
self.tool = tool_class() | |
def forward(self, low_res_imgs: list[AgentImage], prompts: list[str]): | |
# Modal's forward.map() handles batching internally | |
# We can use it synchronously since Modal manages the async execution | |
upscaled_images = self.tool.forward.map(low_res_imgs, prompts) | |
# Convert the generator to a list to get all results | |
return list(upscaled_images) | |