sayakpaul HF Staff commited on
Commit
9ae5cf2
·
verified ·
1 Parent(s): a29caea

Create nano_banana.py

Browse files
Files changed (1) hide show
  1. nano_banana.py +83 -0
nano_banana.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+ from diffusers.modular_pipelines import (
4
+ PipelineState,
5
+ ModularPipelineBlocks,
6
+ InputParam,
7
+ OutputParam,
8
+ )
9
+ from PIL import Image
10
+ import google.generativeai as genai
11
+ import os
12
+
13
+
14
+ class NanoBanana(ModularPipelineBlocks):
15
+ def __init__(self, model_id="gemini-2.5-flash-image-preview"):
16
+ super().__init__()
17
+ api_key = os.getenv("GEMINI_API_KEY")
18
+ if api_key is None:
19
+ raise ValueError("Must provide an API key for Gemini through the `GEMINI_API_KEY` env variable.")
20
+ genai.configure(api_key=api_key)
21
+ self.model = genai.GenerativeModel(model_name=model_id)
22
+
23
+ @property
24
+ def expected_components(self):
25
+ return []
26
+
27
+ @property
28
+ def inputs(self) -> List[InputParam]:
29
+ return [
30
+ InputParam(
31
+ "image",
32
+ type_hint=Image.Image,
33
+ required=False,
34
+ description="Image to use"
35
+ ),
36
+ InputParam(
37
+ "prompt",
38
+ type_hint=str,
39
+ required=True,
40
+ description="Prompt to use",
41
+ )
42
+ ]
43
+
44
+ @property
45
+ def intermediate_inputs(self) -> List[InputParam]:
46
+ return []
47
+
48
+ @property
49
+ def intermediate_outputs(self) -> List[OutputParam]:
50
+ return [
51
+ OutputParam(
52
+ "output_image",
53
+ type_hint=str,
54
+ description="Output image",
55
+ ),
56
+ OutputParam(
57
+ "old_image",
58
+ type_hint=str,
59
+ description="Old image (if) provided by the user",
60
+ )
61
+ ]
62
+
63
+
64
+ def __call__(self, components, state: PipelineState) -> PipelineState:
65
+ block_state = self.get_block_state(state)
66
+
67
+ old_image = block_state.image
68
+ prompt = block_state.state.prompt
69
+ contents = [prompt]
70
+ if old_image is not None:
71
+ contents.expand(old_image)
72
+
73
+ output = self.model.generate_content(contents=contents)
74
+ block_state.output_image = output
75
+
76
+ if old_image is not None:
77
+ block_state.old_image = old_image
78
+ else:
79
+ block_state.old_image = None
80
+
81
+ self.set_block_state(state, block_state)
82
+
83
+ return components, state