|
import os
|
|
import io
|
|
import torch
|
|
import base64
|
|
import zipfile
|
|
from flask import Flask, request, jsonify, render_template, send_file
|
|
from medgan.dcgan import Generator_DCGAN, generate_examples_DCGAN
|
|
from medgan.progan import Generator_ProGAN, generate_examples_ProGAN, seed_everything
|
|
from medgan.stylegan import Generator_SG2, MappingNetwork, generate_examples_SG2
|
|
from medgan.vit import TumorDetectionApp
|
|
from medgan.wgan import Generator_WGAN, generate_examples_WGAN
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
seed_everything()
|
|
|
|
|
|
Z_DIM = 256
|
|
FEATURES_GEN = 64
|
|
CHANNELS_IMG = 3
|
|
progan_steps = 6
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
model_paths = {
|
|
"DCGAN": {
|
|
"Glioma": "models/DCGAN-Glioma.pth",
|
|
"Pituitary": "models/DCGAN-Meningioma.pth",
|
|
"Meningioma": "models/DCGAN-Pituitary.pth",
|
|
},
|
|
"ProGAN": {
|
|
"Glioma": "models/ProGAN-Glioma.pth",
|
|
"Meningioma": "models/ProGAN-Meningioma.pth",
|
|
"Pituitary": "models/ProGAN-Pituitary.pth",
|
|
},
|
|
"StyleGAN2": {
|
|
"Glioma": {
|
|
"generator": "models/StyleGAN2-Glioma.pth",
|
|
"mapping": "models/StyleGAN2-Glioma-MappingNet.pth"
|
|
},
|
|
"Meningioma": {
|
|
"generator": "models/StyleGAN2-Meningioma.pth",
|
|
"mapping": "models/StyleGAN2-Meningioma-MappingNet.pth"
|
|
},
|
|
"Pituitary": {
|
|
"generator": "models/StyleGAN2-Pituitary.pth",
|
|
"mapping": "models/StyleGAN2-Pituitary-MappingNet.pth"
|
|
},
|
|
},
|
|
"WGANs": {
|
|
"Glioma": "models/WGAN-Glioma.pth",
|
|
"Meningioma": "models/WGAN-Pituitary.pth",
|
|
"Pituitary": "models/WGAN-Pituitary.pth",
|
|
}
|
|
}
|
|
|
|
|
|
|
|
dcgan_generators = {}
|
|
for label, path in model_paths["DCGAN"].items():
|
|
model = Generator_DCGAN(1, 256, 64, 3).to(torch.device('cpu'))
|
|
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
|
model.eval()
|
|
dcgan_generators[label] = model
|
|
|
|
|
|
progan_generators = {}
|
|
for label, path in model_paths["ProGAN"].items():
|
|
model = Generator_ProGAN(256, 256, 3).to(torch.device('cpu'))
|
|
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
|
|
model.eval()
|
|
progan_generators[label] = model
|
|
|
|
|
|
stylegan2_generators = {}
|
|
stylegan2_mapping_networks = {}
|
|
for label, paths in model_paths["StyleGAN2"].items():
|
|
gen_model = Generator_SG2(log_resolution=8, W_DIM=256)
|
|
map_net = MappingNetwork(256, 256).to(DEVICE)
|
|
gen_model.load_state_dict(torch.load(paths["generator"], map_location=torch.device('cpu')))
|
|
map_net.load_state_dict(torch.load(paths["mapping"], map_location=torch.device('cpu')))
|
|
gen_model.eval()
|
|
map_net.eval()
|
|
stylegan2_generators[label] = gen_model
|
|
stylegan2_mapping_networks[label] = map_net
|
|
|
|
|
|
wgan_generators = {}
|
|
for label, path in model_paths["WGANs"].items():
|
|
model = Generator_WGAN().to(torch.device('cpu'))
|
|
try:
|
|
|
|
state_dict = torch.load(path, map_location=torch.device('cpu'))
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model.eval()
|
|
wgan_generators[label] = model
|
|
except FileNotFoundError:
|
|
print(f"Checkpoint file not found for {label}: {path}")
|
|
except RuntimeError as e:
|
|
print(f"Error loading WGAN model for {label}: {e}")
|
|
|
|
|
|
|
|
@app.route("/")
|
|
def home():
|
|
return render_template("index.html")
|
|
|
|
@app.route("/about_us")
|
|
def about_us():
|
|
return render_template("About_us.html")
|
|
|
|
@app.route("/generate_info")
|
|
def generate_info():
|
|
return render_template("generate.html")
|
|
|
|
@app.route("/contact")
|
|
def contact():
|
|
return render_template("contact.html")
|
|
|
|
@app.route("/detect_info")
|
|
def detect_info():
|
|
return render_template("detect.html")
|
|
|
|
@app.route("/generate", methods=["POST"])
|
|
def generate():
|
|
data = request.form
|
|
model_type = data.get("model")
|
|
class_name = data.get("class_name")
|
|
num_images = int(data.get("num_images", 1))
|
|
|
|
|
|
if model_type == "DCGANs":
|
|
generators = dcgan_generators
|
|
generation_function = generate_examples_DCGAN
|
|
noise = torch.randn(num_images, Z_DIM, 1, 1).to(torch.device('cpu'))
|
|
elif model_type == "Progressive GANs":
|
|
generators = progan_generators
|
|
generation_function = generate_examples_ProGAN
|
|
noise = torch.randn(num_images, Z_DIM, 1, 1).to(torch.device('cpu'))
|
|
elif model_type == "StyleGAN2":
|
|
generators = stylegan2_generators
|
|
mapping_networks = stylegan2_mapping_networks
|
|
generation_function = generate_examples_SG2
|
|
elif model_type == "WGANs":
|
|
generators = wgan_generators
|
|
generation_function = generate_examples_WGAN
|
|
noise = torch.randn(num_images, 256, 1, 1).to(torch.device('cpu'))
|
|
else:
|
|
return jsonify({"error": "Invalid model type"}), 400
|
|
|
|
if class_name not in generators:
|
|
return jsonify({"error": f"Invalid class name for {model_type}"}), 400
|
|
|
|
if model_type == "StyleGAN2":
|
|
generator = generators[class_name]
|
|
mapping_net = mapping_networks[class_name]
|
|
images_base64, image_buffers = generation_function(generator, mapping_net, num_images)
|
|
else:
|
|
generator = generators[class_name]
|
|
images_base64, image_buffers = generation_function(generator, noise, num_images)
|
|
|
|
|
|
zip_buffer = io.BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
|
for i, buf in enumerate(image_buffers):
|
|
if buf:
|
|
zip_file.writestr(f"generated_image_{i + 1}.png", buf.getvalue())
|
|
zip_buffer.seek(0)
|
|
|
|
|
|
return render_template("results.html", images=images_base64, zip_file=True)
|
|
|
|
@app.route("/download_zip", methods=["GET"])
|
|
def download_zip():
|
|
"""Route to download the ZIP file containing all generated images."""
|
|
zip_buffer = io.BytesIO()
|
|
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
|
|
for i, image_base64 in enumerate(app.config.get("images_base64", [])):
|
|
img_data = base64.b64decode(image_base64)
|
|
zip_file.writestr(f"generated_image_{i + 1}.png", img_data)
|
|
zip_buffer.seek(0)
|
|
return send_file(
|
|
zip_buffer,
|
|
mimetype="application/zip",
|
|
as_attachment=True,
|
|
download_name="generated_images.zip"
|
|
)
|
|
|
|
@app.route("/detect", methods=["POST"])
|
|
def detect():
|
|
try:
|
|
|
|
model_path = "models/vit-35-Epochs-92-NTP-model.pth"
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
file = request.files.get("file")
|
|
if not file:
|
|
print("No file uploaded.")
|
|
return jsonify({"error": "No file uploaded"}), 400
|
|
|
|
|
|
file_path = os.path.join("static", "temp_image.jpg")
|
|
os.makedirs("static", exist_ok=True)
|
|
file.save(file_path)
|
|
print(f"File saved to: {file_path}")
|
|
|
|
|
|
detection_app = TumorDetectionApp(model_path=model_path, device=DEVICE)
|
|
print("Detection app initialized.")
|
|
|
|
|
|
predicted_class = detection_app.predict_image(file_path)
|
|
if predicted_class is None:
|
|
print("Prediction failed.")
|
|
return jsonify({"error": "Prediction failed"}), 500
|
|
|
|
|
|
class_mapping = {
|
|
0: "Glioma",
|
|
1: "Meningioma",
|
|
2: "No Tumor",
|
|
3: "Pituitary"
|
|
}
|
|
result = class_mapping.get(predicted_class, "Unknown")
|
|
print(f"Prediction successful. Result: {result}")
|
|
|
|
|
|
return render_template("results-detect.html", images=["temp_image.jpg"], result=result)
|
|
|
|
except Exception as e:
|
|
print(f"Error in /detect route: {e}")
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=True)
|
|
|