import modal import torch from smolagents import AgentImage, Tool from diffusers import StableDiffusionUpscalePipeline from .app import app from .image import image @app.cls(gpu="T4", image=image, scaledown_window=60 * 5) class RemoteUpscalerModalApp: @modal.enter() def setup(self): model_id = "stabilityai/stable-diffusion-x4-upscaler" self.pipeline = StableDiffusionUpscalePipeline.from_pretrained( model_id, torch_dtype=torch.float16 ) self.pipeline = self.pipeline.to("cuda") @modal.batched(max_batch_size=4, wait_ms=1000) def forward(self, low_res_imgs, prompts: list[str]): print(len(low_res_imgs)) print(low_res_imgs) print(prompts) low_res_imgs = [ img.resize( (min(512, img.width), min(512, img.height)) ) for img in low_res_imgs ] upscaled_images = self.pipeline(prompt=prompts, image=low_res_imgs).images return upscaled_images class RemoteUpscalerTool(Tool): name = "upscaler" description = """ Perform upscaling on images. The "low_res_imgs" are PIL images. The "prompts" are strings. The output is a list of PIL images. You can upscale multiple images at once. """ inputs = { "low_res_imgs": { "type": "array", "description": "The low resolution images to upscale", }, "prompts": { "type": "array", "description": "The prompts to upscale the images", }, } output_type = "object" def __init__(self): super().__init__() tool_class = modal.Cls.from_name(app.name, RemoteUpscalerModalApp.__name__) self.tool = tool_class() def forward(self, low_res_imgs: list[AgentImage], prompts: list[str]): # Modal's forward.map() handles batching internally # We can use it synchronously since Modal manages the async execution upscaled_images = self.tool.forward.map(low_res_imgs, prompts) # Convert the generator to a list to get all results return list(upscaled_images)