Skip to content

Commit 6bf923c

Browse files
committed
Speculative decoding support for DMR provider
Signed-off-by: Christopher Petito <[email protected]>
1 parent cda41d5 commit 6bf923c

File tree

6 files changed

+232
-15
lines changed

6 files changed

+232
-15
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ in the `/examples/` directory.
264264
### DMR (Docker Model Runner) provider options
265265

266266
When using the `dmr` provider, you can use the `provider_opts` key for DMR
267-
runtime-specific (e.g. llama.cpp) options:
267+
runtime-specific (e.g. llama.cpp/vllm) options and speculative decoding:
268268

269269
```yaml
270270
models:
@@ -273,7 +273,12 @@ models:
273273
model: ai/qwen3
274274
max_tokens: 8192
275275
provider_opts:
276+
# general flags passed to the underlying model runtime
276277
runtime_flags: ["--ngl=33", "--repeat-penalty=1.2", ...] # or comma/space-separated string
278+
# speculative decoding for faster inference
279+
speculative_draft_model: ai/qwen3:1B
280+
speculative_num_tokens: 5
281+
speculative_acceptance_rate: 0.8
277282
```
278283

279284
The default base_url `cagent` will use for DMR providers is
@@ -283,6 +288,8 @@ settings](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-doc
283288
on MacOS and Windows, and via command line on [Docker CE on
284289
Linux](https://docs.docker.com/ai/model-runner/get-started/#enable-dmr-in-docker-engine).
285290

291+
See the [DMR Provider documentation](docs/USAGE.md#dmr-docker-model-runner-provider-usage) for more details on runtime flags and speculative decoding options.
292+
286293
## Quickly generate agents and agent teams with `cagent new`
287294

288295
Using the command `cagent new` you can quickly generate agents or multi-agent

docs/PROVIDERS.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ var ProviderAliases = map[string]Alias{
2525

2626
## Add custom config if needed (optional)
2727

28-
If your provider requires custom config, like Azure's `api_version`
28+
If your provider requires custom config, like Azure's `api_version` or DMR's speculative decoding options
2929

3030
```yaml
3131
models:
@@ -41,6 +41,14 @@ models:
4141
model: gpt-4o
4242
provider_opts:
4343
your_custom_option: your_custom_value
44+
# DMR with speculative decoding
45+
dmr_model:
46+
provider: dmr
47+
model: ai/qwen3:14B
48+
provider_opts:
49+
speculative_draft_model: ai/qwen3:1B
50+
speculative_num_tokens: 5
51+
speculative_acceptance_rate: 0.8
4452
```
4553
4654
edit [`pkg/model/provider/openai/client.go`](https://github.com/docker/cagent/blob/main/pkg/model/provider/openai/client.go)
@@ -63,3 +71,15 @@ switch cfg.Provider { //nolint:gocritic
6371
}
6472
}
6573
```
74+
75+
## DMR Provider Specific Options
76+
77+
The DMR provider supports speculative decoding for faster inference. Configure it using `provider_opts`:
78+
79+
- `speculative_draft_model` (string): Model to use for draft predictions
80+
- `speculative_num_tokens` (int): Number of tokens to generate speculatively
81+
- `speculative_acceptance_rate` (float): Acceptance rate threshold for speculative tokens
82+
83+
All three options are passed to `docker model configure` as command-line flags.
84+
85+
You can also pass any flag of the underlying model runtime (llama.cpp or vllm) using the `runtime_flags` option

docs/USAGE.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,30 @@ models:
382382
runtime_flags: "--ngl=33 --repeat-penalty=1.2" # string accepted as well
383383
```
384384

385-
Troubleshooting:
385+
##### Speculative Decoding
386+
387+
DMR supports speculative decoding for faster inference by using a smaller draft model to predict tokens ahead. Configure speculative decoding using `provider_opts`:
388+
389+
```yaml
390+
models:
391+
qwen-with-speculative:
392+
provider: dmr
393+
model: ai/qwen3:14B
394+
max_tokens: 8192
395+
provider_opts:
396+
speculative_draft_model: ai/qwen3:0.6B-F16 # Draft model for predictions
397+
speculative_num_tokens: 16 # Number of tokens to generate speculatively
398+
speculative_acceptance_rate: 0.8 # Acceptance rate threshold
399+
```
400+
401+
All three speculative decoding options are passed to `docker model configure` as flags:
402+
- `speculative_draft_model` → `--speculative-draft-model`
403+
- `speculative_num_tokens` → `--speculative-num-tokens`
404+
- `speculative_acceptance_rate` → `--speculative-acceptance-rate`
405+
406+
These options work alongside `max_tokens` (which sets `--context-size`) and `runtime_flags`.
407+
408+
##### Troubleshooting:
386409

387410
- Plugin not found: cagent will log a debug message and use the default base URL
388411
- Endpoint empty in status: ensure the Model Runner is running, or set `base_url` manually

examples/dmr.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@
33
agents:
44
root:
55
model: qwen
6+
# model: qwen_speculative
67
description: "Pirate-themed AI assistant"
78
instruction: Talk like a pirate
9+
commands:
10+
demo: "Hey tell me a story about docker containers"
811

912
models:
1013
qwen:
1114
provider: dmr
1215
model: ai/qwen3
1316
# base_url defaults to http://localhost:12434/engines/llama.cpp/v1
1417
# use http://model-runner.docker.internal/engines/v1 if you run cagent from a container
18+
19+
# try this model for faster inference if you have enough memory
20+
qwen_speculative:
21+
provider: dmr
22+
model: ai/qwen3
23+
# The draft model should be a smaller, faster variant of the main model with low latency
24+
provider_opts:
25+
speculative_draft_model: ai/qwen3:0.6B-Q4_K_M
26+
speculative_num_tokens: 16 # (this is the llama.cpp default if omitted)
27+
speculative_acceptance_rate: 0.8 # (this is the llama.cpp default if omitted)

pkg/model/provider/dmr/client.go

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
9898
clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth
9999

100100
// Build runtime flags from ModelConfig and engine
101-
contextSize, providerRuntimeFlags := parseDMRProviderOpts(cfg)
101+
contextSize, providerRuntimeFlags, specOpts := parseDMRProviderOpts(cfg)
102102
configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg)
103103
finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags)
104104
for _, w := range warnings {
105105
slog.Warn(w)
106106
}
107-
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "engine", engine)
108-
if err := configureDockerModel(ctx, cfg.Model, contextSize, finalFlags); err != nil {
107+
slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "speculative_opts", specOpts, "engine", engine)
108+
if err := configureDockerModel(ctx, cfg.Model, contextSize, finalFlags, specOpts); err != nil {
109109
slog.Debug("docker model configure skipped or failed", "error", err)
110110
}
111111

@@ -533,14 +533,22 @@ func ConvertParametersToSchema(params any) (any, error) {
533533
return m, nil
534534
}
535535

536-
func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlags []string) {
536+
type speculativeDecodingOpts struct {
537+
draftModel string
538+
numTokens int
539+
acceptanceRate float64
540+
}
541+
542+
func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) {
537543
if cfg == nil {
538-
return 0, nil
544+
return 0, nil, nil
539545
}
540546

541547
// Context length is now sourced from the standard max_tokens field
542548
contextSize = cfg.MaxTokens
543549

550+
slog.Debug("DMR provider opts", "provider_opts", cfg.ProviderOpts)
551+
544552
if len(cfg.ProviderOpts) > 0 {
545553
if v, ok := cfg.ProviderOpts["runtime_flags"]; ok {
546554
switch t := v.(type) {
@@ -555,9 +563,72 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
555563
runtimeFlags = append(runtimeFlags, parts...)
556564
}
557565
}
566+
567+
// Parse speculative decoding options
568+
var hasDraftModel, hasNumTokens, hasAcceptanceRate bool
569+
var draftModel string
570+
var numTokens int
571+
var acceptanceRate float64
572+
573+
if v, ok := cfg.ProviderOpts["speculative_draft_model"]; ok {
574+
if s, ok := v.(string); ok && s != "" {
575+
draftModel = s
576+
hasDraftModel = true
577+
}
578+
}
579+
580+
if v, ok := cfg.ProviderOpts["speculative_num_tokens"]; ok {
581+
switch t := v.(type) {
582+
case float64:
583+
numTokens = int(t)
584+
hasNumTokens = true
585+
case uint64:
586+
numTokens = int(t)
587+
hasNumTokens = true
588+
case string:
589+
s := strings.TrimSpace(t)
590+
if s != "" {
591+
if n, err := strconv.Atoi(s); err == nil {
592+
numTokens = n
593+
hasNumTokens = true
594+
} else if f, err := strconv.ParseFloat(s, 64); err == nil {
595+
numTokens = int(f)
596+
hasNumTokens = true
597+
}
598+
}
599+
}
600+
}
601+
602+
if v, ok := cfg.ProviderOpts["speculative_acceptance_rate"]; ok {
603+
switch t := v.(type) {
604+
case float64:
605+
acceptanceRate = t
606+
hasAcceptanceRate = true
607+
case uint64:
608+
acceptanceRate = float64(t)
609+
hasAcceptanceRate = true
610+
case string:
611+
s := strings.TrimSpace(t)
612+
if s != "" {
613+
if f, err := strconv.ParseFloat(s, 64); err == nil {
614+
acceptanceRate = f
615+
hasAcceptanceRate = true
616+
}
617+
}
618+
}
619+
}
620+
621+
// Only create specOpts if at least one field is set
622+
if hasDraftModel || hasNumTokens || hasAcceptanceRate {
623+
specOpts = &speculativeDecodingOpts{
624+
draftModel: draftModel,
625+
numTokens: numTokens,
626+
acceptanceRate: acceptanceRate,
627+
}
628+
}
558629
}
559630

560-
return contextSize, runtimeFlags
631+
return contextSize, runtimeFlags, specOpts
561632
}
562633

563634
func pullDockerModelIfNeeded(ctx context.Context, model string) error {
@@ -615,8 +686,8 @@ func modelExists(ctx context.Context, model string) bool {
615686
return true
616687
}
617688

618-
func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string) error {
619-
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags)
689+
func configureDockerModel(ctx context.Context, model string, contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) error {
690+
args := buildDockerModelConfigureArgs(model, contextSize, runtimeFlags, specOpts)
620691

621692
cmd := exec.CommandContext(ctx, "docker", args...)
622693
slog.Debug("Running docker model configure", "model", model, "args", args)
@@ -631,12 +702,23 @@ func configureDockerModel(ctx context.Context, model string, contextSize int, ru
631702
}
632703

633704
// buildDockerModelConfigureArgs returns the argument vector passed to `docker` for model configuration.
634-
// It formats context size and runtime flags consistently with the CLI contract.
635-
func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags []string) []string {
705+
// It formats context size, speculative decoding options, and runtime flags consistently with the CLI contract.
706+
func buildDockerModelConfigureArgs(model string, contextSize int, runtimeFlags []string, specOpts *speculativeDecodingOpts) []string {
636707
args := []string{"model", "configure"}
637708
if contextSize > 0 {
638709
args = append(args, "--context-size="+strconv.Itoa(contextSize))
639710
}
711+
if specOpts != nil {
712+
if specOpts.draftModel != "" {
713+
args = append(args, "--speculative-draft-model="+specOpts.draftModel)
714+
}
715+
if specOpts.numTokens > 0 {
716+
args = append(args, "--speculative-num-tokens="+strconv.Itoa(specOpts.numTokens))
717+
}
718+
if specOpts.acceptanceRate > 0 {
719+
args = append(args, "--speculative-min-acceptance-rate="+strconv.FormatFloat(specOpts.acceptanceRate, 'f', -1, 64))
720+
}
721+
}
640722
args = append(args, model)
641723
if len(runtimeFlags) > 0 {
642724
args = append(args, "--")

pkg/model/provider/dmr/client_test.go

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func TestNewClientWithWrongType(t *testing.T) {
3232
}
3333

3434
func TestBuildDockerConfigureArgs(t *testing.T) {
35-
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7", "--top-p", "0.9"})
35+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7", "--top-p", "0.9"}, nil)
3636

3737
assert.Equal(t, []string{"model", "configure", "--context-size=8192", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.7", "--top-p", "0.9"}, args)
3838
}
@@ -62,7 +62,7 @@ func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) {
6262
// provider opts should be appended after derived flags so they can override by order
6363
merged := append(derived, []string{"--threads", "6"}...)
6464

65-
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged)
65+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged, nil)
6666
assert.Equal(t, []string{"model", "configure", "--context-size=4096", "ai/qwen3:14B-Q6_K", "--", "--temp", "0.6", "--top-p", "0.9", "--threads", "6"}, args)
6767
}
6868

@@ -83,3 +83,75 @@ func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) {
8383
func floatPtr(f float64) *float64 {
8484
return &f
8585
}
86+
87+
func TestBuildDockerConfigureArgsWithSpeculativeDecoding(t *testing.T) {
88+
specOpts := &speculativeDecodingOpts{
89+
draftModel: "ai/qwen3:1B",
90+
numTokens: 5,
91+
acceptanceRate: 0.8,
92+
}
93+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 8192, []string{"--temp", "0.7"}, specOpts)
94+
95+
assert.Equal(t, []string{
96+
"model", "configure",
97+
"--context-size=8192",
98+
"--speculative-draft-model=ai/qwen3:1B",
99+
"--speculative-num-tokens=5",
100+
"--speculative-acceptance-rate=0.8",
101+
"ai/qwen3:14B-Q6_K",
102+
"--",
103+
"--temp", "0.7",
104+
}, args)
105+
}
106+
107+
func TestBuildDockerConfigureArgsWithPartialSpeculativeDecoding(t *testing.T) {
108+
specOpts := &speculativeDecodingOpts{
109+
draftModel: "ai/qwen3:1B",
110+
numTokens: 5,
111+
// acceptanceRate not set (0 value)
112+
}
113+
args := buildDockerModelConfigureArgs("ai/qwen3:14B-Q6_K", 0, nil, specOpts)
114+
115+
assert.Equal(t, []string{
116+
"model", "configure",
117+
"--speculative-draft-model=ai/qwen3:1B",
118+
"--speculative-num-tokens=5",
119+
"ai/qwen3:14B-Q6_K",
120+
}, args)
121+
}
122+
123+
func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) {
124+
cfg := &latest.ModelConfig{
125+
MaxTokens: 4096,
126+
ProviderOpts: map[string]any{
127+
"speculative_draft_model": "ai/qwen3:1B",
128+
"speculative_num_tokens": 5,
129+
"speculative_acceptance_rate": 0.75,
130+
"runtime_flags": []string{"--threads", "8"},
131+
},
132+
}
133+
134+
contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg)
135+
136+
assert.Equal(t, 4096, contextSize)
137+
assert.Equal(t, []string{"--threads", "8"}, runtimeFlags)
138+
require.NotNil(t, specOpts)
139+
assert.Equal(t, "ai/qwen3:1B", specOpts.draftModel)
140+
assert.Equal(t, 5, specOpts.numTokens)
141+
assert.InEpsilon(t, 0.75, specOpts.acceptanceRate, 0.001)
142+
}
143+
144+
func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) {
145+
cfg := &latest.ModelConfig{
146+
MaxTokens: 4096,
147+
ProviderOpts: map[string]any{
148+
"runtime_flags": []string{"--threads", "8"},
149+
},
150+
}
151+
152+
contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg)
153+
154+
assert.Equal(t, 4096, contextSize)
155+
assert.Equal(t, []string{"--threads", "8"}, runtimeFlags)
156+
assert.Nil(t, specOpts)
157+
}

0 commit comments

Comments
 (0)