π§ LoRA Fine-Tuned Mistral-7B on MTS-Dialog
This repository contains a LoRA fine-tuned version of mistralai/Mistral-7B-v0.1
for medical dialogue summarization, trained on the MTS-Dialog dataset.
π Resources
- π Training Notebook & Code: GitHub Repository
π Model Summary
- Base Model:
mistralai/Mistral-7B-v0.1
- Fine-tuning Method: LoRA (Low-Rank Adaptation)
- Frameworks: π€ Transformers, PEFT, bitsandbytes
- Quantization: 4-bit
- Task: Medical dialogue summarization
- Dataset: MTS-Dialog
π₯ Task Description
This model is trained to summarize doctor-patient conversations into concise clinical notes, categorized by sections such as GENHX
, HPI
, ROS
, etc. These summaries assist with EHR documentation and clinical decision-making.
βοΈ Training Configuration
Parameter | Value |
---|---|
LoRA Rank | 4 |
Epochs | 3 |
Batch Size | 4 (Γ4 grad. acc.) |
Learning Rate | 3e-4 |
Device | CUDA:0 |
Quantization | 4-bit (bnb) |
β οΈ Due to limited GPU resources (office laptop), training was constrained to 3 epochs and a small LoRA rank. Performance is expected to improve significantly with extended training and better hardware.
π Evaluation Metrics
Metric | Score |
---|---|
ROUGE-1 | 0.1318 |
ROUGE-2 | 0.0456 |
ROUGE-L | 0.0900 |
BLEU | 0.0260 |
π‘ Example Prompt
Summarize the following dialogue for section: GENHX
Doctor: What brings you back into the clinic today, miss?
Patient: I've had chest pain for the last few days.
Doctor: When did it start?
Summary:
## π§ͺ Inference Code
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", load_in_4bit=True)
model = PeftModel.from_pretrained(model, "Imsachinsingh00/Fine_tuned_LoRA_Mistral_MTSDialog_Summarization")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("Imsachinsingh00/Fine_tuned_LoRA_Mistral_MTSDialog_Summarization")
prompt = "Summarize the following dialogue for section: HPI\nDoctor: Hello, what brings you in?\nPatient: I've been dizzy for two days.\nSummary:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=150)
print(tokenizer.decode(output[0], skip_special_tokens=True))
## π Included Files
- `config.json` β PEFT configuration for LoRA
- `adapter_model.bin` β LoRA adapter weights
- `tokenizer/` β Tokenizer files
- `README.md` β This model card
## π Notes
- π« This is not a fully optimized clinical model β only a proof of concept.
- π‘ Consider training longer (`epochs=10`, `rank=8`) on GPUs with higher VRAM for better results.
- Downloads last month
- 2
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support
Evaluation results
- ROUGE-1self-reported0.132
- ROUGE-2self-reported0.046
- ROUGE-Lself-reported0.090
- BLEUself-reported0.026