|
|
|
from typing import List |
|
from diffusers.modular_pipelines import ( |
|
PipelineState, |
|
ModularPipelineBlocks, |
|
InputParam, |
|
OutputParam, |
|
) |
|
from PIL import Image |
|
from google import genai |
|
from io import BytesIO |
|
|
|
client = genai.Client() |
|
|
|
class NanoBanana(ModularPipelineBlocks): |
|
def __init__(self, model_id="gemini-2.5-flash-image-preview"): |
|
super().__init__() |
|
|
|
self.model_id = model_id |
|
|
|
@property |
|
def expected_components(self): |
|
return [] |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"image", |
|
type_hint=Image.Image, |
|
required=False, |
|
description="Image to use" |
|
), |
|
InputParam( |
|
"prompt", |
|
type_hint=str, |
|
required=True, |
|
description="Prompt to use", |
|
) |
|
] |
|
|
|
@property |
|
def intermediate_inputs(self) -> List[InputParam]: |
|
return [] |
|
|
|
@property |
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"output_image", |
|
type_hint=Image.Image, |
|
description="Output image", |
|
), |
|
OutputParam( |
|
"old_image", |
|
type_hint=Image.Image, |
|
description="Old image (if) provided by the user", |
|
) |
|
] |
|
|
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
|
|
old_image = block_state.image |
|
prompt = block_state.prompt |
|
contents = [prompt] |
|
if old_image is not None: |
|
contents.append(old_image) |
|
|
|
response = client.models.generate_content( |
|
model=self.model_id, contents=contents |
|
|
|
) |
|
for part in response.candidates[0].content.parts: |
|
if part.text is not None: |
|
continue |
|
elif part.inline_data is not None: |
|
block_state.output_image = Image.open(BytesIO(part.inline_data.data)) |
|
|
|
if old_image is not None: |
|
block_state.old_image = old_image |
|
else: |
|
block_state.old_image = None |
|
|
|
self.set_block_state(state, block_state) |
|
|
|
return components, state |