Skip to content

Conversation

@toilaluan
Copy link

What does this PR do?

Adding TaylorSeer Caching method to accelerate inference speed mentioned in #12569

Author's codebase: https://github.com/Shenyi-Z/TaylorSeer

This PR structure will heavily mimic FasterCache (https://github.com/huggingface/diffusers/pull/10163/files) behaviour
I prioritze to make it work on image model pipelines (Flux, Qwen Image) for ease of evaluation

Expected Output

4->5x speeding up by these settings while keep output images are qualified

image

State Design

Core of this algorithm is about predict features of step t by using real computed features from previous step using Taylor Expansion Approximation.
We design a State class, include predict & update method and taylor_factors: Tensor to maintain iteration information. Each feature tensor will be bounded to a state instance (in double stream attention class in Flux & QwenImage, output of this module is image_features & txt_features, we will create 2 state instances for them)

  • update method will be called from real compute timestep and update taylor_factors using math formular referenced to original implementation
  • predict method will be called to predict feature from current taylor_factors using math formular referenced to original implementation

@seed93
Copy link

seed93 commented Nov 14, 2025

Will you adapt this great PR for flux kontext controlnet or flux controlnet? It would be nice if it is implemented and I am very eager to try it out.

@toilaluan
Copy link
Author

@seed93 yes, i am prioritizing for flux series and qwen image

@toilaluan
Copy link
Author

Here is analysis about TaylorSeer for Flux
Comparing with baseline, the output image is different, although PAB method give pretty close result
This result is match with author's implementation

model_id cache_method compute_dtype compile time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
flux none fp16 False 22.318 33.313 33.322 33.322 34.305
flux pyramid_attention_broadcast fp16 False 18.394 33.313 33.322 33.322 35.789
flux taylorseer_cache fp16 False 6.457 33.313 33.322 33.322 38.18

Flux visual results

Baseline

image

Pyramid Attention Broadcast

image

TaylorSeer Cache (this implementation)

image

TaylorSeer Original (https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-Diffusers/taylorseer_flux/diffusers_taylorseer_flux.py)

image

Benchmark code is based on #10163

import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/root/diffusers")
branch = repo.active_branch

from diffusers import (
    apply_taylorseer_cache, 
    TaylorSeerCacheConfig, 
    apply_faster_cache, 
    FasterCacheConfig, 
    apply_pyramid_attention_broadcast, 
    PyramidAttentionBroadcastConfig,
)

def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def prepare_flux_config(cache_method: str, pipe: FluxPipeline):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            current_timestep_callback=lambda: pipe.current_timestep,
        )
    elif cache_method == "taylorseer_cache":
        return TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float16, architecture="flux")
    elif cache_method == "fastercache":
        return FasterCacheConfig(
        spatial_attention_block_skip_range=2,
        spatial_attention_timestep_skip_range=(-1, 681),
        low_frequency_weight_update_timestep_range=(99, 641),
        high_frequency_weight_update_timestep_range=(-1, 301),
        spatial_attention_block_identifiers=["transformer_blocks"],
        attention_weight_callback=lambda _: 0.3,
        tensor_format="BFCHW",
    )
    elif cache_method == "none":
        return None


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    print(f"Generator: {generator}")
    print(f"Generation kwargs: {generation_kwargs}")
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method, pipe)
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe.transformer, config)
        elif cache_method == "fastercache":
            apply_faster_cache(pipe.transformer, config)
        elif cache_method == "taylorseer_cache":
            apply_taylorseer_cache(pipe.transformer, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "taylorseer_cache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype)
    

@toilaluan
Copy link
Author

More comparison between this impl, baseline, author's impl

image

@toilaluan
Copy link
Author

I think current implementation is unified for every models that have attention modules, but to achieve full optimization, we have to config regex for which layer to cache or skip compute
Example in a sequence of Linear1, Act1, Linear2, Act2: we need to add hook for Linear1,act1,linear2 to do nothing (return an empty tensor) but cache output of act2
I already fix template for flux, but for other models, user have to write their own and pass it to the config init
@sayakpaul how do you think about this mechanism? I need some advises here

@sayakpaul sayakpaul requested a review from DN6 November 14, 2025 17:41
@toilaluan
Copy link
Author

toilaluan commented Nov 15, 2025

Tuning cache config really helps!

TaylorSeer cache configuration comparison

In the original code, they use 3 warmup steps and no cooldown. The output image differs significantly from the baseline, as shown in the report above.

As suggested in Shenyi-Z/TaylorSeer#12, increasing the warmup steps to 10 helps narrow the gap, but the cached output still has noticeable artifacts. This naturally suggested adding a cooldown phase (running the last steps without caching).

All runs below use the same prompt and 50 inference steps.

Visual comparison

Baseline vs. 3 warmup / 0 cooldown

Baseline (no cache) 3 warmup steps, 0 cooldown (cache)
Baseline output 3 warmup, 0 cooldown output

With only 3 warmup steps and 0 cooldown steps, the image content is not very close to the baseline.

10 warmup / 0 cooldown vs. 10 warmup / 5 cooldown

10 warmup steps, 0 cooldown (cache) 10 warmup steps, 5 cooldown (cache)
10 warmup, 0 cooldown output 10 warmup, 5 cooldown output

With 10 warmup steps, the content is closer to the baseline, but there are still many artifacts and noise.
By running the last 5 steps without caching (cooldown), most of these issues are resolved.


Hardware usage comparison

The table below shows the hardware usage comparison:

cache_method predict_steps max_order warmup_steps stop_predicts time (s) model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none - - - - 22.781 33.313 33.321 37.943 fp16
taylorseer_cache 5.0 1.0 3.0 - 7.099 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 9.024 55.490 55.490 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 - 9.451 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 11.000 55.490 55.490 70.283 fp16
taylorseer_cache 6.0 1.0 3.0 - 6.701 55.492 55.492 70.285 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 8.651 55.490 55.490 70.285 fp16
taylorseer_cache 6.0 1.0 10.0 - 9.053 55.492 55.492 70.283 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 11.001 55.490 55.490 70.283 fp16
image

Code

import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import FluxPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
    generation_kwargs = {
        "prompt": prompt,
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.float16
    taylor_factors_dtype = torch.float16

    param_grid = {
        'predict_steps': [5, 6],
        'max_order': [1],
        'warmup_steps': [3, 10],
        'stop_predicts': [None, 45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    set_verbosity_info()
    main(args.output_dir)

@toilaluan
Copy link
Author

Similar behavior with Qwen Image

cache_method predict_steps max_order warmup_steps stop_predicts time model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none 23.01 53.791 53.807 64.359 fp16
taylorseer_cache 5.0 1.0 3.0 13.457 53.807 53.813 67.303 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 14.562 53.813 53.819 67.303 fp16
taylorseer_cache 5.0 1.0 10.0 14.775 53.819 53.825 67.303 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 15.628 53.825 53.832 67.322 fp16
taylorseer_cache 6.0 1.0 3.0 13.214 53.832 53.839 67.322 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 14.349 53.838 53.845 67.322 fp16
taylorseer_cache 6.0 1.0 10.0 14.595 53.844 53.851 67.322 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 15.707 53.851 53.858 67.342 fp16

@toilaluan
Copy link
Author

toilaluan commented Nov 16, 2025

Flux Kontext – Cache vs Baseline Comparison

@seed93
I tested Flux Kontext using several cache configurations, and the results look promising. Below is a comparison of the baseline output and the cached versions.


Original Image

original image

Image Comparison (Side-by-Side)

Baseline — “add a hat to the cat” predict_steps=7, O=1, warmup=10, cooldown=5 predict_steps=8

Processing Times (CSV → Markdown Table)

| cache_method     | predict_steps | max_order | warmup_steps | stop_predicts | time    | model_memory_gb | inference_memory_gb | max_memory_reserved_gb | compute_dtype |
|------------------|---------------|-----------|--------------|---------------|---------|------------------|----------------------|--------------------------|---------------|
| none             |               |           |              |               | 48.391  | 31.438           | 31.446               | 36.209                   | fp16          |
| taylorseer_cache | 7.0           | 1.0       | 10.0         | 45.0          | 21.468  | 31.447           | 31.447               | 44.625                   | fp16          |
| taylorseer_cache | 8.0           | 1.0       | 10.0         | 45.0          | 20.633  | 31.447           | 31.447               | 44.625                   | fp16          |

Reproduce Code

import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import DiffusionPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    from diffusers.utils import load_image
    model_id = "black-forest-labs/FLUX.1-Kontext-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang, Ultra HD, 4K, cinematic composition."
    edit_prompt = "Add a hat to the cat"
    input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
    generation_kwargs = {
        "prompt": edit_prompt,
        "num_inference_steps": 50,
        "guidance_scale": 2.5,
        "image": input_image,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.bfloat16
    taylor_factors_dtype = torch.bfloat16

    param_grid = {
        'predict_steps': [7, 8],
        'max_order': [1],
        'warmup_steps': [10],
        'stop_predicts': [45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        del pipe
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    # set_verbosity_info()
    main(args.output_dir)

@toilaluan toilaluan marked this pull request as ready for review November 17, 2025 06:21
@seed93
Copy link

seed93 commented Nov 17, 2025

Flux Kontext – Cache vs Baseline Comparison

@seed93 I tested Flux Kontext using several cache configurations, and the results look promising. Below is a comparison of the baseline output and the cached versions.

Original Image

original image ## **Image Comparison (Side-by-Side)** Baseline — “add a hat to the cat” predict_steps=7, O=1, warmup=10, cooldown=5 predict_steps=8 ## **Processing Times (CSV → Markdown Table)** ``` | cache_method | predict_steps | max_order | warmup_steps | stop_predicts | time | model_memory_gb | inference_memory_gb | max_memory_reserved_gb | compute_dtype | |------------------|---------------|-----------|--------------|---------------|---------|------------------|----------------------|--------------------------|---------------| | none | | | | | 48.391 | 31.438 | 31.446 | 36.209 | fp16 | | taylorseer_cache | 7.0 | 1.0 | 10.0 | 45.0 | 21.468 | 31.447 | 31.447 | 44.625 | fp16 | | taylorseer_cache | 8.0 | 1.0 | 10.0 | 45.0 | 20.633 | 31.447 | 31.447 | 44.625 | fp16 | ```

Reproduce Code

This is amazing!

@seed93
Copy link

seed93 commented Nov 19, 2025

I am not sure why it uses so much gpu memory? I have only 24 GB gpu memory.

@seed93
Copy link

seed93 commented Nov 19, 2025

Could you please try using the taylorseer-lite as an option? refer to Shenyi-Z/TaylorSeer#5

@toilaluan
Copy link
Author

@seed93 yeah it seems to not complicated, i will try and post some report here

@toilaluan
Copy link
Author

Taylorseer-lite

@seed93, you can use lite version with minimal extra memory by following this script but it works for Hunyuan model, not Flux.
Flux TS-lite's output is purely noise

  • Hunyuan Output
image
  • Flux Output
image
import torch
from diffusers import FluxPipeline, HunyuanImagePipeline
from diffusers import TaylorSeerCacheConfig


model = "hunyuanimage"  # or "flux"
if model == "flux":
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
elif model == "hunyuanimage":
    pipeline = HunyuanImagePipeline.from_pretrained(
        "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
print(pipeline)

cache_config = TaylorSeerCacheConfig(
    skip_identifiers=[r"^(?!proj_out$)[^.]+\.[^.]+$"],
    cache_identifiers=[r"proj_out"],
    predict_steps=5,
    max_order=2,
    warmup_steps=10,
    stop_predicts=48,
    taylor_factors_dtype=torch.bfloat16,
)

pipeline.transformer.enable_cache(cache_config)

prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]

image.save("teddy_bear.jpg")

@toilaluan
Copy link
Author

@DN6 This feature is ready for reviewing, could you take a look 🙇

@toilaluan
Copy link
Author

@DN6 thanks for suggestions
I commited new code, mostly about naming, which i think is more informative and refractoring
new_forward hook should be more understandable, state.predict and state.update now will handle for both active and inactive module in cache
I also commented in above, explained more about skip and cache identifiers, which now are renamed to inactive_identifiers and active_identifers, but still waiting for your thoughts on naming them

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool work thus far!

Given the popularity of TaylorSeer, do we want to add a mixin (like this) and add it to a few popular models like Qwen, Flux, Flux2, etc.?

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Nov 28, 2025

Style bot fixed some files and pushed the changes.

@toilaluan
Copy link
Author

@sayakpaul Yes, I added a simple one
I notice that Flux2Pipeline.call is missing set_context

noise_pred = self.transformer(

it will raise an error when getting a state since no context set

raise ValueError("No context is set. Please set a context before retrieving the state.")

I can tricky fix it by manually setting context to default but i think this issue should be resolved from the pipeline

@sayakpaul
Copy link
Member

@toilaluan thanks! What do you think about removing the Flux2 test for now and we tackle it in a different PR?

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Nov 30, 2025

Style bot fixed some files and pushed the changes.

@toilaluan
Copy link
Author

toilaluan commented Dec 1, 2025

@sayakpaul I think there is same problem for pipeline variants (i2i, kontext, inpainting) that forget to implement context setting, i raised an issue

#12760

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor comments @toilaluan. Should be good to merge once addressed.

@kodinkod
Copy link

kodinkod commented Dec 3, 2025

Hi, Doesn't this code already include TaylorSeer?


from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig

cache_dit.enable_cache(
    pipe_or_adapter,
    # Basic DBCache w/ FnBn configurations
    cache_config=BasicCacheConfig(
        max_warmup_steps=8,  # steps do not cache
        max_cached_steps=-1, # -1 means no limit
        Fn_compute_blocks=8, # Fn, F8, etc.
        Bn_compute_blocks=8, # Bn, B8, etc.
        residual_diff_threshold=0.12,
    ),
    # Then, you can use the TaylorSeer Calibrator to approximate 
    # the values in cached steps, taylorseer_order default is 1.
    calibrator_config=TaylorSeerCalibratorConfig(
        taylorseer_order=1,
    ),
)

@toilaluan
Copy link
Author

@kodinkod yeah but we're integrating it to diffusers' codebase

@kodinkod
Copy link

kodinkod commented Dec 3, 2025

@kodinkod yeah but we're integrating it to diffusers' codebase

it's already integrated. Do I understand correctly that the TaylorSeer Cache simply has a different block caching strategy than DBCache?

@toilaluan
Copy link
Author

@kodinkod yeah but we're integrating it to diffusers' codebase

it's already integrated. Do I understand correctly that the TaylorSeer Cache simply has a different block caching strategy than DBCache?

I'm not sure where is it integrated, including DBCache in diffusers codebase.
TaylorSeer looks similar to DBCache but using different approximation method (taylor)

@DN6
Copy link
Collaborator

DN6 commented Dec 3, 2025

@toilaluan Just one more thing to consider is if the the cache is compatible with torch.compile. Conditional checks usually don't play well with torch.compile. We've disabled it for certain checks in previous cacheing approaches

@torch.compiler.disable

@toilaluan
Copy link
Author

@DN6, @sayakpaul I added a similar compiler disable to taylorseer cache, but i have an observation that can be applied to both FBCache and TaylorSeer:

  • graph will break at torch.compile.disable then it requires recompiling every block since hook is applied to block level
    By default recompile limit is 8, while total blocks of transformer is much higher (57 in flux), we have to set this limit to higher than number of blocks to achieve best performance and similar graphs to regional compiling. Check the code below:
import torch
from diffusers import FluxPipeline, HunyuanImagePipeline
from diffusers import TaylorSeerCacheConfig, FirstBlockCacheConfig

# torch._logging.set_logs(graph_code=True)

import torch._dynamo as dynamo
dynamo.config.recompile_limit = 100

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")
print(pipeline)

cache_config = TaylorSeerCacheConfig(
    cache_interval=5,
    max_order=1,
    disable_cache_before_step=50, # assume we will run full compute to see compile effect
    disable_cache_after_step=48,
    taylor_factors_dtype=torch.bfloat16,
    use_lite_mode=True
)
fbconfig = FirstBlockCacheConfig(
    threshold=1e-6, # set this value to very small so cache will not be applied to see compile effect
)

pipeline.transformer.enable_cache(fbconfig) # or cache_config

pipeline.transformer.compile(fullgraph=False, dynamic=True)

prompt = "A laptop on top of a teddy bear, realistic, high quality, 4k"
# warmup
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]
# monitor this call
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]
image.save("teddy_bear.jpg")

@toilaluan
Copy link
Author

Comparison: Baseline, Baseline-25steps, FBCache, TaylorSeer Cache

Memory & Speed Metrics

Variant Load Memory (GB) Peak Memory (GB) Processing Time (s)
Baseline 31.437537 33.828313 22.241594
Baseline-25steps 31.447058 33.829167 11.025676
FBCache 31.447058 33.905339 14.023782
TaylorSeer Cache 31.447058 33.829656 8.929811

Visual Outputs

Variant Image
Baseline Baseline
Baseline-25steps Baseline 25 steps
FBCache FBCache
TaylorSeer Cache TaylorCache

Code

```python import torch from diffusers import FluxPipeline, TaylorSeerCacheConfig, FirstBlockCacheConfig import time import os import matplotlib.pyplot as plt import pandas as pd import gc # Added for explicit garbage collection

Set dynamo config

import torch._dynamo as dynamo
dynamo.config.recompile_limit = 100

prompts = [
"A laptop on top of a teddy bear, realistic, high quality, 4k",
]

Create output folder

os.makedirs("outputs", exist_ok=True)

Define cache configs

fbconfig = FirstBlockCacheConfig(
threshold=0.05
)

tsconfig = TaylorSeerCacheConfig(
cache_interval=5,
max_order=1,
disable_cache_before_step=10,
disable_cache_after_step=48,
taylor_factors_dtype=torch.bfloat16,
use_lite_mode=True
)

Collect results

results = []

for i, prompt in enumerate(prompts):
images = {}
for variant in ['baseline', 'baseline_reduce', 'firstblock', 'taylor']:
# Clear cache before loading
gc.collect()
torch.cuda.empty_cache()

    # Load pipeline
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
    
    load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
    
    # Enable cache if applicable
    if variant == 'firstblock':
        pipeline.transformer.enable_cache(fbconfig)
    elif variant == 'taylor':
        pipeline.transformer.enable_cache(tsconfig)
    # No cache for baseline and baseline_reduce
    
    # Compile (uncomment if needed)
    # pipeline.transformer.compile(fullgraph=False, dynamic=True)
    
    # Warmup with 10 steps (uncomment if needed)
    # gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
    # _ = pipeline(
    #     prompt=prompt,
    #     width=1024,
    #     height=1024,
    #     num_inference_steps=25,
    #     guidance_scale=3.0,
    #     generator=gen_warmup
    # ).images[0]
    
    # Main run
    steps = 25 if variant == 'baseline_reduce' else 50
    
    gen_main = torch.Generator(device="cuda").manual_seed(181201)
    
    torch.cuda.reset_peak_memory_stats()
    start_time = time.time()
    image = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=steps,
        guidance_scale=3.0,
        generator=gen_main
    ).images[0]
    end_time = time.time()
    
    peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB
    processing_time = end_time - start_time
    
    # Save image
    image_path = f"outputs/{variant}_prompt{i}.jpg"
    image.save(image_path)
    images[variant] = image
    
    # Record results
    results.append({
        'Prompt Index': i,
        'Variant': variant,
        'Load Memory (GB)': load_mem_gb,
        'Peak Memory (GB)': peak_mem_gb,
        'Processing Time (s)': processing_time
    })
    
    # Clean up
    pipeline.to("cpu")
    del pipeline
    gc.collect()  # Force garbage collection
    torch.cuda.empty_cache()  # Empty CUDA cache after GC
    dynamo.reset()  # Reset Dynamo cache (harmless even if not compiling)

# Plot image comparison for this prompt
fig, axs = plt.subplots(1, 4, figsize=(40, 10))
variants_order = ['baseline', 'baseline_reduce', 'firstblock', 'taylor']
for j, var in enumerate(variants_order):
    axs[j].imshow(images[var])
    axs[j].set_title(var)
    axs[j].axis('off')
plt.tight_layout()
plt.savefig(f"outputs/comparison_prompt{i}.png")
plt.close()

Print speed and memory comparison as a table

df = pd.DataFrame(results)
print("Speed and Memory Comparison:")
print(df.to_string(index=False))

Optionally, plot bar charts for averages

avg_df = df.groupby('Variant').mean().reset_index()
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.bar(avg_df['Variant'], avg_df['Processing Time (s)'], color='b', label='Time (s)')
ax1.set_ylabel('Processing Time (s)')
ax2 = ax1.twinx()
ax2.plot(avg_df['Variant'], avg_df['Peak Memory (GB)'], color='r', marker='o', label='Peak Memory (GB)')
ax2.set_ylabel('Peak Memory (GB)')
fig.suptitle('Average Speed and Memory Comparison')
fig.legend()
plt.savefig("outputs/metrics_comparison.png")
plt.close()

</details>

@sayakpaul
Copy link
Member

graph will break at torch.compile.disable then it requires recompiling every block since hook is applied to block level
By default recompile limit is 8, while total blocks of transformer is much higher (57 in flux), we have to set this limit to higher than number of blocks to achieve best performance and similar graphs to regional compiling. Check the code below:

Yes, increasing the compile limit is fine here.

Some questions / notes:

  • In the code snippet provided in [Feat] TaylorSeer Cache #12648 (comment), why do we need dynamic=True?
  • Could we also add the compilation timing in here to see if that helps at all (especially with the recompilations)?
  • Let's try to add this comparison (just a link your comment is fine) in the docs? I think this is golden information!

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2025

Style bot fixed some files and pushed the changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants