Spaces:
Runtime error
Runtime error
File size: 2,153 Bytes
7e327f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import modal
import torch
from smolagents import AgentImage, Tool
from diffusers import StableDiffusionUpscalePipeline
from .app import app
from .image import image
@app.cls(gpu="T4", image=image, scaledown_window=60 * 5)
class RemoteUpscalerModalApp:
@modal.enter()
def setup(self):
model_id = "stabilityai/stable-diffusion-x4-upscaler"
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
model_id, torch_dtype=torch.float16
)
self.pipeline = self.pipeline.to("cuda")
@modal.batched(max_batch_size=4, wait_ms=1000)
def forward(self, low_res_imgs, prompts: list[str]):
print(len(low_res_imgs))
print(low_res_imgs)
print(prompts)
low_res_imgs = [
img.resize(
(min(512, img.width), min(512, img.height))
) for img in low_res_imgs
]
upscaled_images = self.pipeline(prompt=prompts, image=low_res_imgs).images
return upscaled_images
class RemoteUpscalerTool(Tool):
name = "upscaler"
description = """
Perform upscaling on images.
The "low_res_imgs" are PIL images.
The "prompts" are strings.
The output is a list of PIL images.
You can upscale multiple images at once.
"""
inputs = {
"low_res_imgs": {
"type": "array",
"description": "The low resolution images to upscale",
},
"prompts": {
"type": "array",
"description": "The prompts to upscale the images",
},
}
output_type = "object"
def __init__(self):
super().__init__()
tool_class = modal.Cls.from_name(app.name, RemoteUpscalerModalApp.__name__)
self.tool = tool_class()
def forward(self, low_res_imgs: list[AgentImage], prompts: list[str]):
# Modal's forward.map() handles batching internally
# We can use it synchronously since Modal manages the async execution
upscaled_images = self.tool.forward.map(low_res_imgs, prompts)
# Convert the generator to a list to get all results
return list(upscaled_images)
|