AmeyaKawthalkar commited on
Commit
c4fd269
·
verified ·
1 Parent(s): 075c1af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -54
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # app.py
2
  # FastAPI backend for a Hugging Face Space (CPU tier)
3
  # • Only MedGemma-4B-IT, no Parakeet, no tool-calling
4
- # • Reads HF_TOKEN from Space secrets, uses /data/.cache as writable cache
5
  # • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]}
6
 
7
  import os, pathlib, uuid
@@ -17,16 +17,16 @@ from transformers import pipeline
17
  # ------------------------------------------------------------
18
  # 1. Configure cache + authentication BEFORE loading models
19
  # ------------------------------------------------------------
20
- HOME_DIR = pathlib.Path.home()
21
- CACHE_DIR = HOME_DIR / ".cache" / "huggingface"
22
- CACHE_DIR.mkdir(parents=True, exist_ok=True) # ← always writable
23
 
24
- os.environ["HF_HOME"] = str(CACHE_DIR)
 
 
 
 
25
  os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR)
26
 
27
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") # fine-grained read token
28
 
29
-
30
  # ------------------------------------------------------------
31
  # 2. Simple Pydantic request model
32
  # ------------------------------------------------------------
@@ -47,17 +47,21 @@ medgemma_pipe = None
47
  def get_medgemma():
48
  global medgemma_pipe
49
  if medgemma_pipe is None:
50
- print("Loading MedGemma-4B-IT …")
51
- medgemma_pipe = pipeline(
52
- "text-generation",
53
- model="google/medgemma-4b-it",
54
- torch_dtype=DTYPE,
55
- device=0 if torch.cuda.is_available() else -1,
56
- token=HF_TOKEN, # authenticate to gated repo :contentReference[oaicite:5]{index=5}
57
- cache_dir=CACHE_DIR,
58
- trust_remote_code=True,
59
- )
60
- print("✅ MedGemma loaded")
 
 
 
 
61
  return medgemma_pipe
62
 
63
  # ------------------------------------------------------------
@@ -87,67 +91,88 @@ SYSTEM_PROMPT = (
87
  # ------------------------------------------------------------
88
  @app.post("/chat")
89
  async def chat(request: Request):
90
- body = await request.json()
91
- payload = ChatCompletionRequest(**body)
92
- user_msg = payload.messages[-1].content or ""
93
- prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- pipe = get_medgemma()
96
- if pipe is None:
97
  return JSONResponse(
98
  {
99
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
100
  "choices": [{
101
  "message": {
102
  "role": "assistant",
103
- "content": "MedGemma model is unavailable. "
104
- "Check your gated-model access and HF_TOKEN.",
105
  }
106
- }],
107
- },
108
- status_code=503,
109
  )
110
-
111
- try:
112
- result = pipe(
113
- prompt,
114
- max_new_tokens=256,
115
- do_sample=True,
116
- temperature=0.7,
117
- pad_token_id=pipe.tokenizer.eos_token_id,
118
- return_full_text=False,
119
- )
120
- assistant_text = result[0]["generated_text"].strip() if result else "No response."
121
  except Exception as e:
122
- print("Generation error:", e)
123
- assistant_text = "Error generating response. Please retry later."
124
-
125
- return JSONResponse(
126
- {
127
- "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
128
- "choices": [{
129
- "message": {
130
- "role": "assistant",
131
- "content": assistant_text,
132
- }
133
- }]
134
- }
135
- )
136
 
137
  # ------------------------------------------------------------
138
  # 7. Health endpoint
139
  # ------------------------------------------------------------
 
 
 
 
140
  @app.get("/health")
141
  async def health():
142
  return {
143
  "status": "ok",
144
  "model_loaded": medgemma_pipe is not None,
145
  "hf_token_present": bool(HF_TOKEN),
 
146
  }
147
 
148
  # ------------------------------------------------------------
149
- # 8. For local dev (wont run inside Space runtime)
150
  # ------------------------------------------------------------
151
  if __name__ == "__main__":
152
  import uvicorn
153
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  # app.py
2
  # FastAPI backend for a Hugging Face Space (CPU tier)
3
  # • Only MedGemma-4B-IT, no Parakeet, no tool-calling
4
+ # • Reads HF_TOKEN from Space secrets, uses /tmp for writable cache
5
  # • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]}
6
 
7
  import os, pathlib, uuid
 
17
  # ------------------------------------------------------------
18
  # 1. Configure cache + authentication BEFORE loading models
19
  # ------------------------------------------------------------
 
 
 
20
 
21
+ # Use /tmp for cache in HF Spaces (always writable)
22
+ CACHE_DIR = pathlib.Path("/tmp/hf_cache")
23
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
24
+
25
+ os.environ["HF_HOME"] = str(CACHE_DIR)
26
  os.environ["TRANSFORMERS_CACHE"] = str(CACHE_DIR)
27
 
28
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") # fine-grained read token
29
 
 
30
  # ------------------------------------------------------------
31
  # 2. Simple Pydantic request model
32
  # ------------------------------------------------------------
 
47
  def get_medgemma():
48
  global medgemma_pipe
49
  if medgemma_pipe is None:
50
+ try:
51
+ print("Loading MedGemma-4B-IT …")
52
+ medgemma_pipe = pipeline(
53
+ "text-generation",
54
+ model="google/medgemma-4b-it",
55
+ torch_dtype=DTYPE,
56
+ device=0 if torch.cuda.is_available() else -1,
57
+ token=HF_TOKEN, # authenticate to gated repo
58
+ cache_dir=CACHE_DIR,
59
+ trust_remote_code=True,
60
+ )
61
+ print("✅ MedGemma loaded successfully")
62
+ except Exception as e:
63
+ print(f"❌ Error loading MedGemma: {e}")
64
+ medgemma_pipe = None
65
  return medgemma_pipe
66
 
67
  # ------------------------------------------------------------
 
91
  # ------------------------------------------------------------
92
  @app.post("/chat")
93
  async def chat(request: Request):
94
+ try:
95
+ body = await request.json()
96
+ payload = ChatCompletionRequest(**body)
97
+ user_msg = payload.messages[-1].content or ""
98
+ prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n"
99
+
100
+ pipe = get_medgemma()
101
+ if pipe is None:
102
+ return JSONResponse(
103
+ {
104
+ "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
105
+ "choices": [{
106
+ "message": {
107
+ "role": "assistant",
108
+ "content": "MedGemma model is unavailable. "
109
+ "Check your gated-model access and HF_TOKEN.",
110
+ }
111
+ }],
112
+ },
113
+ status_code=503,
114
+ )
115
+
116
+ try:
117
+ result = pipe(
118
+ prompt,
119
+ max_new_tokens=256,
120
+ do_sample=True,
121
+ temperature=0.7,
122
+ pad_token_id=pipe.tokenizer.eos_token_id,
123
+ return_full_text=False,
124
+ )
125
+ assistant_text = result[0]["generated_text"].strip() if result else "No response."
126
+ except Exception as e:
127
+ print("Generation error:", e)
128
+ assistant_text = "Error generating response. Please retry later."
129
 
 
 
130
  return JSONResponse(
131
  {
132
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
133
  "choices": [{
134
  "message": {
135
  "role": "assistant",
136
+ "content": assistant_text,
 
137
  }
138
+ }]
139
+ }
 
140
  )
 
 
 
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
+ print(f"Chat endpoint error: {e}")
143
+ return JSONResponse(
144
+ {
145
+ "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
146
+ "choices": [{
147
+ "message": {
148
+ "role": "assistant",
149
+ "content": "Server error. Please try again later.",
150
+ }
151
+ }]
152
+ },
153
+ status_code=500
154
+ )
 
155
 
156
  # ------------------------------------------------------------
157
  # 7. Health endpoint
158
  # ------------------------------------------------------------
159
+ @app.get("/")
160
+ async def root():
161
+ return {"status": "healthy", "message": "MedGemma API is running"}
162
+
163
  @app.get("/health")
164
  async def health():
165
  return {
166
  "status": "ok",
167
  "model_loaded": medgemma_pipe is not None,
168
  "hf_token_present": bool(HF_TOKEN),
169
+ "cache_dir": str(CACHE_DIR),
170
  }
171
 
172
  # ------------------------------------------------------------
173
+ # 8. For local dev (won't run inside Space runtime)
174
  # ------------------------------------------------------------
175
  if __name__ == "__main__":
176
  import uvicorn
177
  uvicorn.run(app, host="0.0.0.0", port=7860)
178
+