warshanks commited on
Commit
b60fb62
·
1 Parent(s): cb74b60
Files changed (5) hide show
  1. README.md +10 -6
  2. app.py +210 -44
  3. requirements.txt +251 -1
  4. style.css +11 -0
  5. uv.lock +0 -0
README.md CHANGED
@@ -1,12 +1,16 @@
1
  ---
2
- title: Medgemma 4b
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
 
 
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
+ title: MedGemma 4B IT
3
+ models: [google/medgemma-4b-it]
4
+ preload_from_hub: google/medgemma-4b-it
5
+ emoji: 🩻
6
+ colorFrom: blue
7
+ colorTo: green
8
  sdk: gradio
9
+ sdk_version: 5.21.0
10
  app_file: app.py
11
  pinned: false
12
+ thumbnail: >-
13
+ https://cdn-uploads.huggingface.co/production/uploads/67340377534ff3213928481b/f2kd9Zs0G-chH0ZwfDSOT.png
14
  ---
15
 
16
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,64 +1,230 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
42
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
  demo = gr.ChatInterface(
47
- respond,
 
 
 
 
48
  additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
  ],
 
 
 
 
 
 
 
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import re
5
+ import tempfile
6
+ from collections.abc import Iterator
7
+ from threading import Thread
8
+
9
+ import cv2
10
  import gradio as gr
11
+ import spaces
12
+ import torch
13
+ from loguru import logger
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
17
+ model_id = os.getenv("MODEL_ID", "google/medgemma-4b-it")
18
+ processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
+ model = Gemma3ForConditionalGeneration.from_pretrained(
20
+ model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
+ )
22
+
23
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
+
25
+
26
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
+ image_count = 0
28
+ video_count = 0
29
+ for path in paths:
30
+ if path.endswith(".mp4"):
31
+ video_count += 1
32
+ else:
33
+ image_count += 1
34
+ return image_count, video_count
35
+
36
+
37
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
+ image_count = 0
39
+ video_count = 0
40
+ for item in history:
41
+ if item["role"] != "user" or isinstance(item["content"], str):
42
+ continue
43
+ if item["content"][0].endswith(".mp4"):
44
+ video_count += 1
45
+ else:
46
+ image_count += 1
47
+ return image_count, video_count
48
+
49
+
50
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
+ new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
+ history_image_count, history_video_count = count_files_in_history(history)
53
+ image_count = history_image_count + new_image_count
54
+ video_count = history_video_count + new_video_count
55
+ if video_count > 1:
56
+ gr.Warning("Only one video is supported.")
57
+ return False
58
+ if video_count == 1:
59
+ if image_count > 0:
60
+ gr.Warning("Mixing images and videos is not allowed.")
61
+ return False
62
+ if "<image>" in message["text"]:
63
+ gr.Warning("Using <image> tags with video files is not supported.")
64
+ return False
65
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
66
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
67
+ return False
68
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
69
+ gr.Warning("The number of <image> tags in the text does not match the number of images.")
70
+ return False
71
+ return True
72
+
73
+
74
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
75
+ vidcap = cv2.VideoCapture(video_path)
76
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
77
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
78
+
79
+ frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
80
+ frames: list[tuple[Image.Image, float]] = []
81
+
82
+ for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
83
+ if len(frames) >= MAX_NUM_IMAGES:
84
+ break
85
 
86
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
87
+ success, image = vidcap.read()
88
+ if success:
89
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
90
+ pil_image = Image.fromarray(image)
91
+ timestamp = round(i / fps, 2)
92
+ frames.append((pil_image, timestamp))
93
 
94
+ vidcap.release()
95
+ return frames
 
 
 
 
 
 
 
96
 
 
 
 
 
 
97
 
98
+ def process_video(video_path: str) -> list[dict]:
99
+ content = []
100
+ frames = downsample_video(video_path)
101
+ for frame in frames:
102
+ pil_image, timestamp = frame
103
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
104
+ pil_image.save(temp_file.name)
105
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
106
+ content.append({"type": "image", "url": temp_file.name})
107
+ logger.debug(f"{content=}")
108
+ return content
109
 
 
110
 
111
+ def process_interleaved_images(message: dict) -> list[dict]:
112
+ logger.debug(f"{message['files']=}")
113
+ parts = re.split(r"(<image>)", message["text"])
114
+ logger.debug(f"{parts=}")
115
+
116
+ content = []
117
+ image_index = 0
118
+ for part in parts:
119
+ logger.debug(f"{part=}")
120
+ if part == "<image>":
121
+ content.append({"type": "image", "url": message["files"][image_index]})
122
+ logger.debug(f"file: {message['files'][image_index]}")
123
+ image_index += 1
124
+ elif part.strip():
125
+ content.append({"type": "text", "text": part.strip()})
126
+ elif isinstance(part, str) and part != "<image>":
127
+ content.append({"type": "text", "text": part})
128
+ logger.debug(f"{content=}")
129
+ return content
130
+
131
+
132
+ def process_new_user_message(message: dict) -> list[dict]:
133
+ if not message["files"]:
134
+ return [{"type": "text", "text": message["text"]}]
135
+
136
+ if message["files"][0].endswith(".mp4"):
137
+ return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
138
+
139
+ if "<image>" in message["text"]:
140
+ return process_interleaved_images(message)
141
+
142
+ return [
143
+ {"type": "text", "text": message["text"]},
144
+ *[{"type": "image", "url": path} for path in message["files"]],
145
+ ]
146
+
147
+
148
+ def process_history(history: list[dict]) -> list[dict]:
149
+ messages = []
150
+ current_user_content: list[dict] = []
151
+ for item in history:
152
+ if item["role"] == "assistant":
153
+ if current_user_content:
154
+ messages.append({"role": "user", "content": current_user_content})
155
+ current_user_content = []
156
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
157
+ else:
158
+ content = item["content"]
159
+ if isinstance(content, str):
160
+ current_user_content.append({"type": "text", "text": content})
161
+ else:
162
+ current_user_content.append({"type": "image", "url": content[0]})
163
+ return messages
164
+
165
+
166
+ @spaces.GPU(duration=120)
167
+ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 2048) -> Iterator[str]:
168
+ if not validate_media_constraints(message, history):
169
+ yield ""
170
+ return
171
+
172
+ messages = []
173
+ if system_prompt:
174
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
175
+ messages.extend(process_history(history))
176
+ messages.append({"role": "user", "content": process_new_user_message(message)})
177
+
178
+ inputs = processor.apply_chat_template(
179
  messages,
180
+ add_generation_prompt=True,
181
+ tokenize=True,
182
+ return_dict=True,
183
+ return_tensors="pt",
184
+ ).to(device=model.device, dtype=torch.bfloat16)
 
185
 
186
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
187
+ generate_kwargs = dict(
188
+ inputs,
189
+ max_new_tokens=max_new_tokens,
190
+ streamer=streamer,
191
+ temperature=1.0,
192
+ top_p=0.95,
193
+ top_k=64,
194
+ min_p=0.0,
195
+ )
196
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
197
+ t.start()
198
 
199
+ output = ""
200
+ for delta in streamer:
201
+ output += delta
202
+ yield output
203
 
204
+
205
+ DESCRIPTION = """\
206
+ This is a demo of MedGemma, a Gemma 3 variant trained for performance on medical text and image comprehension.
207
+ You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.
208
  """
209
+
 
210
  demo = gr.ChatInterface(
211
+ fn=run,
212
+ type="messages",
213
+ chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
214
+ textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
215
+ multimodal=True,
216
  additional_inputs=[
217
+ gr.Textbox(label="System Prompt", value=""),
218
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=8192, step=10, value=2048),
 
 
 
 
 
 
 
 
219
  ],
220
+ stop_btn=False,
221
+ title="MedGemma 4B IT",
222
+ description=DESCRIPTION,
223
+ run_examples_on_click=False,
224
+ cache_examples=False,
225
+ css_paths="style.css",
226
+ delete_cache=(1800, 1800),
227
  )
228
 
 
229
  if __name__ == "__main__":
230
  demo.launch()
requirements.txt CHANGED
@@ -1 +1,251 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.4.0
4
+ # via gemma-3-12b-it (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.8.0
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ certifi==2025.1.31
15
+ # via
16
+ # httpcore
17
+ # httpx
18
+ # requests
19
+ charset-normalizer==3.4.1
20
+ # via requests
21
+ click==8.1.8
22
+ # via
23
+ # typer
24
+ # uvicorn
25
+ exceptiongroup==1.2.2
26
+ # via anyio
27
+ fastapi==0.115.11
28
+ # via gradio
29
+ ffmpy==0.5.0
30
+ # via gradio
31
+ filelock==3.17.0
32
+ # via
33
+ # huggingface-hub
34
+ # torch
35
+ # transformers
36
+ # triton
37
+ fsspec==2025.3.0
38
+ # via
39
+ # gradio-client
40
+ # huggingface-hub
41
+ # torch
42
+ gradio==5.21.0
43
+ # via
44
+ # gemma-3-12b-it (pyproject.toml)
45
+ # spaces
46
+ gradio-client==1.7.2
47
+ # via gradio
48
+ groovy==0.1.2
49
+ # via gradio
50
+ h11==0.14.0
51
+ # via
52
+ # httpcore
53
+ # uvicorn
54
+ hf-transfer==0.1.9
55
+ # via gemma-3-12b-it (pyproject.toml)
56
+ httpcore==1.0.7
57
+ # via httpx
58
+ httpx==0.28.1
59
+ # via
60
+ # gradio
61
+ # gradio-client
62
+ # safehttpx
63
+ # spaces
64
+ huggingface-hub==0.29.2
65
+ # via
66
+ # accelerate
67
+ # gradio
68
+ # gradio-client
69
+ # tokenizers
70
+ # transformers
71
+ idna==3.10
72
+ # via
73
+ # anyio
74
+ # httpx
75
+ # requests
76
+ jinja2==3.1.6
77
+ # via
78
+ # gradio
79
+ # torch
80
+ loguru==0.7.3
81
+ # via gemma-3-12b-it (pyproject.toml)
82
+ markdown-it-py==3.0.0
83
+ # via rich
84
+ markupsafe==2.1.5
85
+ # via
86
+ # gradio
87
+ # jinja2
88
+ mdurl==0.1.2
89
+ # via markdown-it-py
90
+ mpmath==1.3.0
91
+ # via sympy
92
+ networkx==3.4.2
93
+ # via torch
94
+ numpy==2.2.3
95
+ # via
96
+ # accelerate
97
+ # gradio
98
+ # opencv-python-headless
99
+ # pandas
100
+ # transformers
101
+ nvidia-cublas-cu12==12.1.3.1
102
+ # via
103
+ # nvidia-cudnn-cu12
104
+ # nvidia-cusolver-cu12
105
+ # torch
106
+ nvidia-cuda-cupti-cu12==12.1.105
107
+ # via torch
108
+ nvidia-cuda-nvrtc-cu12==12.1.105
109
+ # via torch
110
+ nvidia-cuda-runtime-cu12==12.1.105
111
+ # via torch
112
+ nvidia-cudnn-cu12==9.1.0.70
113
+ # via torch
114
+ nvidia-cufft-cu12==11.0.2.54
115
+ # via torch
116
+ nvidia-curand-cu12==10.3.2.106
117
+ # via torch
118
+ nvidia-cusolver-cu12==11.4.5.107
119
+ # via torch
120
+ nvidia-cusparse-cu12==12.1.0.106
121
+ # via
122
+ # nvidia-cusolver-cu12
123
+ # torch
124
+ nvidia-nccl-cu12==2.20.5
125
+ # via torch
126
+ nvidia-nvjitlink-cu12==12.8.93
127
+ # via
128
+ # nvidia-cusolver-cu12
129
+ # nvidia-cusparse-cu12
130
+ nvidia-nvtx-cu12==12.1.105
131
+ # via torch
132
+ opencv-python-headless==4.11.0.86
133
+ # via gemma-3-12b-it (pyproject.toml)
134
+ orjson==3.10.15
135
+ # via gradio
136
+ packaging==24.2
137
+ # via
138
+ # accelerate
139
+ # gradio
140
+ # gradio-client
141
+ # huggingface-hub
142
+ # spaces
143
+ # transformers
144
+ pandas==2.2.3
145
+ # via gradio
146
+ pillow==11.1.0
147
+ # via gradio
148
+ protobuf==6.30.0
149
+ # via gemma-3-12b-it (pyproject.toml)
150
+ psutil==5.9.8
151
+ # via
152
+ # accelerate
153
+ # spaces
154
+ pydantic==2.10.6
155
+ # via
156
+ # fastapi
157
+ # gradio
158
+ # spaces
159
+ pydantic-core==2.27.2
160
+ # via pydantic
161
+ pydub==0.25.1
162
+ # via gradio
163
+ pygments==2.19.1
164
+ # via rich
165
+ python-dateutil==2.9.0.post0
166
+ # via pandas
167
+ python-multipart==0.0.20
168
+ # via gradio
169
+ pytz==2025.1
170
+ # via pandas
171
+ pyyaml==6.0.2
172
+ # via
173
+ # accelerate
174
+ # gradio
175
+ # huggingface-hub
176
+ # transformers
177
+ regex==2024.11.6
178
+ # via transformers
179
+ requests==2.32.3
180
+ # via
181
+ # huggingface-hub
182
+ # spaces
183
+ # transformers
184
+ rich==13.9.4
185
+ # via typer
186
+ ruff==0.9.10
187
+ # via gradio
188
+ safehttpx==0.1.6
189
+ # via gradio
190
+ safetensors==0.5.3
191
+ # via
192
+ # accelerate
193
+ # transformers
194
+ semantic-version==2.10.0
195
+ # via gradio
196
+ sentencepiece==0.2.0
197
+ # via gemma-3-12b-it (pyproject.toml)
198
+ shellingham==1.5.4
199
+ # via typer
200
+ six==1.17.0
201
+ # via python-dateutil
202
+ sniffio==1.3.1
203
+ # via anyio
204
+ spaces==0.32.0
205
+ # via gemma-3-12b-it (pyproject.toml)
206
+ starlette==0.46.1
207
+ # via
208
+ # fastapi
209
+ # gradio
210
+ sympy==1.13.3
211
+ # via torch
212
+ tokenizers==0.21.0
213
+ # via transformers
214
+ tomlkit==0.13.2
215
+ # via gradio
216
+ torch==2.4.0
217
+ # via
218
+ # gemma-3-12b-it (pyproject.toml)
219
+ # accelerate
220
+ tqdm==4.67.1
221
+ # via
222
+ # huggingface-hub
223
+ # transformers
224
+ transformers @ git+https://github.com/huggingface/transformers@2829013d2d00e63d75a1f6f7a3f003bc60cc69af
225
+ # via gemma-3-12b-it (pyproject.toml)
226
+ triton==3.0.0
227
+ # via torch
228
+ typer==0.15.2
229
+ # via gradio
230
+ typing-extensions==4.12.2
231
+ # via
232
+ # anyio
233
+ # fastapi
234
+ # gradio
235
+ # gradio-client
236
+ # huggingface-hub
237
+ # pydantic
238
+ # pydantic-core
239
+ # rich
240
+ # spaces
241
+ # torch
242
+ # typer
243
+ # uvicorn
244
+ tzdata==2025.1
245
+ # via pandas
246
+ urllib3==2.3.0
247
+ # via requests
248
+ uvicorn==0.34.0
249
+ # via gradio
250
+ websockets==15.0.1
251
+ # via gradio-client
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #logo {
7
+ display: block;
8
+ margin: 0 auto;
9
+ width: 40%;
10
+ object-fit: contain;
11
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff