dx-reasoning-qwen2.5-grpo
A LoRA adapter for clinical diagnostic reasoning, fine-tuned from Qwen2.5-7B-Instruct using Group Relative Policy Optimisation (GRPO).
Model description
This model was trained to improve clinical diagnostic reasoning by learning to generate step-by-step reasoning before providing a diagnosis. It uses a structured format with <reasoning> and <diagnosis> tags.
Training details
- Base model: Qwen/Qwen2.5-7B-Instruct
- Training method: GRPO (Group Relative Policy Optimisation)
- LoRA configuration:
- Rank (r): 64
- Alpha: 128
- Dropout: 0.05
- Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
- Trainable parameters: 161M / 7.7B total (2.08%)
- Training steps: 700 (2+ epochs)
- Hardware: 2x NVIDIA H100 80GB
- Training time: ~20 hours
Reward function
The model was trained with a composite reward function:
- Embedding similarity: Cosine similarity between generated diagnosis and ground truth using PubMedBERT embeddings
- Reasoning quality: Bonus for including structured reasoning steps
Dataset
Trained on gretelai/symptom_to_diagnosis:
- 853 training samples
- 200 evaluation samples
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo")
# Example inference
prompt = """You are a medical expert. Given the patient's symptoms, provide a diagnosis.
Patient symptoms: The patient presents with severe headache, sensitivity to light, neck stiffness, and fever.
First, provide your reasoning in <reasoning> tags, then give your diagnosis in <diagnosis> tags."""
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Rollouts
Training rollouts are available in the rollouts/ directory, containing generation samples at each evaluation step (100, 200, 300, 400, 500, 600, 700). These can be used for per-diagnosis analysis of training progression.
Limitations
- Trained on a relatively small dataset (853 samples)
- Focused on symptom-to-diagnosis task; may not generalise to other medical reasoning tasks
- Should not be used for actual medical diagnosis - for research purposes only
Citation
If you use this model, please cite:
@misc{dx-reasoning-qwen2.5-grpo,
author = {Chris von Csefalvay},
title = {dx-reasoning-qwen2.5-grpo: Clinical Diagnostic Reasoning with GRPO},
year = {2026},
publisher = {Hugging Face},
url = {https://huggingface.co/chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo}
}
Training logs
Training was monitored via Weights & Biases. See the project for detailed metrics and training curves.
- Downloads last month
- -