mubashirhussaindev commited on
Commit
3cd3c35
Β·
verified Β·
1 Parent(s): 6911c3c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoImageProcessor, AutoModelForImageClassification
4
+ from PIL import Image
5
+ import json
6
+ import re
7
+ import pandas as pd
8
+ from datetime import datetime
9
+ import plotly.express as px
10
+ from io import StringIO
11
+
12
+ # Load text model
13
+ text_model_name = "microsoft/BiomedVLP-CXR-BERT-specialized"
14
+ text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)
15
+ text_model = AutoModelForSequenceClassification.from_pretrained(text_model_name)
16
+
17
+ # Load image model
18
+ image_model_name = "aehrc/cxrmate-tf" # Replace with skin disease model if needed
19
+ image_processor = AutoImageProcessor.from_pretrained(image_model_name)
20
+ image_model = AutoModelForImageClassification.from_pretrained(image_model_name)
21
+
22
+ # Define labels
23
+ text_labels = ["Positive", "Negative", "Neutral", "Informative"] # For text analysis
24
+ image_labels = ["Normal", "Abnormal"] # For X-ray or skin images
25
+
26
+ # Store conversation state
27
+ conversation_state = {
28
+ "history": [],
29
+ "texts": [],
30
+ "image_uploaded": False,
31
+ "last_analysis": None,
32
+ "analysis_log": []
33
+ }
34
+
35
+ # Extract key terms
36
+ def extract_key_terms(text):
37
+ terms = re.findall(r'\b(fever|cough|fatigue|headache|sore throat|chest pain|shortness of breath|rash|lesion|study|treatment|trial|astronaut|microgravity)\b', text, re.IGNORECASE)
38
+ return terms
39
+
40
+ # Generate context-aware follow-up questions
41
+ def generate_follow_up(terms, history):
42
+ if not terms:
43
+ return "Please provide medical text (e.g., symptoms, abstract) or upload an image."
44
+ if "astronaut" in [t.lower() for t in terms] or "microgravity" in [t.lower() for t in terms]:
45
+ return "Are you researching space medicine? Please describe physiological data or symptoms in microgravity."
46
+ if len(terms) < 3:
47
+ return "Can you provide more details (e.g., duration of symptoms or study context)?"
48
+ if not conversation_state["image_uploaded"]:
49
+ return "Would you like to upload an image (e.g., X-ray or skin photo) for analysis?"
50
+ return "Would you like to analyze another text or image, or export the analysis log?"
51
+
52
+ # Main analysis function
53
+ def analyze_medical_input(user_input, image=None, chat_history=None, export_format="None"):
54
+ global conversation_state
55
+ if not chat_history:
56
+ chat_history = []
57
+
58
+ # Process text input
59
+ text_response = ""
60
+ text_chart = ""
61
+ if user_input.strip():
62
+ terms = extract_key_terms(user_input)
63
+ conversation_state["texts"].extend(terms)
64
+ inputs = text_tokenizer(user_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
65
+ with torch.no_grad():
66
+ outputs = text_model(**inputs)
67
+ logits = outputs.logits
68
+ predicted_class_idx = logits.argmax(-1).item()
69
+ confidence = torch.softmax(logits, dim=-1)[0][predicted_class_idx].item()
70
+ scores = torch.softmax(logits, dim=-1)[0].tolist()
71
+ conversation_state["last_analysis"] = {
72
+ "type": "text",
73
+ "label": text_labels[predicted_class_idx],
74
+ "confidence": confidence,
75
+ "scores": scores,
76
+ "input": user_input,
77
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
78
+ }
79
+ conversation_state["analysis_log"].append(conversation_state["last_analysis"])
80
+ text_response = f"Text Analysis: {text_labels[predicted_class_idx]} (Confidence: {confidence:.2%})"
81
+
82
+ # Text visualization (Chart.js)
83
+ chart_data = {
84
+ "type": "bar",
85
+ "data": {
86
+ "labels": text_labels,
87
+ "datasets": [{
88
+ "label": "Confidence Scores",
89
+ "data": scores,
90
+ "backgroundColor": ["#4CAF50", "#F44336", "#2196F3", "#FF9800"],
91
+ "borderColor": ["#388E3C", "#D32F2F", "#1976D2", "#F57C00"],
92
+ "borderWidth": 1
93
+ }]
94
+ },
95
+ "options": {
96
+ "scales": {
97
+ "y": {"beginAtZero": True, "max": 1, "title": {"display": True, "text": "Confidence"}},
98
+ "x": {"title": {"display": True, "text": "Text Categories"}}
99
+ },
100
+ "plugins": {"title": {"display": True, "text": "Text Analysis Confidence"}}
101
+ }
102
+ }
103
+ text_chart = f"""
104
+ <canvas id='textChart' width='400' height='200'></canvas>
105
+ <script src='https://cdn.jsdelivr.net/npm/chart.js'></script>
106
+ <script>
107
+ new Chart(document.getElementById('textChart'), {json.dumps(chart_data)});
108
+ </script>
109
+ """
110
+
111
+ # Process image input
112
+ image_response = ""
113
+ image_chart = ""
114
+ if image is not None:
115
+ conversation_state["image_uploaded"] = True
116
+ inputs = image_processor(images=image, return_tensors="pt")
117
+ with torch.no_grad():
118
+ outputs = image_model(**inputs)
119
+ logits = outputs.logits
120
+ predicted_class_idx = logits.argmax(-1).item()
121
+ confidence = torch.softmax(logits, dim=-1)[0][predicted_class_idx].item()
122
+ scores = torch.softmax(logits, dim=-1)[0].tolist()
123
+ conversation_state["last_analysis"] = {
124
+ "type": "image",
125
+ "label": image_labels[predicted_class_idx],
126
+ "confidence": confidence,
127
+ "scores": scores,
128
+ "input": "image",
129
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
130
+ }
131
+ conversation_state["analysis_log"].append(conversation_state["last_analysis"])
132
+ image_response = f"Image Analysis: {image_labels[predicted_class_idx]} (Confidence: {confidence:.2%})"
133
+
134
+ # Image visualization (Chart.js)
135
+ chart_data = {
136
+ "type": "bar",
137
+ "data": {
138
+ "labels": image_labels,
139
+ "datasets": [{
140
+ "label": "Confidence Scores",
141
+ "data": scores,
142
+ "backgroundColor": ["#4CAF50", "#F44336"],
143
+ "borderColor": ["#388E3C", "#D32F2F"],
144
+ "borderWidth": 1
145
+ }]
146
+ },
147
+ "options": {
148
+ "scales": {
149
+ "y": {"beginAtZero": True, "max": 1, "title": {"display": True, "text": "Confidence"}},
150
+ "x": {"title": {"display": True, "text": "Image Categories"}}
151
+ },
152
+ "plugins": {"title": {"display": True, "text": "Image Analysis Confidence"}}
153
+ }
154
+ }
155
+ image_chart = f"""
156
+ <canvas id='imageChart' width='400' height='200'></canvas>
157
+ <script src='https://cdn.jsdelivr.net/npm/chart.js'></script>
158
+ <script>
159
+ new Chart(document.getElementById('imageChart'), {json.dumps(chart_data)});
160
+ </script>
161
+ """
162
+
163
+ # Generate trend visualization (Plotly)
164
+ trend_html = ""
165
+ if len(conversation_state["analysis_log"]) > 1:
166
+ df = pd.DataFrame(conversation_state["analysis_log"])
167
+ fig = px.line(
168
+ df, x="timestamp", y="confidence", color="type",
169
+ title="Analysis Confidence Over Time",
170
+ labels={"confidence": "Confidence Score", "timestamp": "Time"}
171
+ )
172
+ trend_html = fig.to_html(full_html=False)
173
+
174
+ # Combine responses
175
+ response = "\n".join([r for r in [text_response, image_response] if r])
176
+ if not response:
177
+ response = "No analysis yet. Please provide text or upload an image."
178
+ response += f"\n\nFollow-Up: {generate_follow_up(conversation_state['texts'], conversation_state['history'])}"
179
+ response += f"\n\n{text_chart}\n{image_chart}\n{trend_html}"
180
+
181
+ # Handle export
182
+ if export_format != "None":
183
+ df = pd.DataFrame(conversation_state["analysis_log"])
184
+ if export_format == "JSON":
185
+ export_data = df.to_json(orient="records")
186
+ return response, gr.File(value=StringIO(export_data), file_name="analysis_log.json")
187
+ elif export_format == "CSV":
188
+ export_data = df.to_csv(index=False)
189
+ return response, gr.File(value=StringIO(export_data), file_name="analysis_log.csv")
190
+
191
+ # Add disclaimer
192
+ disclaimer = "⚠️ This tool is for research purposes only and does not provide medical diagnoses. Consult a healthcare professional for medical advice."
193
+ response += f"\n\n{disclaimer}"
194
+
195
+ conversation_state["history"].append((user_input, response))
196
+ return response
197
+
198
+ # Custom CSS for professional UI
199
+ css = """
200
+ body { background-color: #f0f2f5; font-family: 'Segoe UI', Arial, sans-serif; }
201
+ .gradio-container { max-width: 900px; margin: auto; padding: 30px; background: white; border-radius: 10px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); }
202
+ h1 { color: #1a3c5e; text-align: center; font-size: 2em; }
203
+ input, textarea { border-radius: 8px; border: 1px solid #ccc; padding: 10px; }
204
+ button { background: linear-gradient(90deg, #3498db, #2980b9); color: white; border-radius: 8px; padding: 12px; font-weight: bold; }
205
+ button:hover { background: linear-gradient(90deg, #2980b9, #1a6ea6); }
206
+ #export_dropdown { width: 150px; margin-top: 10px; }
207
+ """
208
+
209
+ # Create Gradio interface
210
+ with gr.Blocks(css=css) as iface:
211
+ gr.Markdown("# Ultra-Advanced Medical Research Chatbot")
212
+ gr.Markdown("Analyze medical texts or images for research purposes. Supports symptom analysis, literature review, or space medicine research. Not for medical diagnosis.")
213
+ with gr.Row():
214
+ with gr.Column(scale=2):
215
+ text_input = gr.Textbox(lines=5, placeholder="Enter symptoms, medical abstract, or space medicine data...")
216
+ image_input = gr.Image(type="pil", label="Upload X-ray or Skin Image")
217
+ export_dropdown = gr.Dropdown(choices=["None", "JSON", "CSV"], label="Export Log", value="None")
218
+ submit_button = gr.Button("Analyze")
219
+ with gr.Column(scale=3):
220
+ output = gr.HTML()
221
+ submit_button.click(
222
+ fn=analyze_medical_input,
223
+ inputs=[text_input, image_input, gr.State(), export_dropdown],
224
+ outputs=[output, gr.File()]
225
+ )
226
+
227
+ # Launch the interface
228
+ iface.launch()