Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README_Training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Training

## Build Dataset

```
uv run src/build_dataset.py --output ../data/
```

## Train Model

```
bash scripts run_training.sh -m Qwen/Qwen3-0.6B -d <Absolute Path to Data>
```
2 changes: 1 addition & 1 deletion agent-sdk
59 changes: 39 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,45 +1,64 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src"]

[project]
name = "agentic-code-search-oss"
version = "0.1.0"
description = "An open-source implementation of a low-latency agent for code localization."
name = "src"
version = "0.0.1"
readme = "README.md"
license = "MIT"
classifiers = [
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
requires-python = ">=3.12"
license = { "text" = "MIT" }
dependencies = [
"datasets>=4.2.0",
"skyrl-train @ git+https://github.com/NovaSky-AI/SkyRL.git@main#subdirectory=skyrl-train",
"openhands-tools",
"openhands-agent-server",
"openhands-workspace",
"vllm",
"verifiers>=0.1.6.post0",
]

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["."]
include = ["src", "agent-sdk"]

[dependency-groups]
dev = [
"pytest>=8.4.2",
"pre-commit>=4.3.0",
"psutil>=7.0.0",
"pyright>=1.1.405",
"ruff>=0.12.10",
"pycodestyle>=2.12.0",
]

[tool.uv]
override-dependencies = [
"flash-attn",
]
extra-build-variables = { flash-attn = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" } }

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = ["E", "F", "I"]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
[tool.uv.extra-build-dependencies]
flash-attn = ["torch"]

[tool.uv.sources]
flash-attn = {url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl"}
openhands-sdk = { workspace = true }
openhands-tools = { workspace = true }
openhands-workspace = { workspace = true }
openhands-agent-server = { workspace = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }

[tool.uv.workspace]
members = [
Expand Down
91 changes: 91 additions & 0 deletions scripts/run_training.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/bin/bash
#SBATCH --job-name=cso
#SBATCH --output=logs/%j.out
#SBATCH --error=logs/%j.out
#SBATCH --partition=preempt
#SBATCH --gres=gpu:L40S:8
#SBATCH --nodes=1
#SBATCH --time=2-00:00:00
#SBATCH --mem=64G
#SBATCH --cpus-per-task=32
#SBATCH --ntasks-per-node=1

. .env

while getopts ":m:n:d:s:" opt; do
case ${opt} in
m ) MODEL=$OPTARG;;
n ) N_ROLLOUTS=$OPTARG;;
d ) DATA_PATH=$OPTARG;;
s ) CKPT_PATH=$OPTARG;;
# \? ) echo "Usage: cmd [-u] [-p]";;
esac
done

MODEL_ALIAS=$(echo $MODEL | sed 's/\//-/g')
# Get number of GPUs available
NUM_GPUS=$(nvidia-smi -L | wc -l)
N_ROLLOUTS="${N_ROLLOUTS:-8}"
MAX_LENGTH=8192
RUN_NAME="code_search_${MODEL_ALIAS}"
set -x

DATA_PATH="${DATA_PATH:-data/swe_smith}"
CKPT_PATH="/datadrive/lsutawik/cso/${CKPT_PATH:-ckpts/${MODEL_ALIAS}}"
mkdir -p $CKPT_PATH

NNODES=1
NUM_INFERENCE_ENGINES=4
TP_SIZE=1
LOGGER=wandb

# We use a small batch size here for demonstration
# NOTE (sumanthrh): The `generator.max_turns` here is actually unused, and we use the `step_limit` from the `swebench.yaml` file.
uv run --isolated -m src.train \
data.train_data="['$DATA_PATH/train.parquet']" \
data.val_data="['$DATA_PATH/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.policy.model.path=${MODEL} \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
trainer.placement.policy_num_nodes=$NNODES \
trainer.placement.ref_num_nodes=$NNODES \
trainer.policy.sequence_parallel_size=$NUM_GPUS \
generator.num_inference_engines=$NUM_INFERENCE_ENGINES \
generator.inference_engine_tensor_parallel_size=$TP_SIZE \
+generator.traj_dir=$CKPT_PATH/trajectories/ \
+generator.engine_init_kwargs="{enable_auto_tool_choice:true,tool_call_parser:hermes}" \
trainer.epochs=20 \
trainer.eval_batch_size=100 \
trainer.eval_before_train=false \
trainer.eval_interval=10 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=8 \
trainer.policy_mini_batch_size=8 \
trainer.micro_forward_batch_size_per_gpu=2 \
trainer.micro_train_batch_size_per_gpu=2 \
trainer.dump_data_batch=true \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=4096 \
generator.sampling_params.max_generate_length=${MAX_LENGTH} \
generator.max_input_length=30720 \
generator.max_turns=20 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=vllm \
generator.run_engines_locally=True \
generator.enable_http_endpoint=True \
generator.http_endpoint_host='0.0.0.0' \
generator.http_endpoint_port=8080 \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
generator.n_samples_per_prompt=${N_ROLLOUTS} \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="code_search" \
trainer.run_name=${RUN_NAME} \
trainer.resume_mode=null \
trainer.ckpt_path="$CKPT_PATH"
20 changes: 17 additions & 3 deletions src/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,26 @@ def main():
lambda row: f"{extract_functions_from_patch(row["patch"])}", axis=1
)

dataset["prompt"] = dataset.apply(
lambda row: [{"role": "user", "content": row["problem_statement"]}], axis=1
)

# shuffle dataset
dataset = dataset.sample(frac=1).reset_index(drop=True)

# train_size = int(0.975 * len(dataset))
train_dataset = dataset.iloc[:-100]
validation_dataset = dataset.iloc[-100:]

# if output does not exist, create it
output_dir = os.path.join(args.output, args.dataset.replace("/", "__") + "_" + args.split)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_path = os.path.join(output_dir, "train.parquet")
dataset.to_parquet(output_path)

train_dataset.to_parquet(output_path)

output_path = os.path.join(output_dir, "validation.parquet")
validation_dataset.to_parquet(output_path)

if __name__ == "__main__":
main()
main()
Loading