From 94c4c68531c3eb29de726407ba9da6b8bab52b1a Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Wed, 29 Oct 2025 22:30:52 -0700 Subject: [PATCH 1/2] Pyrtl floating point library --- pyrtl/rtllib/pyrtlfloat/__init__.py | 12 + pyrtl/rtllib/pyrtlfloat/_add_sub.py | 342 ++++++++++++++++++ pyrtl/rtllib/pyrtlfloat/_float_utills.py | 104 ++++++ pyrtl/rtllib/pyrtlfloat/_multiplication.py | 161 +++++++++ pyrtl/rtllib/pyrtlfloat/_types.py | 30 ++ pyrtl/rtllib/pyrtlfloat/floatoperations.py | 33 ++ pyrtl/rtllib/pyrtlfloat/floatwirevector.py | 46 +++ tests/rtllib/pyrtlfloat/test_add_sub.py | 30 ++ .../rtllib/pyrtlfloat/test_multiplication.py | 31 ++ 9 files changed, 789 insertions(+) create mode 100644 pyrtl/rtllib/pyrtlfloat/__init__.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_add_sub.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_float_utills.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_multiplication.py create mode 100644 pyrtl/rtllib/pyrtlfloat/_types.py create mode 100644 pyrtl/rtllib/pyrtlfloat/floatoperations.py create mode 100644 pyrtl/rtllib/pyrtlfloat/floatwirevector.py create mode 100644 tests/rtllib/pyrtlfloat/test_add_sub.py create mode 100644 tests/rtllib/pyrtlfloat/test_multiplication.py diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py new file mode 100644 index 00000000..df407d93 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/__init__.py @@ -0,0 +1,12 @@ +from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode +from .floatoperations import FloatOperations +from .floatwirevector import Float16WireVector + +__all__ = [ + "FloatingPointType", + "FPTypeProperties", + "PyrtlFloatConfig", + "RoundingMode", + "FloatOperations", + "Float16WireVector", +] diff --git a/pyrtl/rtllib/pyrtlfloat/_add_sub.py b/pyrtl/rtllib/pyrtlfloat/_add_sub.py new file mode 100644 index 00000000..419d647f --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_add_sub.py @@ -0,0 +1,342 @@ +import pyrtl + +from ._float_utills import FloatUtils +from ._types import PyrtlFloatConfig, RoundingMode + + +class AddSubHelper: + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + total_bits = num_exp_bits + num_mant_bits + 1 + + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + + # operand_smaller is the operand with the smaller absolute value and + # operand_larger is the operand with the larger absolute value + operand_smaller = pyrtl.WireVector(bitwidth=total_bits) + operand_larger = pyrtl.WireVector(bitwidth=total_bits) + + with pyrtl.conditional_assignment: + exponent_and_mantissa_len = num_mant_bits + num_exp_bits + with ( + operand_a_daz[:exponent_and_mantissa_len] + < operand_b_daz[:exponent_and_mantissa_len] + ): + operand_smaller |= operand_a_daz + operand_larger |= operand_b_daz + with pyrtl.otherwise: + operand_smaller |= operand_b_daz + operand_larger |= operand_a_daz + + smaller_operand_sign = FloatUtils.get_sign(fp_type_props, operand_smaller) + larger_operand_sign = FloatUtils.get_sign(fp_type_props, operand_larger) + smaller_operand_exponent = FloatUtils.get_exponent( + fp_type_props, operand_smaller + ) + larger_operand_exponent = FloatUtils.get_exponent(fp_type_props, operand_larger) + smaller_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_smaller) + ) + larger_operand_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_larger) + ) + + exponent_diff = larger_operand_exponent - smaller_operand_exponent + smaller_mantissa_shifted = pyrtl.shift_right_logical( + smaller_operand_mantissa, exponent_diff + ) + grs = pyrtl.WireVector(bitwidth=3) # guard, round, sticky bits for rounding + with pyrtl.conditional_assignment: + with exponent_diff >= 2: + guard_and_round = pyrtl.shift_right_logical( + smaller_operand_mantissa, exponent_diff - 2 + )[:2] + mask = ( + pyrtl.shift_left_logical( + pyrtl.Const(1, bitwidth=num_mant_bits), exponent_diff - 2 + ) + - 1 + ) + sticky = (smaller_operand_mantissa & mask) != 0 + grs |= pyrtl.concat(guard_and_round, sticky) + with exponent_diff == 1: + grs |= pyrtl.concat( + smaller_operand_mantissa[0], pyrtl.Const(0, bitwidth=2) + ) + with pyrtl.otherwise: + grs |= 0 + smaller_mantissa_shifted_grs = pyrtl.concat(smaller_mantissa_shifted, grs) + larger_mantissa_extended = pyrtl.concat( + larger_operand_mantissa, pyrtl.Const(0, bitwidth=3) + ) + + sum_exponent, sum_mantissa, sum_grs, sum_carry = AddSubHelper._add_operands( + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + + sub_exponent, sub_mantissa, sub_grs, num_leading_zeros = ( + AddSubHelper._sub_operands( + num_mant_bits, + larger_operand_exponent, + smaller_mantissa_shifted_grs, + larger_mantissa_extended, + ) + ) + + # WireVectors for the raw addition or subtraction result, before handling + # special cases + raw_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + raw_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + if rounding_mode == RoundingMode.RNE: + raw_result_grs = pyrtl.WireVector(bitwidth=3) + + with pyrtl.conditional_assignment: + with smaller_operand_sign == larger_operand_sign: # add + raw_result_exponent |= sum_exponent + raw_result_mantissa |= sum_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sum_grs + with pyrtl.otherwise: # sub + raw_result_exponent |= sub_exponent + raw_result_mantissa |= sub_mantissa + if rounding_mode == RoundingMode.RNE: + raw_result_grs |= sub_grs + + if rounding_mode == RoundingMode.RNE: + ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) = AddSubHelper._round( + num_mant_bits, + num_exp_bits, + raw_result_exponent, + raw_result_mantissa, + raw_result_grs, + ) + + smaller_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_smaller) + larger_operand_nan = FloatUtils.is_NaN(fp_type_props, operand_larger) + smaller_operand_inf = FloatUtils.is_inf(fp_type_props, operand_smaller) + larger_operand_inf = FloatUtils.is_inf(fp_type_props, operand_larger) + smaller_operand_zero = FloatUtils.is_zero(fp_type_props, operand_smaller) + larger_operand_zero = FloatUtils.is_zero(fp_type_props, operand_larger) + + # WireVectors for the final result after handling special cases + final_result_sign = pyrtl.WireVector(bitwidth=1) + final_result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + final_result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + + # handle special cases + with pyrtl.conditional_assignment: + # if either operand is NaN or both operands are infinity of opposite signs, + # the result is NaN + with ( + smaller_operand_nan + | larger_operand_nan + | ( + smaller_operand_inf + & larger_operand_inf + & (larger_operand_sign != smaller_operand_sign) + ) + ): + final_result_sign |= larger_operand_sign + FloatUtils.make_output_NaN( + fp_type_props, final_result_exponent, final_result_mantissa + ) + # infinities + with smaller_operand_inf: + final_result_sign |= larger_operand_sign + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + with larger_operand_inf: + final_result_sign |= larger_operand_sign + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + # +num + -num = +0 + with ( + (smaller_operand_mantissa == larger_operand_mantissa) + & (smaller_operand_exponent == larger_operand_exponent) + & (larger_operand_sign != smaller_operand_sign) + ): + final_result_sign |= 0 + FloatUtils.make_output_zero( + final_result_exponent, final_result_mantissa + ) + with smaller_operand_zero: + final_result_sign |= larger_operand_sign + final_result_mantissa |= larger_operand_mantissa + final_result_exponent |= larger_operand_exponent + with larger_operand_zero: + final_result_sign |= smaller_operand_sign + final_result_mantissa |= smaller_operand_mantissa + final_result_exponent |= smaller_operand_exponent + # overflow and underflow + initial_larger_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2) + if rounding_mode == RoundingMode.RNE: + larger_exponent_max_value = ( + initial_larger_exponent_max_value + - sum_carry + - rounding_exponent_incremented + ) + else: + larger_exponent_max_value = ( + initial_larger_exponent_max_value - sum_carry + ) + initial_larger_exponent_min_value = pyrtl.Const(1) + if rounding_mode == RoundingMode.RNE: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + + num_leading_zeros + - rounding_exponent_incremented + ) + else: + larger_exponent_min_value = ( + initial_larger_exponent_min_value + num_leading_zeros + ) + with (smaller_operand_sign == larger_operand_sign) & ( + larger_operand_exponent > larger_exponent_max_value + ): # detect overflow on addition + final_result_sign |= larger_operand_sign + if rounding_mode == RoundingMode.RNE: + FloatUtils.make_output_inf( + fp_type_props, final_result_exponent, final_result_mantissa + ) + else: + FloatUtils.make_output_largest_finite_number( + fp_type_props, final_result_exponent, final_result_mantissa + ) + with (smaller_operand_sign != larger_operand_sign) & ( + larger_operand_exponent < larger_exponent_min_value + ): # detect underflow on subtraction + final_result_sign |= larger_operand_sign + FloatUtils.make_output_zero( + final_result_exponent, final_result_mantissa + ) + with pyrtl.otherwise: + final_result_sign |= larger_operand_sign + if rounding_mode == RoundingMode.RNE: + final_result_exponent |= raw_result_rounded_exponent + final_result_mantissa |= raw_result_rounded_mantissa + else: + final_result_exponent |= raw_result_exponent + final_result_mantissa |= raw_result_mantissa + + return pyrtl.concat( + final_result_sign, final_result_exponent, final_result_mantissa + ) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + num_exp_bits = config.fp_type_properties.num_exponent_bits + num_mant_bits = config.fp_type_properties.num_mantissa_bits + operand_b_negated = operand_b ^ pyrtl.concat( + pyrtl.Const(1, bitwidth=1), + pyrtl.Const(0, bitwidth=num_exp_bits + num_mant_bits), + ) + return AddSubHelper.add(config, operand_a, operand_b_negated) + + @staticmethod + def _add_operands( + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + sum_mantissa_grs = pyrtl.WireVector() + sum_mantissa_grs <<= larger_mantissa_extended + smaller_mantissa_shifted_grs + sum_carry = sum_mantissa_grs[-1] + sum_mantissa = pyrtl.select( + sum_carry, sum_mantissa_grs[4:], sum_mantissa_grs[3:-1] + ) + sum_grs = pyrtl.select( + sum_carry, + pyrtl.concat(sum_mantissa_grs[2:4], sum_mantissa_grs[:2] != 0), + sum_mantissa_grs[:3], + ) + sum_exponent = pyrtl.select( + sum_carry, larger_operand_exponent + 1, larger_operand_exponent + ) + return sum_exponent, sum_mantissa, sum_grs, sum_carry + + @staticmethod + def _sub_operands( + num_mant_bits: int, + larger_operand_exponent: pyrtl.WireVector, + smaller_mantissa_shifted_grs: pyrtl.WireVector, + larger_mantissa_extended: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector, pyrtl.WireVector]: + def leading_zero_priority_encoder(wire: pyrtl.WireVector, length: int): + out = pyrtl.WireVector( + bitwidth=pyrtl.infer_val_and_bitwidth(length - 1).bitwidth + ) + with pyrtl.conditional_assignment: + for i in range(wire.bitwidth - 1, wire.bitwidth - length - 1, -1): + with wire[i]: + out |= wire.bitwidth - i - 1 + return out + + sub_mantissa_grs = pyrtl.WireVector(bitwidth=num_mant_bits + 4) + sub_mantissa_grs <<= larger_mantissa_extended - smaller_mantissa_shifted_grs + num_leading_zeros = leading_zero_priority_encoder( + sub_mantissa_grs, num_mant_bits + 1 + ) + sub_mantissa_grs_shifted = pyrtl.shift_left_logical( + sub_mantissa_grs, num_leading_zeros + ) + sub_mantissa = sub_mantissa_grs_shifted[3:] + sub_grs = sub_mantissa_grs_shifted[:3] + sub_exponent = larger_operand_exponent - num_leading_zeros + return sub_exponent, sub_mantissa, sub_grs, num_leading_zeros + + @staticmethod + def _round( + num_mant_bits: int, + num_exp_bits: int, + raw_result_exponent: pyrtl.WireVector, + raw_result_mantissa: pyrtl.WireVector, + raw_result_grs: pyrtl.WireVector, + ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: + last = raw_result_mantissa[0] + guard = raw_result_grs[2] + round = raw_result_grs[1] + sticky = raw_result_grs[0] + round_up = guard & (last | round | sticky) + raw_result_rounded_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + raw_result_rounded_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + rounding_exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with round_up: + with raw_result_mantissa == (1 << num_mant_bits) - 1: + raw_result_rounded_mantissa |= 0 + raw_result_rounded_exponent |= raw_result_exponent + 1 + rounding_exponent_incremented |= 1 + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + 1 + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + with pyrtl.otherwise: + raw_result_rounded_mantissa |= raw_result_mantissa + raw_result_rounded_exponent |= raw_result_exponent + rounding_exponent_incremented |= 0 + return ( + raw_result_rounded_exponent, + raw_result_rounded_mantissa, + rounding_exponent_incremented, + ) diff --git a/pyrtl/rtllib/pyrtlfloat/_float_utills.py b/pyrtl/rtllib/pyrtlfloat/_float_utills.py new file mode 100644 index 00000000..0ae58329 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_float_utills.py @@ -0,0 +1,104 @@ +import pyrtl + +from ._types import FPTypeProperties + + +class FloatUtils: + @staticmethod + def get_sign(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return wire[fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits] + + @staticmethod + def get_exponent( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return wire[ + fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits + + fp_prop.num_exponent_bits + ] + + @staticmethod + def get_mantissa( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return wire[: fp_prop.num_mantissa_bits] + + @staticmethod + def is_zero(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( + FloatUtils.get_exponent(fp_prop, wire) == 0 + ) + + @staticmethod + def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) == 0) & ( + FloatUtils.get_exponent(fp_prop, wire) + == (1 << fp_prop.num_exponent_bits) - 1 + ) + + @staticmethod + def is_denormalized( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( + FloatUtils.get_exponent(fp_prop, wire) == 0 + ) + + @staticmethod + def is_NaN(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: + return (FloatUtils.get_mantissa(fp_prop, wire) != 0) & ( + FloatUtils.get_exponent(fp_prop, wire) + == (1 << fp_prop.num_exponent_bits) - 1 + ) + + @staticmethod + def make_denormals_zero( + fp_prop: FPTypeProperties, wire: pyrtl.WireVector + ) -> pyrtl.WireVector: + out = pyrtl.WireVector( + bitwidth=fp_prop.num_mantissa_bits + fp_prop.num_exponent_bits + 1 + ) + with pyrtl.conditional_assignment: + with FloatUtils.get_exponent(fp_prop, wire) == 0: + out |= pyrtl.concat( + FloatUtils.get_sign(fp_prop, wire), + FloatUtils.get_exponent(fp_prop, wire), + pyrtl.Const(0, bitwidth=fp_prop.num_mantissa_bits), + ) + with pyrtl.otherwise: + out |= wire + return out + + @staticmethod + def make_output_inf( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 0 + + @staticmethod + def make_output_NaN( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 1 + mantissa |= 1 << (fp_prop.num_mantissa_bits - 1) + + @staticmethod + def make_output_zero( + exponent: pyrtl.WireVector, mantissa: pyrtl.WireVector + ) -> None: + exponent |= 0 + mantissa |= 0 + + @staticmethod + def make_output_largest_finite_number( + fp_prop: FPTypeProperties, + exponent: pyrtl.WireVector, + mantissa: pyrtl.WireVector, + ) -> None: + exponent |= (1 << fp_prop.num_exponent_bits) - 2 + mantissa |= (1 << fp_prop.num_mantissa_bits) - 1 diff --git a/pyrtl/rtllib/pyrtlfloat/_multiplication.py b/pyrtl/rtllib/pyrtlfloat/_multiplication.py new file mode 100644 index 00000000..a0de7d25 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_multiplication.py @@ -0,0 +1,161 @@ +import pyrtl + +from ._float_utills import FloatUtils +from ._types import PyrtlFloatConfig, RoundingMode + + +class MultiplicationHelper: + @staticmethod + def multiply( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + fp_type_props = config.fp_type_properties + rounding_mode = config.rounding_mode + num_exp_bits = fp_type_props.num_exponent_bits + num_mant_bits = fp_type_props.num_mantissa_bits + + operand_a_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_a) + operand_b_daz = FloatUtils.make_denormals_zero(fp_type_props, operand_b) + a_sign = FloatUtils.get_sign(fp_type_props, operand_a_daz) + b_sign = FloatUtils.get_sign(fp_type_props, operand_b_daz) + a_exponent = FloatUtils.get_exponent(fp_type_props, operand_a_daz) + b_exponent = FloatUtils.get_exponent(fp_type_props, operand_b_daz) + + exponent_bias = 2 ** (fp_type_props.num_exponent_bits - 1) - 1 + + result_sign = a_sign ^ b_sign + operand_exponent_sums = a_exponent + b_exponent + product_exponent = operand_exponent_sums - pyrtl.Const(exponent_bias) + + a_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_a_daz) + ) + b_mantissa = pyrtl.concat( + pyrtl.Const(1), FloatUtils.get_mantissa(fp_type_props, operand_b_daz) + ) + product_mantissa = a_mantissa * b_mantissa + + normalized_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + normalized_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + need_to_normalize = product_mantissa[-1] + + if rounding_mode == RoundingMode.RNE: + guard = pyrtl.WireVector(bitwidth=1) + sticky = pyrtl.WireVector(bitwidth=1) + last = pyrtl.WireVector(bitwidth=1) + + with pyrtl.conditional_assignment: + with need_to_normalize: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 1 :] + normalized_product_exponent |= product_exponent + 1 + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 2] + sticky |= product_mantissa[: -num_mant_bits - 2] != 0 + last |= product_mantissa[-num_mant_bits - 1] + with pyrtl.otherwise: + normalized_product_mantissa |= product_mantissa[-num_mant_bits - 2 : -1] + normalized_product_exponent |= product_exponent + if rounding_mode == RoundingMode.RNE: + guard |= product_mantissa[-num_mant_bits - 3] + sticky |= product_mantissa[: -num_mant_bits - 3] != 0 + last |= product_mantissa[-num_mant_bits - 2] + + if rounding_mode == RoundingMode.RNE: + rounded_product_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + rounded_product_exponent = pyrtl.WireVector(bitwidth=num_exp_bits + 1) + exponent_incremented = pyrtl.WireVector(bitwidth=1) + with pyrtl.conditional_assignment: + with guard & (last | sticky): + with normalized_product_mantissa == (1 << num_mant_bits) - 1: + rounded_product_mantissa |= 0 + rounded_product_exponent |= normalized_product_exponent + 1 + exponent_incremented |= 1 + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + 1 + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + with pyrtl.otherwise: + rounded_product_mantissa |= normalized_product_mantissa + rounded_product_exponent |= normalized_product_exponent + exponent_incremented |= 0 + + result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) + result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) + + operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz) + operand_b_nan = FloatUtils.is_NaN(fp_type_props, operand_b_daz) + operand_a_inf = FloatUtils.is_inf(fp_type_props, operand_a_daz) + operand_b_inf = FloatUtils.is_inf(fp_type_props, operand_b_daz) + operand_a_zero = FloatUtils.is_zero(fp_type_props, operand_a_daz) + operand_b_zero = FloatUtils.is_zero(fp_type_props, operand_b_daz) + operand_a_denormalized = FloatUtils.is_denormalized( + fp_type_props, operand_a_daz + ) + operand_b_denormalized = FloatUtils.is_denormalized( + fp_type_props, operand_b_daz + ) + + # Overflow and underflow checks (only for normal cases) + sum_exponent_max_value = pyrtl.Const(2**num_exp_bits - 2 + exponent_bias) + sum_exponent_min_value = pyrtl.Const(1 + exponent_bias) + if rounding_mode == RoundingMode.RNE: + exponent_max_value = ( + sum_exponent_max_value - need_to_normalize - exponent_incremented + ) + exponent_min_value = ( + sum_exponent_min_value - need_to_normalize - exponent_incremented + ) + else: + exponent_max_value = sum_exponent_max_value - need_to_normalize + exponent_min_value = sum_exponent_min_value - need_to_normalize + + if rounding_mode == RoundingMode.RNE: + raw_result_exponent = rounded_product_exponent[0:num_exp_bits] + raw_result_mantissa = rounded_product_mantissa + else: + raw_result_exponent = normalized_product_exponent[0:num_exp_bits] + raw_result_mantissa = normalized_product_mantissa + + with pyrtl.conditional_assignment: + # nan + with ( + operand_a_nan + | operand_b_nan + | (operand_a_inf & operand_b_zero) + | (operand_a_zero & operand_b_inf) + ): + FloatUtils.make_output_NaN( + fp_type_props, result_exponent, result_mantissa + ) + # infinity + with operand_a_inf | operand_b_inf: + FloatUtils.make_output_inf( + fp_type_props, result_exponent, result_mantissa + ) + # overflow + with operand_exponent_sums > exponent_max_value: + if rounding_mode == RoundingMode.RNE: + FloatUtils.make_output_inf( + fp_type_props, result_exponent, result_mantissa + ) + else: + FloatUtils.make_output_largest_finite_number( + fp_type_props, result_exponent, result_mantissa + ) + # zero or underflow + with ( + operand_a_zero + | operand_b_zero + | (operand_exponent_sums < exponent_min_value) + | operand_a_denormalized + | operand_b_denormalized + ): + FloatUtils.make_output_zero(result_exponent, result_mantissa) + with pyrtl.otherwise: + result_exponent |= raw_result_exponent + result_mantissa |= raw_result_mantissa + + return pyrtl.concat(result_sign, result_exponent, result_mantissa) diff --git a/pyrtl/rtllib/pyrtlfloat/_types.py b/pyrtl/rtllib/pyrtlfloat/_types.py new file mode 100644 index 00000000..15a1c811 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/_types.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from enum import Enum + + +class RoundingMode(Enum): + RTZ = 1 # round towards zero (truncate) + RNE = 2 # round to nearest, ties to even (default mode) + + +@dataclass(frozen=True) +class FPTypeProperties: + num_exponent_bits: int + num_mantissa_bits: int + + +class FloatingPointType(Enum): + BFLOAT16 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=7) + FLOAT16 = FPTypeProperties(num_exponent_bits=5, num_mantissa_bits=10) + FLOAT32 = FPTypeProperties(num_exponent_bits=8, num_mantissa_bits=23) + FLOAT64 = FPTypeProperties(num_exponent_bits=11, num_mantissa_bits=52) + + +@dataclass(frozen=True) +class PyrtlFloatConfig: + fp_type_properties: FPTypeProperties + rounding_mode: RoundingMode + + +class PyrtlFloatException(Exception): + pass diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py new file mode 100644 index 00000000..e4d10769 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -0,0 +1,33 @@ +import pyrtl + +from ._add_sub import AddSubHelper +from ._multiplication import MultiplicationHelper +from ._types import PyrtlFloatConfig, RoundingMode + + +class FloatOperations: + default_rounding_mode = RoundingMode.RNE + + @staticmethod + def multiply( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return MultiplicationHelper.multiply(config, operand_a, operand_b) + + @staticmethod + def add( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return AddSubHelper.add(config, operand_a, operand_b) + + @staticmethod + def sub( + config: PyrtlFloatConfig, + operand_a: pyrtl.WireVector, + operand_b: pyrtl.WireVector, + ) -> pyrtl.WireVector: + return AddSubHelper.sub(config, operand_a, operand_b) diff --git a/pyrtl/rtllib/pyrtlfloat/floatwirevector.py b/pyrtl/rtllib/pyrtlfloat/floatwirevector.py new file mode 100644 index 00000000..0004adb7 --- /dev/null +++ b/pyrtl/rtllib/pyrtlfloat/floatwirevector.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import pyrtl + +from ._types import FloatingPointType, PyrtlFloatConfig, PyrtlFloatException +from .floatoperations import FloatOperations + + +class Float16WireVector(pyrtl.WireVector): + def __init__(self): + super().__init__() + self.bitwidth = 16 + + def __ilshift__(self, other): + if isinstance(other, (pyrtl.WireVector, Float16WireVector)): + super().__ilshift__(other) + else: + msg = ( + "FloatWireVector16 can only be driven by a FloatWireVector16 " + "or a PyRTL WireVector." + ) + raise PyrtlFloatException(msg) + return self + + def _get_config(self) -> PyrtlFloatConfig | None: + return PyrtlFloatConfig( + FloatingPointType.FLOAT16.value, FloatOperations.default_rounding_mode + ) + + def __add__(self, other: Float16WireVector) -> Float16WireVector: + ret = Float16WireVector() + ret <<= FloatOperations.add(self._get_config(), self, other) + return ret + + def __sub__(self, other: Float16WireVector) -> Float16WireVector: + ret = Float16WireVector() + ret <<= FloatOperations.sub(self._get_config(), self, other) + return ret + + def __mul__(self, other: Float16WireVector) -> Float16WireVector: + ret = Float16WireVector() + ret <<= FloatOperations.multiply(self._get_config(), self, other) + return ret + + +# will create BFloat16WireVector, Float32WireVector, and Float64WireVector the same way diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py new file mode 100644 index 00000000..f20177a9 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -0,0 +1,30 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode + + +class TestMultiplication(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + a = pyrtl.Input(bitwidth=16, name="a") + b = pyrtl.Input(bitwidth=16, name="b") + a_floatwv = Float16WireVector() + a_floatwv <<= a + b_floatwv = Float16WireVector() + b_floatwv <<= b + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_add = pyrtl.Output(name="result_add") + result_add <<= a_floatwv + b_floatwv + result_sub = pyrtl.Output(name="result_sub") + result_sub <<= a_floatwv - b_floatwv + self.sim = pyrtl.Simulation() + + def test_multiplication_simple(self): + self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000}) + self.assertEqual(self.sim.inspect("result_add"), 0b0100100000000000) + self.assertEqual(self.sim.inspect("result_sub"), 0b1100000000000000) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py new file mode 100644 index 00000000..812b0ea5 --- /dev/null +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -0,0 +1,31 @@ +import unittest + +import pyrtl +from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode + + +class TestMultiplication(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + a = pyrtl.Input(bitwidth=16, name="a") + b = pyrtl.Input(bitwidth=16, name="b") + a_floatwv = Float16WireVector() + a_floatwv <<= a + b_floatwv = Float16WireVector() + b_floatwv <<= b + FloatOperations.default_rounding_mode = RoundingMode.RNE + result_rne = pyrtl.Output(name="result_rne") + result_rne <<= a_floatwv * b_floatwv + FloatOperations.default_rounding_mode = RoundingMode.RTZ + result_rtz = pyrtl.Output(name="result_rtz") + result_rtz <<= a_floatwv * b_floatwv + self.sim = pyrtl.Simulation() + + def test_multiplication_simple(self): + self.sim.step({"a": 0b0011111000000000, "b": 0b0011110000000001}) + self.assertEqual(self.sim.inspect("result_rne"), 0b0011111000000010) + self.assertEqual(self.sim.inspect("result_rtz"), 0b0011111000000001) + + +if __name__ == "__main__": + unittest.main() From 15ad9965cd7dcecd2d43ebdb9d60e6fd2f830046 Mon Sep 17 00:00:00 2001 From: Gabor Szita Date: Thu, 30 Oct 2025 15:49:44 -0700 Subject: [PATCH 2/2] Remove FloatWireVector --- pyrtl/rtllib/pyrtlfloat/__init__.py | 14 ++++-- pyrtl/rtllib/pyrtlfloat/floatoperations.py | 48 ++++++++++++++++++- pyrtl/rtllib/pyrtlfloat/floatwirevector.py | 46 ------------------ tests/rtllib/pyrtlfloat/test_add_sub.py | 10 ++-- .../rtllib/pyrtlfloat/test_multiplication.py | 10 ++-- 5 files changed, 63 insertions(+), 65 deletions(-) delete mode 100644 pyrtl/rtllib/pyrtlfloat/floatwirevector.py diff --git a/pyrtl/rtllib/pyrtlfloat/__init__.py b/pyrtl/rtllib/pyrtlfloat/__init__.py index df407d93..d9b64710 100644 --- a/pyrtl/rtllib/pyrtlfloat/__init__.py +++ b/pyrtl/rtllib/pyrtlfloat/__init__.py @@ -1,6 +1,11 @@ from ._types import FloatingPointType, FPTypeProperties, PyrtlFloatConfig, RoundingMode -from .floatoperations import FloatOperations -from .floatwirevector import Float16WireVector +from .floatoperations import ( + BFloat16Operations, + Float16Operations, + Float32Operations, + Float64Operations, + FloatOperations, +) __all__ = [ "FloatingPointType", @@ -8,5 +13,8 @@ "PyrtlFloatConfig", "RoundingMode", "FloatOperations", - "Float16WireVector", + "BFloat16Operations", + "Float16Operations", + "Float32Operations", + "Float64Operations", ] diff --git a/pyrtl/rtllib/pyrtlfloat/floatoperations.py b/pyrtl/rtllib/pyrtlfloat/floatoperations.py index e4d10769..ef081b0a 100644 --- a/pyrtl/rtllib/pyrtlfloat/floatoperations.py +++ b/pyrtl/rtllib/pyrtlfloat/floatoperations.py @@ -2,14 +2,14 @@ from ._add_sub import AddSubHelper from ._multiplication import MultiplicationHelper -from ._types import PyrtlFloatConfig, RoundingMode +from ._types import FloatingPointType, PyrtlFloatConfig, RoundingMode class FloatOperations: default_rounding_mode = RoundingMode.RNE @staticmethod - def multiply( + def mul( config: PyrtlFloatConfig, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector, @@ -31,3 +31,47 @@ def sub( operand_b: pyrtl.WireVector, ) -> pyrtl.WireVector: return AddSubHelper.sub(config, operand_a, operand_b) + + +class _BaseTypedFloatOperations: + _fp_type: FloatingPointType = None + + @classmethod + def mul( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.mul(cls._get_config(), operand_a, operand_b) + + @classmethod + def add( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.add(cls._get_config(), operand_a, operand_b) + + @classmethod + def sub( + cls, operand_a: pyrtl.WireVector, operand_b: pyrtl.WireVector + ) -> pyrtl.WireVector: + return FloatOperations.sub(cls._get_config(), operand_a, operand_b) + + @classmethod + def _get_config(cls) -> PyrtlFloatConfig: + return PyrtlFloatConfig( + cls._fp_type.value, FloatOperations.default_rounding_mode + ) + + +class BFloat16Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.BFLOAT16 + + +class Float16Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT16 + + +class Float32Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT32 + + +class Float64Operations(_BaseTypedFloatOperations): + _fp_type = FloatingPointType.FLOAT64 diff --git a/pyrtl/rtllib/pyrtlfloat/floatwirevector.py b/pyrtl/rtllib/pyrtlfloat/floatwirevector.py deleted file mode 100644 index 0004adb7..00000000 --- a/pyrtl/rtllib/pyrtlfloat/floatwirevector.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import annotations - -import pyrtl - -from ._types import FloatingPointType, PyrtlFloatConfig, PyrtlFloatException -from .floatoperations import FloatOperations - - -class Float16WireVector(pyrtl.WireVector): - def __init__(self): - super().__init__() - self.bitwidth = 16 - - def __ilshift__(self, other): - if isinstance(other, (pyrtl.WireVector, Float16WireVector)): - super().__ilshift__(other) - else: - msg = ( - "FloatWireVector16 can only be driven by a FloatWireVector16 " - "or a PyRTL WireVector." - ) - raise PyrtlFloatException(msg) - return self - - def _get_config(self) -> PyrtlFloatConfig | None: - return PyrtlFloatConfig( - FloatingPointType.FLOAT16.value, FloatOperations.default_rounding_mode - ) - - def __add__(self, other: Float16WireVector) -> Float16WireVector: - ret = Float16WireVector() - ret <<= FloatOperations.add(self._get_config(), self, other) - return ret - - def __sub__(self, other: Float16WireVector) -> Float16WireVector: - ret = Float16WireVector() - ret <<= FloatOperations.sub(self._get_config(), self, other) - return ret - - def __mul__(self, other: Float16WireVector) -> Float16WireVector: - ret = Float16WireVector() - ret <<= FloatOperations.multiply(self._get_config(), self, other) - return ret - - -# will create BFloat16WireVector, Float32WireVector, and Float64WireVector the same way diff --git a/tests/rtllib/pyrtlfloat/test_add_sub.py b/tests/rtllib/pyrtlfloat/test_add_sub.py index f20177a9..3f006da2 100644 --- a/tests/rtllib/pyrtlfloat/test_add_sub.py +++ b/tests/rtllib/pyrtlfloat/test_add_sub.py @@ -1,7 +1,7 @@ import unittest import pyrtl -from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode class TestMultiplication(unittest.TestCase): @@ -9,15 +9,11 @@ def setUp(self): pyrtl.reset_working_block() a = pyrtl.Input(bitwidth=16, name="a") b = pyrtl.Input(bitwidth=16, name="b") - a_floatwv = Float16WireVector() - a_floatwv <<= a - b_floatwv = Float16WireVector() - b_floatwv <<= b FloatOperations.default_rounding_mode = RoundingMode.RNE result_add = pyrtl.Output(name="result_add") - result_add <<= a_floatwv + b_floatwv + result_add <<= Float16Operations.add(a, b) result_sub = pyrtl.Output(name="result_sub") - result_sub <<= a_floatwv - b_floatwv + result_sub <<= Float16Operations.sub(a, b) self.sim = pyrtl.Simulation() def test_multiplication_simple(self): diff --git a/tests/rtllib/pyrtlfloat/test_multiplication.py b/tests/rtllib/pyrtlfloat/test_multiplication.py index 812b0ea5..439efabb 100644 --- a/tests/rtllib/pyrtlfloat/test_multiplication.py +++ b/tests/rtllib/pyrtlfloat/test_multiplication.py @@ -1,7 +1,7 @@ import unittest import pyrtl -from pyrtl.rtllib.pyrtlfloat import Float16WireVector, FloatOperations, RoundingMode +from pyrtl.rtllib.pyrtlfloat import Float16Operations, FloatOperations, RoundingMode class TestMultiplication(unittest.TestCase): @@ -9,16 +9,12 @@ def setUp(self): pyrtl.reset_working_block() a = pyrtl.Input(bitwidth=16, name="a") b = pyrtl.Input(bitwidth=16, name="b") - a_floatwv = Float16WireVector() - a_floatwv <<= a - b_floatwv = Float16WireVector() - b_floatwv <<= b FloatOperations.default_rounding_mode = RoundingMode.RNE result_rne = pyrtl.Output(name="result_rne") - result_rne <<= a_floatwv * b_floatwv + result_rne <<= Float16Operations.mul(a, b) FloatOperations.default_rounding_mode = RoundingMode.RTZ result_rtz = pyrtl.Output(name="result_rtz") - result_rtz <<= a_floatwv * b_floatwv + result_rtz <<= Float16Operations.mul(a, b) self.sim = pyrtl.Simulation() def test_multiplication_simple(self):