Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "crrl"
version = "0.1.0"
description = "Finetuning LLMs with RL to improve program repair capabilities"
readme = "README.md"
requires-python = ">=3.11,<3.12"
requires-python = ">=3.11,<3.13"
dependencies = [
# Our fork of trl
"trl @ git+https://github.com/ASSERT-KTH/trl.git@dev",
Expand Down Expand Up @@ -40,9 +40,12 @@ dependencies = [

[project.optional-dependencies]
# vLLM is required, but install is brittle needing torch to be installed first
vllm = ["vllm==0.10.1; platform_system != 'Darwin'"]
flash = ["flash-attn==2.8.3; platform_system != 'Darwin'", "flashinfer-python==0.3.1; platform_system != 'Darwin'"]
# Flash attention packages require GPU at compile time, install separately
vllm = ["vllm==0.11.0; platform_system != 'Darwin'"]
# Flash attention packages - flash-attn needed by transformers for training, flashinfer for vLLM
flash = [
"flash-attn==2.8.3; platform_system != 'Darwin'",
"flashinfer-python==0.5.0; platform_system != 'Darwin'",
]
dev = ["pytest>=8.3.4", "matplotlib>=3.10.0", "scikit-learn>=1.6.1"]
# Different agents
# aider = ["aider-chat @ git+https://github.com/BjarniHaukur/aider.git"]
Expand Down
4 changes: 2 additions & 2 deletions scripts/grpo/large_grpo_lora_train_job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ apptainer exec $APPT_COMMON --env CUDA_VISIBLE_DEVICES=2,3,4,5 crrl.sif accelera
grpo.max_prompt_length=$MAX_PROMPT_LENGTH \
grpo.max_completion_length=$MAX_COMPLETION_LENGTH \
grpo.num_generations=4 \
grpo.generation_batch_size=8 \
grpo.generation_batch_size=32 \
grpo.per_device_train_batch_size=1 \
grpo.gradient_accumulation_steps=4 \
grpo.gradient_accumulation_steps=8 \
grpo.optim="adamw_torch" \
"$@" # pass any additional arguments

12 changes: 6 additions & 6 deletions scripts/train_container.def
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Bootstrap: docker
From: vllm/vllm-openai:v0.10.0
From: vllm/vllm-openai:v0.11.0

%environment
# Keep container largely stateless; runtime env is passed via --env from host
Expand All @@ -15,10 +15,10 @@ From: vllm/vllm-openai:v0.10.0

pip install --upgrade pip
pip install -e .

pip install torch==2.7.1+cu128 --index-url https://download.pytorch.org/whl/cu128

# we pin the version since we need prebuilt binaries to be able to install
pip install --no-build-isolation --prefer-binary flash-attn==2.8.0.post2
pip install --no-build-isolation flashinfer-python
# flash-attn is needed by transformers for attn_implementation="flash_attention_2" during training
# (vLLM only bundles it for inference, not as an importable package)
# Using --no-cache-dir to avoid cross-device link errors in container builds
pip install --no-build-isolation --no-cache-dir flash-attn==2.8.3
pip install --no-build-isolation --no-cache-dir flashinfer-python==0.5.0