Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dpop training #5339

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ class RLHFArguments:
default=0.0,
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
)
dpop_lambda: float = field(
default=0.0,
metadata={"help": "The weight factor of the penalty term in DPOP training."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
Expand Down
114 changes: 114 additions & 0 deletions src/llamafactory/train/dpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.ftx_gamma = finetuning_args.pref_ftx
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma
self.dpop_lambda = finetuning_args.dpop_lambda

Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
Expand Down Expand Up @@ -136,7 +137,120 @@ def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor
logits = pi_logratios - gamma_logratios
simpo_loss = -F.logsigmoid(self.beta * logits)
return simpo_loss
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.

Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)

Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps

pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios

if self.dpop_lambda > 1e-6:
penalty_term = torch.maximum(torch.zeros_like(policy_chosen_logps), reference_chosen_logps - policy_chosen_logps)
logits += - self.dpop_lambda * penalty_term

# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "robust":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
) / (1 - 2 * self.label_smoothing)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto_pair":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
elif self.loss_type == "bco_pair":
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps

chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards)
delta = self.running.mean

losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)
elif self.loss_type == "sppo_hard":
# In the paper (https://arxiv.org/pdf/2405.00675), SPPO employs a soft probability approach, estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is set to 1 for the winner and 0 for the loser.
a = policy_chosen_logps - reference_chosen_logps
b = policy_rejected_logps - reference_rejected_logps

losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
elif self.loss_type == "nca_pair":
chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.beta
rejected_rewards = (policy_rejected_logps - reference_rejected_logps) * self.beta
losses = (
-F.logsigmoid(chosen_rewards)
- 0.5 * F.logsigmoid(-chosen_rewards)
- 0.5 * F.logsigmoid(-rejected_rewards)
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair', 'sppo_hard', 'nca_pair', 'robust']"
)

chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
)

return losses, chosen_rewards, rejected_rewards
def compute_preference_loss(
self,
policy_chosen_logps: "torch.Tensor",
Expand Down