baliddeki commited on
Commit
79c27a2
·
verified ·
1 Parent(s): 6613891

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +16 -0
  2. README.md +7 -6
  3. app.py +129 -0
  4. model.py +56 -0
  5. requirements.txt +16 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY . /app
9
+
10
+ # Install the dependencies
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Expose port 8000
14
+ EXPOSE 7860
15
+ # CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Phronesis Ml Endpoint
3
- emoji: 🌍
4
  colorFrom: green
5
- colorTo: red
6
- sdk: docker
 
 
7
  pinned: false
8
- license: mit
9
- short_description: ML endpoints
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Phronesis
3
+ emoji: 🌖
4
  colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: app.py
9
  pinned: false
10
+ short_description: 'REPORT GEN AND CLASSIFICATION MODEL '
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #app.py
2
+ import os
3
+ import io
4
+ import uvicorn
5
+
6
+ import torch
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+ from torchvision import models, transforms
11
+ from PIL import Image
12
+ import numpy as np
13
+ from huggingface_hub import hf_hub_download
14
+ import pydicom
15
+ import gc
16
+ from model import CombinedModel, ImageToTextProjector
17
+ from fastapi import FastAPI, Request
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+
20
+
21
+ app = FastAPI()
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ @app.get("/")
32
+ async def root(request: Request):
33
+ return {"message": "Welcome to Phronesis"}
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ def dicom_to_png(dicom_data):
38
+ try:
39
+ dicom_file = pydicom.dcmread(dicom_data)
40
+ if not hasattr(dicom_file, 'PixelData'):
41
+ raise HTTPException(status_code=400, detail="No pixel data in DICOM file.")
42
+
43
+ pixel_array = dicom_file.pixel_array.astype(np.float32)
44
+ pixel_array = ((pixel_array - pixel_array.min()) / (pixel_array.ptp())) * 255.0
45
+ pixel_array = pixel_array.astype(np.uint8)
46
+
47
+ img = Image.fromarray(pixel_array).convert("L")
48
+ return img
49
+ except Exception as e:
50
+ raise HTTPException(status_code=500, detail=f"Error converting DICOM to PNG: {e}")
51
+
52
+ # Set up secure model initialization
53
+ HF_TOKEN = os.getenv('HF_TOKEN')
54
+ if not HF_TOKEN:
55
+ raise ValueError("Missing Hugging Face token in environment variables.")
56
+
57
+ try:
58
+ report_generator_tokenizer = AutoTokenizer.from_pretrained(
59
+ "KYAGABA/combined-multimodal-model",
60
+ token=HF_TOKEN if HF_TOKEN else None
61
+ )
62
+ video_model = models.video.r3d_18(weights="KINETICS400_V1")
63
+ video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
64
+ report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
65
+ projector = ImageToTextProjector(512, report_generator.config.d_model)
66
+ num_classes = 4
67
+ combined_model = CombinedModel(video_model, report_generator, num_classes, projector, report_generator_tokenizer)
68
+ model_file = hf_hub_download("KYAGABA/combined-multimodal-model", "pytorch_model.bin", token=HF_TOKEN)
69
+ state_dict = torch.load(model_file, map_location=device)
70
+ combined_model.load_state_dict(state_dict)
71
+ combined_model.eval()
72
+ except Exception as e:
73
+ raise SystemExit(f"Error loading models: {e}")
74
+
75
+ image_transform = transforms.Compose([
76
+ transforms.Resize((112, 112)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
79
+ ])
80
+
81
+ class_names = ["acute", "normal", "chronic", "lacunar"]
82
+
83
+ @app.post("/predict/")
84
+ async def predict(files: list[UploadFile]):
85
+ print(f"Received {len(files)} files")
86
+ n_frames = 16
87
+ images = []
88
+
89
+ for file in files:
90
+ ext = file.filename.split('.')[-1].lower()
91
+ try:
92
+ if ext in ['dcm', 'ima']:
93
+ dicom_img = dicom_to_png(await file.read())
94
+ images.append(dicom_img.convert("RGB"))
95
+ elif ext in ['png', 'jpeg', 'jpg']:
96
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
97
+ images.append(img)
98
+ else:
99
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
100
+ except Exception as e:
101
+ raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
102
+
103
+ if not images:
104
+ return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
105
+
106
+ if len(images) >= n_frames:
107
+ images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
108
+ else:
109
+ images_sampled = images + [images[-1]] * (n_frames - len(images))
110
+
111
+ image_tensors = [image_transform(img) for img in images_sampled]
112
+ images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device)
113
+
114
+ with torch.no_grad():
115
+ class_outputs, generated_report, _ = combined_model(images_tensor)
116
+ predicted_class = torch.argmax(class_outputs, dim=1).item()
117
+ predicted_class_name = class_names[predicted_class]
118
+
119
+ gc.collect()
120
+ if torch.cuda.is_available():
121
+ torch.cuda.empty_cache()
122
+
123
+ return {
124
+ "predicted_class": predicted_class_name,
125
+ "generated_report": generated_report[0] if generated_report else "No report generated."
126
+ }
127
+
128
+ if __name__ == "__main__":
129
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
model.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModelForSeq2SeqLM
6
+
7
+ class ImageToTextProjector(nn.Module):
8
+ def __init__(self, image_embedding_dim, text_embedding_dim):
9
+ super(ImageToTextProjector, self).__init__()
10
+ self.fc = nn.Linear(image_embedding_dim, text_embedding_dim)
11
+ self.activation = nn.ReLU()
12
+ self.dropout = nn.Dropout(p=0.5)
13
+
14
+ def forward(self, x):
15
+ x = self.fc(x)
16
+ x = self.activation(x)
17
+ x = self.dropout(x)
18
+ return x
19
+
20
+ class CombinedModel(nn.Module):
21
+ def __init__(self, video_model, report_generator, num_classes, projector, tokenizer):
22
+ super(CombinedModel, self).__init__()
23
+ self.video_model = video_model
24
+ self.report_generator = report_generator
25
+ self.classifier = nn.Linear(512, num_classes)
26
+ self.projector = projector
27
+ self.dropout = nn.Dropout(p=0.5)
28
+ self.tokenizer = tokenizer # Store tokenizer
29
+
30
+ def forward(self, images, labels=None):
31
+ video_embeddings = self.video_model(images)
32
+ video_embeddings = self.dropout(video_embeddings)
33
+ class_outputs = self.classifier(video_embeddings)
34
+ projected_embeddings = self.projector(video_embeddings)
35
+ encoder_inputs = projected_embeddings.unsqueeze(1)
36
+
37
+ if labels is not None:
38
+ outputs = self.report_generator(
39
+ inputs_embeds=encoder_inputs,
40
+ labels=labels
41
+ )
42
+ gen_loss = outputs.loss
43
+ generated_report = None
44
+ else:
45
+ generated_report_ids = self.report_generator.generate(
46
+ inputs_embeds=encoder_inputs,
47
+ max_length=512,
48
+ num_beams=4,
49
+ early_stopping=True
50
+ )
51
+ generated_report = self.tokenizer.batch_decode(
52
+ generated_report_ids, skip_special_tokens=True
53
+ )
54
+ gen_loss = None
55
+
56
+ return class_outputs, generated_report, gen_loss
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch==2.0.1
3
+ torchvision==0.15.2
4
+ transformers==4.44.2
5
+ gradio==5.0
6
+ numpy==1.26.2
7
+ Pillow==10.0.1
8
+ fastapi
9
+ # Additional dependencies
10
+ huggingface_hub==0.25.1 # Compatible with both transformers and gradio
11
+ torchmetrics==1.5.1
12
+ nltk==3.8.1
13
+ scikit-learn==1.3.0
14
+ tqdm==4.66.1
15
+ sentencepiece==0.1.99
16
+ pydicom==2.4.1