File size: 2,414 Bytes
9ae5cf2
 
 
 
 
 
 
 
 
01b885e
 
9ae5cf2
01b885e
9ae5cf2
 
 
 
01b885e
 
9ae5cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01b885e
9ae5cf2
 
 
 
01b885e
9ae5cf2
 
 
 
 
 
 
 
01b885e
9ae5cf2
 
01b885e
9ae5cf2
01b885e
 
 
 
 
 
 
 
 
9ae5cf2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

from typing import List
from diffusers.modular_pipelines import (
    PipelineState,
    ModularPipelineBlocks,
    InputParam,
    OutputParam,
)
from PIL import Image
from google import genai
from io import BytesIO

client = genai.Client()

class NanoBanana(ModularPipelineBlocks):
    def __init__(self, model_id="gemini-2.5-flash-image-preview"):
        super().__init__()
        # Cannot initialize the client because it throws a pickling error.
        self.model_id = model_id

    @property
    def expected_components(self):
        return []

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam(
                "image",
                type_hint=Image.Image,
                required=False,
                description="Image to use"
            ),
            InputParam(
                "prompt",
                type_hint=str,
                required=True,
                description="Prompt to use",
            )
        ]
    
    @property
    def intermediate_inputs(self) -> List[InputParam]:
        return []

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "output_image",
                type_hint=Image.Image,
                description="Output image",
            ),
            OutputParam(
                "old_image",
                type_hint=Image.Image,
                description="Old image (if) provided by the user",
            )
        ]

    def __call__(self, components, state: PipelineState) -> PipelineState:
        block_state = self.get_block_state(state)

        old_image = block_state.image
        prompt = block_state.prompt
        contents = [prompt]
        if old_image is not None:
            contents.append(old_image)
        
        response = client.models.generate_content(
            model=self.model_id, contents=contents

        )
        for part in response.candidates[0].content.parts:
            if part.text is not None:
                continue
            elif part.inline_data is not None:
                block_state.output_image = Image.open(BytesIO(part.inline_data.data))
        
        if old_image is not None:
            block_state.old_image = old_image
        else:
            block_state.old_image = None 

        self.set_block_state(state, block_state)

        return components, state