dx-reasoning-qwen2.5-grpo

A LoRA adapter for clinical diagnostic reasoning, fine-tuned from Qwen2.5-7B-Instruct using Group Relative Policy Optimisation (GRPO).

Model description

This model was trained to improve clinical diagnostic reasoning by learning to generate step-by-step reasoning before providing a diagnosis. It uses a structured format with <reasoning> and <diagnosis> tags.

Training details

  • Base model: Qwen/Qwen2.5-7B-Instruct
  • Training method: GRPO (Group Relative Policy Optimisation)
  • LoRA configuration:
    • Rank (r): 64
    • Alpha: 128
    • Dropout: 0.05
    • Target modules: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
  • Trainable parameters: 161M / 7.7B total (2.08%)
  • Training steps: 700 (2+ epochs)
  • Hardware: 2x NVIDIA H100 80GB
  • Training time: ~20 hours

Reward function

The model was trained with a composite reward function:

  • Embedding similarity: Cosine similarity between generated diagnosis and ground truth using PubMedBERT embeddings
  • Reasoning quality: Bonus for including structured reasoning steps

Dataset

Trained on gretelai/symptom_to_diagnosis:

  • 853 training samples
  • 200 evaluation samples

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct",
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo")

# Example inference
prompt = """You are a medical expert. Given the patient's symptoms, provide a diagnosis.

Patient symptoms: The patient presents with severe headache, sensitivity to light, neck stiffness, and fever.

First, provide your reasoning in <reasoning> tags, then give your diagnosis in <diagnosis> tags."""

messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)

outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Rollouts

Training rollouts are available in the rollouts/ directory, containing generation samples at each evaluation step (100, 200, 300, 400, 500, 600, 700). These can be used for per-diagnosis analysis of training progression.

Limitations

  • Trained on a relatively small dataset (853 samples)
  • Focused on symptom-to-diagnosis task; may not generalise to other medical reasoning tasks
  • Should not be used for actual medical diagnosis - for research purposes only

Citation

If you use this model, please cite:

@misc{dx-reasoning-qwen2.5-grpo,
  author = {Chris von Csefalvay},
  title = {dx-reasoning-qwen2.5-grpo: Clinical Diagnostic Reasoning with GRPO},
  year = {2026},
  publisher = {Hugging Face},
  url = {https://huggingface.co/chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo}
}

Training logs

Training was monitored via Weights & Biases. See the project for detailed metrics and training curves.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo

Base model

Qwen/Qwen2.5-7B
Adapter
(1610)
this model

Dataset used to train chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo

Space using chrisvoncsefalvay/dx-reasoning-qwen2.5-grpo 1