Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pyrtl/rtllib/pyrtlfloat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode
from .floatoperations import FloatOperations
from .floatwirevector import Float16WireVector
from .floatoperations import (
BFloat16Operations,
Float16Operations,
Float32Operations,
Float64Operations,
FloatOperations,
)

__all__ = [
"FloatingPointType",
"FPTypeProperties",
"PyrtlFloatConfig",
"RoundingMode",
"FloatOperations",
"Float16WireVector",
"BFloat16Operations",
"Float16Operations",
"Float32Operations",
"Float64Operations",
]
48 changes: 46 additions & 2 deletions pyrtl/rtllib/pyrtlfloat/floatoperations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from ._add_sub import AddSubHelper
from ._multiplication import MultiplicationHelper
from ._types import PyrtlFloatConfig, RoundingMode
from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode


class FloatOperations:
default_rounding_mode = RoundingMode.RNE

@staticmethod
def multiply(
def mul(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem very useful, since it just passes all its arguments through to MultiplicationHelper.multiply? Seems better to point users directly to MultiplicationHelper.multiply. Same comments for the other methods below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I have this FloatOperations class and separate MultiplicationHelper and AddSubHelper classes is because the user is supposed to use FloatOperations, MultiplicationHelper and AddSubHelper are internal helpers and the user is not supposed to use them. I created these helpers so I can separate the multiplication and addition/subtraction logic into separate files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense, but I'd still try to refactor, it would be nice to remove these functions that are just pass-through wrappers for another function. Maybe:

  1. AddSubHelper.add could be an ordinary function instead of a @staticmethod
  2. _add_sub.py could be renamed to add_sub.py
  3. Non-public functions in add_sub.py could have an underscore prefix
  4. __init__.py could import add from add_sub

I think that would still make it clear which parts of the interface are meant for public use, while removing a layer of indirection?

This kind of indirection tends to be annoying when debugging problems, because you have to jump through another hoop to find the code you're looking for. "Our princess is in another castle!"

config: PyrtlFloatConfig,
operand_a: pyrtl.WireVector,
operand_b: pyrtl.WireVector,
Expand All @@ -31,3 +31,47 @@ def sub(
operand_b: pyrtl.WireVector,
) -> pyrtl.WireVector:
return AddSubHelper.sub(config, operand_a, operand_b)


class _BaseTypedFloatOperations:
_fp_type: FloatingPointType = None

@classmethod
def mul(
cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector
) -> pyrtl.WireVector:
return FloatOperations.mul(cls._get_config(), operand_a, operand_b)

@classmethod
def add(
cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector
) -> pyrtl.WireVector:
return FloatOperations.add(cls._get_config(), operand_a, operand_b)

@classmethod
def sub(
cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector
) -> pyrtl.WireVector:
return FloatOperations.sub(cls._get_config(), operand_a, operand_b)

@classmethod
def _get_config(cls) -> PyrtlFloatConfig:
return PyrtlFloatConfig(
cls._fp_type.value, FloatOperations.default_rounding_mode
)


class BFloat16Operations(_BaseTypedFloatOperations):
_fp_type = FloatingPointType.BFLOAT16


class Float16Operations(_BaseTypedFloatOperations):
_fp_type = FloatingPointType.FLOAT16


class Float32Operations(_BaseTypedFloatOperations):
_fp_type = FloatingPointType.FLOAT32


class Float64Operations(_BaseTypedFloatOperations):
_fp_type = FloatingPointType.FLOAT64
46 changes: 0 additions & 46 deletions pyrtl/rtllib/pyrtlfloat/floatwirevector.py

This file was deleted.

10 changes: 3 additions & 7 deletions tests/rtllib/pyrtlfloat/test_add_sub.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
import unittest

import pyrtl
from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode
from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode


class TestMultiplication(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()
a = pyrtl.Input(bitwidth=16, name="a")
b = pyrtl.Input(bitwidth=16, name="b")
a_floatwv = Float16WireVector()
a_floatwv <<= a
b_floatwv = Float16WireVector()
b_floatwv <<= b
FloatOperations.default_rounding_mode = RoundingMode.RNE
result_add = pyrtl.Output(name="result_add")
result_add <<= a_floatwv + b_floatwv
result_add <<= Float16Operations.add(a, b)
result_sub = pyrtl.Output(name="result_sub")
result_sub <<= a_floatwv - b_floatwv
result_sub <<= Float16Operations.sub(a, b)
self.sim = pyrtl.Simulation()

def test_multiplication_simple(self):
Expand Down
10 changes: 3 additions & 7 deletions tests/rtllib/pyrtlfloat/test_multiplication.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
import unittest

import pyrtl
from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode
from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode


class TestMultiplication(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()
a = pyrtl.Input(bitwidth=16, name="a")
b = pyrtl.Input(bitwidth=16, name="b")
a_floatwv = Float16WireVector()
a_floatwv <<= a
b_floatwv = Float16WireVector()
b_floatwv <<= b
FloatOperations.default_rounding_mode = RoundingMode.RNE
result_rne = pyrtl.Output(name="result_rne")
result_rne <<= a_floatwv * b_floatwv
result_rne <<= Float16Operations.mul(a, b)
FloatOperations.default_rounding_mode = RoundingMode.RTZ
result_rtz = pyrtl.Output(name="result_rtz")
result_rtz <<= a_floatwv * b_floatwv
result_rtz <<= Float16Operations.mul(a, b)
self.sim = pyrtl.Simulation()

def test_multiplication_simple(self):
Expand Down