nano-banana-modular / nano_banana.py
sayakpaul's picture
sayakpaul HF Staff
Update nano_banana.py
01b885e verified
raw
history blame
2.41 kB
from typing import List
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
OutputParam,
)
from PIL import Image
from google import genai
from io import BytesIO
client = genai.Client()
class NanoBanana(ModularPipelineBlocks):
def __init__(self, model_id="gemini-2.5-flash-image-preview"):
super().__init__()
# Cannot initialize the client because it throws a pickling error.
self.model_id = model_id
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Image.Image,
required=False,
description="Image to use"
),
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
)
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"output_image",
type_hint=Image.Image,
description="Output image",
),
OutputParam(
"old_image",
type_hint=Image.Image,
description="Old image (if) provided by the user",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_image = block_state.image
prompt = block_state.prompt
contents = [prompt]
if old_image is not None:
contents.append(old_image)
response = client.models.generate_content(
model=self.model_id, contents=contents
)
for part in response.candidates[0].content.parts:
if part.text is not None:
continue
elif part.inline_data is not None:
block_state.output_image = Image.open(BytesIO(part.inline_data.data))
if old_image is not None:
block_state.old_image = old_image
else:
block_state.old_image = None
self.set_block_state(state, block_state)
return components, state