File size: 13,839 Bytes
20dcaab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import os
import gc
import threading
import time
import torch
import numpy as np
from PIL import Image
from datetime import datetime
import gradio as gr

# Import utility modules
from utils.text_model import generate_text_with_transformers, generate_text_with_llamacpp, check_llamacpp_available
from utils.vision_model import process_medical_image, load_vision_model

# Define models
TEXT_MODEL = "microsoft/Phi-3-mini-4k-instruct"  # Smaller model for better speed
VISION_MODEL = "flaviagiammarino/medsam-vit-base"

# Initialize threading event for cancellation
cancel_event = threading.Event()

# Cache for models
vision_model_cache = {"model": None, "processor": None}
transformers_model_cache = {"model": None, "tokenizer": None}

# Check if llamacpp is available
llamacpp_available = check_llamacpp_available()

# Function for text generation
def generate_pharmaceutical_response(query, use_llamacpp=False, progress=gr.Progress()):
    cancel_event.clear()
    
    if not query.strip():
        return "Please enter a question about medications or medical conditions."
    
    # Choose generation method based on user preference and availability
    if use_llamacpp and llamacpp_available:
        progress(0.1, desc="Loading llama.cpp model...")
        try:
            result = generate_text_with_llamacpp(
                query=query,
                max_tokens=512,
                temperature=0.7,
                top_p=0.9,
                cancel_event=cancel_event,
                progress_callback=lambda p, d: progress(0.1 + 0.8 * p, desc=d)
            )
            progress(1.0, desc="Done!")
            return result
        except Exception as e:
            return f"Error generating response with llama.cpp: {e}"
    else:
        # Fallback to transformers if llamacpp is not available or not chosen
        progress(0.1, desc="Loading transformers model...")
        try:
            # Get or load model and tokenizer
            if transformers_model_cache["model"] is None or transformers_model_cache["tokenizer"] is None:
                from utils.text_model import load_text_model
                model, tokenizer = load_text_model(TEXT_MODEL, quantize=torch.cuda.is_available())
                transformers_model_cache["model"] = model
                transformers_model_cache["tokenizer"] = tokenizer
            else:
                model = transformers_model_cache["model"]
                tokenizer = transformers_model_cache["tokenizer"]
                
            # Prepare the input with the correct format for Phi-3
            inputs = tokenizer(query, return_tensors="pt", padding=True)
            if torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}
                model = model.cuda()
            
            # Generate with updated parameters and progress tracking
            progress(0.2, desc="Generating response...")
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                use_cache=False,
                callback=lambda x: progress(0.2 + 0.7 * (x / 512), desc="Generating response...")
            )
            
            result = tokenizer.decode(outputs[0], skip_special_tokens=True)
            progress(1.0, desc="Done!")
            return result
        except Exception as e:
            return f"Error generating response with transformers: {e}"

# Function for image analysis
def analyze_medical_image_gradio(image, progress=gr.Progress()):
    if image is None:
        return None, "Please upload a medical image to analyze."
    
    progress(0.1, desc="Loading vision model...")
    try:
        # Get or load vision model
        if vision_model_cache["model"] is None or vision_model_cache["processor"] is None:
            model, processor = load_vision_model(VISION_MODEL)
            vision_model_cache["model"] = model
            vision_model_cache["processor"] = processor
        else:
            model = vision_model_cache["model"]
            processor = vision_model_cache["processor"]
    except Exception as e:
        return None, f"Error loading vision model: {e}"
    
    progress(0.3, desc="Processing image...")
    try:
        # FIX: Handle the return values correctly
        result_image, metadata, analysis_text = process_medical_image(
            image,
            model=model,
            processor=processor
        )
        
        progress(0.9, desc="Finalizing results...")
        
        # Format the analysis text for Gradio
        analysis_html = analysis_text.replace("##", "<h3>").replace("#", "</h3>")
        analysis_html = analysis_html.replace("**", "<b>").replace("**", "</b>")
        analysis_html = analysis_html.replace("\n\n", "<br><br>").replace("\n- ", "<br>β€’ ")
        
        progress(1.0, desc="Done!")
        return result_image, analysis_html
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        print(f"Error analyzing image: {e}\n{error_details}")
        return None, f"Error analyzing image: {e}"

# Function for clinical interpretation
def generate_clinical_interpretation(image, analysis_html, use_llamacpp=False, progress=gr.Progress()):
    if image is None or not analysis_html:
        return "Please analyze an image first."
    
    # Extract information from the analysis HTML
    import re
    
    image_type = re.search(r"<b>Image Type</b>: ([^<]+)", analysis_html)
    region = re.search(r"<b>Region</b>: ([^<]+)", analysis_html)
    laterality = re.search(r"<b>Laterality</b>: ([^<]+)", analysis_html)
    findings = re.search(r"<b>Findings</b>:[^β€’]*β€’ ([^<]+)", analysis_html)
    
    image_type = image_type.group(1) if image_type else "Medical X-ray"
    region = region.group(1) if region else "Unknown region"
    laterality = laterality.group(1) if laterality else "Unknown laterality"
    findings = findings.group(1) if findings else "No findings detected"
    
    # Create a detailed prompt
    prompt = f"""

    Provide a detailed clinical interpretation of a {image_type} with the following characteristics:

    - Segmented region: {region}

    - Laterality: {laterality}

    - Findings: {findings}

    

    Describe potential clinical significance, differential diagnoses, and recommendations. Include educational information about normal anatomy in this region and common pathologies that might be found. Be thorough but concise.

    """
    
    # Choose generation method based on user preference and availability
    if use_llamacpp and llamacpp_available:
        progress(0.1, desc="Loading llama.cpp model...")
        try:
            result = generate_text_with_llamacpp(
                query=prompt,
                max_tokens=768,
                temperature=0.7,
                top_p=0.9,
                cancel_event=cancel_event,
                progress_callback=lambda p, d: progress(0.1 + 0.8 * p, desc=d)
            )
            progress(1.0, desc="Done!")
            return result
        except Exception as e:
            return f"Error generating interpretation with llama.cpp: {e}"
    else:
        # Fallback to transformers if llamacpp is not available or not chosen
        progress(0.1, desc="Loading transformers model...")
        try:
            # Get or load model and tokenizer
            if transformers_model_cache["model"] is None or transformers_model_cache["tokenizer"] is None:
                from utils.text_model import load_text_model
                model, tokenizer = load_text_model(TEXT_MODEL, quantize=torch.cuda.is_available())
                transformers_model_cache["model"] = model
                transformers_model_cache["tokenizer"] = tokenizer
            else:
                model = transformers_model_cache["model"]
                tokenizer = transformers_model_cache["tokenizer"]
                
            result = generate_text_with_transformers(
                model=model,
                tokenizer=tokenizer,
                query=prompt,
                max_tokens=768,
                temperature=0.7,
                cancel_event=cancel_event,
                progress_callback=lambda p, d: progress(0.1 + 0.8 * p, desc=d)
            )
            progress(1.0, desc="Done!")
            return result
        except Exception as e:
            return f"Error generating interpretation with transformers: {e}"

# Function to cancel generation
def cancel_generation():
    cancel_event.set()
    return "Generation cancelled."

# Create Gradio Interface
with gr.Blocks(title="AI Pharma App", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸ’Š AI Pharma Assistant")
    gr.Markdown("This AI-powered application helps you access pharmaceutical information and analyze medical images.")
    
    # Set up model choice
    with gr.Accordion("Model Settings", open=False):
        use_llamacpp = gr.Checkbox(
            label="Use Llama.cpp for faster inference", 
            value=llamacpp_available,
            interactive=llamacpp_available
        )
        if not llamacpp_available:
            gr.Markdown("⚠️ Llama.cpp is not available. Using transformers backend.")
            gr.Markdown("To enable llama.cpp, run: pip install llama-cpp-python --no-cache-dir")
            gr.Markdown("Then run: python download_model.py")
    
    with gr.Tab("Drug Information"):
        gr.Markdown("### πŸ“‹ Drug Information Assistant")
        gr.Markdown("Ask questions about medications, treatments, side effects, drug interactions, or any other pharmaceutical information.")
        
        with gr.Row():
            with gr.Column():
                drug_query = gr.Textbox(
                    label="Type your question about medications here:",
                    placeholder="Example: What are the common side effects of Lisinopril? How does Metformin work?",
                    lines=4
                )
                
                with gr.Row():
                    drug_submit = gr.Button("Ask AI", variant="primary")
                    drug_cancel = gr.Button("Cancel", variant="secondary")
                
            drug_response = gr.Markdown(label="AI Response")
        
        drug_submit.click(
            generate_pharmaceutical_response,
            inputs=[drug_query, use_llamacpp],
            outputs=drug_response
        )
        drug_cancel.click(
            cancel_generation,
            inputs=[],
            outputs=drug_response
        )
    
    with gr.Tab("Medical Image Analysis"):
        gr.Markdown("### πŸ” Medical Image Analyzer")
        gr.Markdown("Upload medical images such as X-rays, CT scans, MRIs, ultrasound images, or other diagnostic visuals. The AI will automatically analyze important structures.")
        
        with gr.Row():
            with gr.Column(scale=2):
                image_input = gr.Image(
                    label="Upload a medical image",
                    type="pil"
                )
                image_submit = gr.Button("Analyze Image", variant="primary")
                
            with gr.Column(scale=3):
                with gr.Tabs():
                    with gr.TabItem("Visualization"):
                        image_output = gr.Image(label="Segmentation Result")
                    with gr.TabItem("Analysis"):
                        analysis_output = gr.HTML(label="Analysis Results")
                
                with gr.Accordion("Clinical Interpretation", open=False):
                    interpret_button = gr.Button("Generate Clinical Interpretation")
                    interpretation_output = gr.Markdown()
                    
        image_submit.click(
            analyze_medical_image_gradio,
            inputs=[image_input],
            outputs=[image_output, analysis_output]
        )
        
        interpret_button.click(
            generate_clinical_interpretation,
            inputs=[image_input, analysis_output, use_llamacpp],
            outputs=interpretation_output
        )
    
    with gr.Accordion("System Information", open=False):
        gr.Markdown(f"**Date:** {datetime.now().strftime('%Y-%m-%d')}")
        gr.Markdown(f"**Device:** {'GPU' if torch.cuda.is_available() else 'CPU'}")
        if torch.cuda.is_available():
            gr.Markdown(f"**GPU:** {torch.cuda.get_device_name(0)}")
        gr.Markdown(f"**Text Model:** {TEXT_MODEL}")
        gr.Markdown(f"**Vision Model:** {VISION_MODEL}")
        gr.Markdown("**Llama.cpp:** " + ("Available" if llamacpp_available else "Not available"))
        gr.Markdown("""

        **Note:** This application is for educational purposes only and should not be used for medical diagnosis.

        The analysis provided is automated and should be reviewed by a qualified healthcare professional.

        """)

    gr.Markdown("""

    <div style="text-align: center; margin-top: 20px; padding: 10px; border-top: 1px solid #ddd;">

        <p>AI Pharma Assistant β€” Built with open-source models</p>

        <p>Powered by Phi-3 and MedSAM | Β© 2025</p>

    </div>

    """, elem_id="footer")

# Launch the app with customization for Hugging Face Spaces
if __name__ == "__main__":
    # Check if running on HF Spaces
    if os.environ.get('SPACE_ID'):
        demo.launch(server_name="0.0.0.0", share=False)
    else:
        # Local development
        demo.launch(share=False)