File size: 1,753 Bytes
40cf44a
 
61a2851
b9ba0f4
 
 
61a2851
 
b9ba0f4
61a2851
b9ba0f4
 
 
61a2851
 
 
b9ba0f4
61a2851
b9ba0f4
 
61a2851
 
b9ba0f4
 
 
61a2851
 
 
b9ba0f4
 
61a2851
b9ba0f4
 
61a2851
 
b9ba0f4
 
 
 
 
 
61a2851
9678f07
61a2851
b9ba0f4
 
61a2851
 
b9ba0f4
 
61a2851
b9ba0f4
 
61a2851
 
b9ba0f4
61a2851
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
! pip install unsloth peft bitsandbytes accelerate transformers


import subprocess
import sys

subprocess.check_call([sys.executable, "-m", "pip", "install", "unsloth", "peft", "bitsandbytes", "accelerate", "transformers"])


# Import necessary modules
from transformers import AutoTokenizer
from unsloth import FastLanguageModel

# Define the MedQA prompt
medqa_prompt = """You are a medical QA system. Answer the following medical question clearly and in detail with complete sentences.
### Question:
{}
### Answer:
"""

# Load the model and tokenizer using unsloth
model_name = "Vijayendra/Phi4-MedQA" 
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=2048,
    dtype=None,  # Use default precision
    load_in_4bit=True,  # Enable 4-bit quantization
    device_map="auto"  # Automatically map model to available devices
)

# Enable faster inference
FastLanguageModel.for_inference(model)

# Prepare the medical question
medical_question = "What are the common symptoms of diabetes?"  # Replace with your medical question
inputs = tokenizer(
    [medqa_prompt.format(medical_question)],
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=1024
).to("cuda")  # Ensure inputs are on the GPU

# Generate the output
outputs = model.generate(
    **inputs,
    max_new_tokens=512,  # Allow for detailed responses
    use_cache=True  # Speeds up generation
)

# Decode and clean the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Extract and print the generated answer
answer_text = response.split("### Answer:")[1].strip() if "### Answer:" in response else response.strip()

print(f"Question: {medical_question}")
print(f"Answer: {answer_text}")