Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/python/CuTeDSL/cute/print_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.

# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.

# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import cutlass
import cutlass.cute as cute
from cutlass.utils import print_latex, print_latex_tv
from cutlass import for_generate, yield_out

"""
A Latex Printing Example using CuTe DSL.

This example prints latex for a given layout or thread value layout.

The primary goal for this example is to demonstrate how to dump latex, which can then be
turned into an image in your favorite latex compiler.

To run this example:

.. code-block:: bash

python examples/python/CuteDSL/cute/print_latex.py
python examples/python/CuteDSL/cute/print_latex.py --tv_layout

To compile, pipe the output to a file and use a tool like pdflatex:

.. code-block:: bash
python examples/python/CuTeDSL/cute/print_latex.py > latex.tex
pdflatex latex.tex
"""


@cute.jit
def main(print_tv_layout: cutlass.Constexpr[bool]):
# Note: only support compile time printing layouts
if cutlass.const_expr(print_tv_layout):
thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0))
val_layout = cute.make_ordered_layout((4, 1), order=(1, 0))
tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout)
print_latex_tv(tv_layout, tiler_mn)
else:
layout = cute.make_layout((10, 10))
print_latex(layout)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="example of print latex and print latex tv"
)
parser.add_argument("--tv_layout", action="store_true")

args = parser.parse_args()

main(args.tv_layout)
4 changes: 4 additions & 0 deletions python/CuTeDSL/cutlass/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
sm_wise_inter_gpu_multimem_barrier,
)

from .print_latex import print_latex, print_latex_tv

__all__ = [
"get_smem_capacity_in_bytes",
"SmemAllocator",
Expand All @@ -90,4 +92,6 @@
"create_initial_search_state",
"GroupedGemmTileSchedulerHelper",
"HardwareInfo",
"print_latex",
"print_latex_tv",
]
169 changes: 169 additions & 0 deletions python/CuTeDSL/cutlass/utils/print_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

from typing import Callable, Union

from ..cute import (
Layout,
ComposedLayout,
append,
is_static,
make_layout,
size,
product_each,
rank,
)
from ..cute.typing import IntTuple

__all__ = ["print_latex", "print_latex_tv"]


def tikz_color_bwx8(idx: int):
color_map = [
"black!00",
"black!40",
"black!20",
"black!60",
"black!10",
"black!50",
"black!30",
"black!70",
]
return color_map[idx % 8]


def tikz_color_white(idx: int):
return "white"


def tikz_color_tv(tid: int, vid: int):
color_map = [
"{rgb,255:red,175;green,175;blue,255}",
"{rgb,255:red,175;green,255;blue,175}",
"{rgb,255:red,255;green,255;blue,175}",
"{rgb,255:red,255;green,175;blue,175}",
"{rgb,255:red,210;green,210;blue,255}",
"{rgb,255:red,210;green,255;blue,210}",
"{rgb,255:red,255;green,255;blue,210}",
"{rgb,255:red,255;green,210;blue,210}",
]
return color_map[tid % 8]


def print_latex(x: Union[Layout, ComposedLayout], *, color: Callable = tikz_color_bwx8):
"""
Prints a layout.

:param x: A layout
:type x: Union[Layout, ComposedLayout]
:param color: A function that returns TiKZ colors
:type color: Callable

"""

if not is_static(x):
raise ValueError("Requires static input")
if rank(x) > 2:
raise ValueError("Requires rank <= 2 to print")

if rank(x) == 1:
layout = append(x, make_layout(1, stride=0))
else:
layout = x

print("%% Layout: {}", layout)
print("\\documentclass[convert]{standalone}")
print("\\usepackage{tikz}")
print("\\begin{document}")
print(
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]"
)

M, N = product_each(x.shape)

for m in range(M):
for n in range(N):
idx = layout((m, n))
print("\\node[fill=")
print(color(idx))
print("] at (%d,%d) {%d};\n" % (m, n, idx))
print(
"\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N)
)
for m in range(M):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
for n in range(N):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))

## Footer
print("\\end{tikzpicture}")
print("\\end{document}")


def print_latex_tv(
layout_tv: Union[Layout, ComposedLayout],
tile_mn: Union[IntTuple, Layout],
*,
color: Callable = tikz_color_tv,
):
"""
Prints a tv layout for a tile M N. Everything must be static.

:param layout_tv: A static thread value layout
:type layout_tv: Union[Layout, ComposedLayout]
:param tile_mn: A static M N tile
:type tile_mn: Union[IntTuple, Layout]
:param color: A function that returns TiKZ colors
:type color: Callable

"""
if not is_static(layout_tv) or not is_static(tile_mn):
raise ValueError("Layout tv and tile_mn must be static")
if rank(layout_tv) != 2:
raise ValueError("Require layout_tv to be rank 2")

print("%% Layout TV: {}", layout_tv)
print("\\documentclass[convert]{standalone}")
print("\\usepackage{tikz}")
print("\\begin{document}")
print(
"\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n"
)

if not isinstance(tile_mn, Layout):
tile_mn = make_layout(tile_mn)

M, N = product_each(tile_mn.shape)
filled = [[False for n in range(N)] for m in range(M)]

for tid in range(size(layout_tv, mode=[0])):
for vid in range(size(layout_tv, mode=[1])):
idx = layout_tv((tid, vid))
m = (idx // tile_mn.stride[0]) % tile_mn.shape[0]
n = (idx // tile_mn.stride[1]) % tile_mn.shape[1]
if not filled[m][n]:
filled[m][n] = True
print(
"\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n"
% (color(tid, vid), m, n, tid, vid)
)

print(
"\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n" % (M, N)
)
for m in range(M):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (m, -1, m))
for n in range(N):
print("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n" % (-1, n, n))

## Footer
print("\\end{tikzpicture}")
print("\\end{document}")