@@ -98,14 +98,14 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt
9898 clientOptions = append (clientOptions , option .WithBaseURL (baseURL ), option .WithAPIKey ("" )) // DMR doesn't need auth
9999
100100 // Build runtime flags from ModelConfig and engine
101- contextSize , providerRuntimeFlags := parseDMRProviderOpts (cfg )
101+ contextSize , providerRuntimeFlags , specOpts := parseDMRProviderOpts (cfg )
102102 configFlags := buildRuntimeFlagsFromModelConfig (engine , cfg )
103103 finalFlags , warnings := mergeRuntimeFlagsPreferUser (configFlags , providerRuntimeFlags )
104104 for _ , w := range warnings {
105105 slog .Warn (w )
106106 }
107- slog .Debug ("DMR provider_opts parsed" , "model" , cfg .Model , "context_size" , contextSize , "runtime_flags" , finalFlags , "engine" , engine )
108- if err := configureDockerModel (ctx , cfg .Model , contextSize , finalFlags ); err != nil {
107+ slog .Debug ("DMR provider_opts parsed" , "model" , cfg .Model , "context_size" , contextSize , "runtime_flags" , finalFlags , "speculative_opts" , specOpts , " engine" , engine )
108+ if err := configureDockerModel (ctx , cfg .Model , contextSize , finalFlags , specOpts ); err != nil {
109109 slog .Debug ("docker model configure skipped or failed" , "error" , err )
110110 }
111111
@@ -533,14 +533,22 @@ func ConvertParametersToSchema(params any) (any, error) {
533533 return m , nil
534534}
535535
536- func parseDMRProviderOpts (cfg * latest.ModelConfig ) (contextSize int , runtimeFlags []string ) {
536+ type speculativeDecodingOpts struct {
537+ draftModel string
538+ numTokens int
539+ acceptanceRate float64
540+ }
541+
542+ func parseDMRProviderOpts (cfg * latest.ModelConfig ) (contextSize int , runtimeFlags []string , specOpts * speculativeDecodingOpts ) {
537543 if cfg == nil {
538- return 0 , nil
544+ return 0 , nil , nil
539545 }
540546
541547 // Context length is now sourced from the standard max_tokens field
542548 contextSize = cfg .MaxTokens
543549
550+ slog .Debug ("DMR provider opts" , "provider_opts" , cfg .ProviderOpts )
551+
544552 if len (cfg .ProviderOpts ) > 0 {
545553 if v , ok := cfg .ProviderOpts ["runtime_flags" ]; ok {
546554 switch t := v .(type ) {
@@ -555,9 +563,72 @@ func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize int, runtimeFlag
555563 runtimeFlags = append (runtimeFlags , parts ... )
556564 }
557565 }
566+
567+ // Parse speculative decoding options
568+ var hasDraftModel , hasNumTokens , hasAcceptanceRate bool
569+ var draftModel string
570+ var numTokens int
571+ var acceptanceRate float64
572+
573+ if v , ok := cfg .ProviderOpts ["speculative_draft_model" ]; ok {
574+ if s , ok := v .(string ); ok && s != "" {
575+ draftModel = s
576+ hasDraftModel = true
577+ }
578+ }
579+
580+ if v , ok := cfg .ProviderOpts ["speculative_num_tokens" ]; ok {
581+ switch t := v .(type ) {
582+ case float64 :
583+ numTokens = int (t )
584+ hasNumTokens = true
585+ case uint64 :
586+ numTokens = int (t )
587+ hasNumTokens = true
588+ case string :
589+ s := strings .TrimSpace (t )
590+ if s != "" {
591+ if n , err := strconv .Atoi (s ); err == nil {
592+ numTokens = n
593+ hasNumTokens = true
594+ } else if f , err := strconv .ParseFloat (s , 64 ); err == nil {
595+ numTokens = int (f )
596+ hasNumTokens = true
597+ }
598+ }
599+ }
600+ }
601+
602+ if v , ok := cfg .ProviderOpts ["speculative_acceptance_rate" ]; ok {
603+ switch t := v .(type ) {
604+ case float64 :
605+ acceptanceRate = t
606+ hasAcceptanceRate = true
607+ case uint64 :
608+ acceptanceRate = float64 (t )
609+ hasAcceptanceRate = true
610+ case string :
611+ s := strings .TrimSpace (t )
612+ if s != "" {
613+ if f , err := strconv .ParseFloat (s , 64 ); err == nil {
614+ acceptanceRate = f
615+ hasAcceptanceRate = true
616+ }
617+ }
618+ }
619+ }
620+
621+ // Only create specOpts if at least one field is set
622+ if hasDraftModel || hasNumTokens || hasAcceptanceRate {
623+ specOpts = & speculativeDecodingOpts {
624+ draftModel : draftModel ,
625+ numTokens : numTokens ,
626+ acceptanceRate : acceptanceRate ,
627+ }
628+ }
558629 }
559630
560- return contextSize , runtimeFlags
631+ return contextSize , runtimeFlags , specOpts
561632}
562633
563634func pullDockerModelIfNeeded (ctx context.Context , model string ) error {
@@ -615,8 +686,8 @@ func modelExists(ctx context.Context, model string) bool {
615686 return true
616687}
617688
618- func configureDockerModel (ctx context.Context , model string , contextSize int , runtimeFlags []string ) error {
619- args := buildDockerModelConfigureArgs (model , contextSize , runtimeFlags )
689+ func configureDockerModel (ctx context.Context , model string , contextSize int , runtimeFlags []string , specOpts * speculativeDecodingOpts ) error {
690+ args := buildDockerModelConfigureArgs (model , contextSize , runtimeFlags , specOpts )
620691
621692 cmd := exec .CommandContext (ctx , "docker" , args ... )
622693 slog .Debug ("Running docker model configure" , "model" , model , "args" , args )
@@ -631,12 +702,23 @@ func configureDockerModel(ctx context.Context, model string, contextSize int, ru
631702}
632703
633704// buildDockerModelConfigureArgs returns the argument vector passed to `docker` for model configuration.
634- // It formats context size and runtime flags consistently with the CLI contract.
635- func buildDockerModelConfigureArgs (model string , contextSize int , runtimeFlags []string ) []string {
705+ // It formats context size, speculative decoding options, and runtime flags consistently with the CLI contract.
706+ func buildDockerModelConfigureArgs (model string , contextSize int , runtimeFlags []string , specOpts * speculativeDecodingOpts ) []string {
636707 args := []string {"model" , "configure" }
637708 if contextSize > 0 {
638709 args = append (args , "--context-size=" + strconv .Itoa (contextSize ))
639710 }
711+ if specOpts != nil {
712+ if specOpts .draftModel != "" {
713+ args = append (args , "--speculative-draft-model=" + specOpts .draftModel )
714+ }
715+ if specOpts .numTokens > 0 {
716+ args = append (args , "--speculative-num-tokens=" + strconv .Itoa (specOpts .numTokens ))
717+ }
718+ if specOpts .acceptanceRate > 0 {
719+ args = append (args , "--speculative-min-acceptance-rate=" + strconv .FormatFloat (specOpts .acceptanceRate , 'f' , - 1 , 64 ))
720+ }
721+ }
640722 args = append (args , model )
641723 if len (runtimeFlags ) > 0 {
642724 args = append (args , "--" )
0 commit comments