From 5c1df46a8668c109cd8822aa3c8bfd0de2f63479 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Tue, 5 Aug 2025 22:47:36 +0000 Subject: [PATCH 1/3] feat: Add GMPO support --- src/art/unsloth/train.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index f1fde190..8fc614ae 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -85,9 +85,9 @@ def compute_loss( next_input_ids = shift_tensor(inputs["tokens"], 0) chunk_size = _config.get("logprob_calculation_chunk_size", 1024) # Assert that sequence length is evenly divisible by the chunk size - assert seq_len % chunk_size == 0, ( - f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})" - ) + assert ( + seq_len % chunk_size == 0 + ), f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})" os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" new_logprobs, entropies = calculate_logprobs( autocast_dtype, @@ -154,10 +154,23 @@ def compute_loss( prob_ratio = torch.clamp( prob_ratio, max=max_negative_advantage_importance_sampling_weight ) - policy_loss = -torch.min( - prob_ratio * advantages, - torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, - ) + if _config.get("gmpo", True): + advantage_signs = -torch.sign(advantages) + signed_logprob_diff = logprob_diff * advantage_signs + signed_logprob_diff_clamp = torch.clamp( + signed_logprob_diff, -epsilon, epsilon_high + ) + signed_logprob_diff_max = torch.max( + signed_logprob_diff, signed_logprob_diff_clamp + ) + logprobs_diff_max = advantage_signs * signed_logprob_diff_max + prob_ratio = torch.exp(logprob_diff) + policy_loss = -advantages * logprobs_diff_max + else: + policy_loss = -torch.min( + prob_ratio * advantages, + torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, + ) if ref_logprobs is not None: kl_div = ( torch.exp(ref_logprobs - new_logprobs) @@ -323,7 +336,9 @@ def _calculate_logprobs( chunk_logits = torch.matmul(chunk_hs, lm_head_t) # [B, chunk_size, V] chunk_selected_logits = torch.gather( chunk_logits, dim=-1, index=chunk_input_ids.unsqueeze(-1) - ).squeeze(-1) # [B, chunk_size] + ).squeeze( + -1 + ) # [B, chunk_size] chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size] log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp From a543d378a132c102ca9b5e7723c4e5a06c670abf Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Wed, 6 Aug 2025 16:19:43 +0000 Subject: [PATCH 2/3] fix: GMPO issue --- src/art/unsloth/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 8fc614ae..c734d9de 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -164,8 +164,8 @@ def compute_loss( signed_logprob_diff, signed_logprob_diff_clamp ) logprobs_diff_max = advantage_signs * signed_logprob_diff_max - prob_ratio = torch.exp(logprob_diff) - policy_loss = -advantages * logprobs_diff_max + prob_ratio = torch.exp(logprobs_diff_max) + policy_loss = -advantages * prob_ratio else: policy_loss = -torch.min( prob_ratio * advantages, From 5b7270242547e314ec8f8b7f6f5619339a462e04 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Tue, 12 Aug 2025 15:11:30 +0000 Subject: [PATCH 3/3] chore: Reformat train.py --- src/art/unsloth/train.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index c7344097..bc468766 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -86,9 +86,9 @@ def compute_loss( next_input_ids = shift_tensor(inputs["tokens"], 0) chunk_size = _config.get("logprob_calculation_chunk_size", 1024) # Assert that sequence length is evenly divisible by the chunk size - assert ( - seq_len % chunk_size == 0 - ), f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})" + assert seq_len % chunk_size == 0, ( + f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})" + ) os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" new_logprobs, entropies = calculate_logprobs( autocast_dtype, @@ -339,9 +339,7 @@ def _calculate_logprobs( chunk_logits = torch.matmul(chunk_hs, lm_head_t) # [B, chunk_size, V] chunk_selected_logits = torch.gather( chunk_logits, dim=-1, index=chunk_input_ids.unsqueeze(-1) - ).squeeze( - -1 - ) # [B, chunk_size] + ).squeeze(-1) # [B, chunk_size] chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size] log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp