-
Notifications
You must be signed in to change notification settings - Fork 88
Pyrtl floating point library #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## development #475 +/- ##
=============================================
+ Coverage 91.0% 91.4% +0.4%
=============================================
Files 25 31 +6
Lines 7091 7473 +382
=============================================
+ Hits 6450 6827 +377
- Misses 641 646 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
fdxmw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this contribution! Here are some initial comments
| from ._types import PyrtlFloatConfig, RoundingMode | ||
|
|
||
|
|
||
| class AddSubHelper: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, classes with only @staticmethods probably shouldn't be classes :) It looks like this class can be removed, and all the static methods can be ordinary functions? Same comment for FloatUtils and MultiplicationHelper
| raw_result_grs: pyrtl.WireVector, | ||
| ) -> tuple[pyrtl.WireVector, pyrtl.WireVector]: | ||
| last = raw_result_mantissa[0] | ||
| guard = raw_result_grs[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using wire_struct to simplify these kinds of concat/slice patterns. With a wire_struct, you wouldn't need this line, and later lines would instead refer to raw_result.guard.
Also see the wire_struct example.
| fp_prop: FPTypeProperties, wire: pyrtl.WireVector | ||
| ) -> pyrtl.WireVector: | ||
| return wire[ | ||
| fp_prop.num_mantissa_bits : fp_prop.num_mantissa_bits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, wire_struct can remove the need for these tricky bit offset calculations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to do this, but the problem is that the number of mantissa and exponent bits differs based on fp_prop, so we would need to set the bitwidths of the slices in the wire_struct dynamically at runtime. Is it possible to do this with wire_struct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is possible with some trickery, but it's probably overkill in this case, since you only need to select between a few well-known bit layouts (the user can't define an arbitrary custom floating point format). I think you could define a set of wire_structs that share the same interface, something like:
@pyrtl.wire_struct
class Float16
sign: 1
exponent: 5
fraction: 10
@pyrtl.wire_struct
class Float32
sign: 1
exponent: 8
fraction: 23So the idea is that these objects share the same interface, so it's easier to write code that works with any of these types. For example, if you want to compare sign bits, you'd write a.sign == b.sign, and that would work regardless of whether a or b are Float16 or Float32. If you care about types (I encourage you to! :), these are Unions like Float16 | Float32.
In case it's useful, here's the trickery, you can define a function that returns a dynamically-defined class:
import pyrtl
def define_my_struct(a_bits: int, b_bits: int):
@pyrtl.wire_struct
class MyInternalClass:
a: a_bits
b: b_bits
return MyInternalClass
MyStruct = define_my_struct(a_bits=4, b_bits=8)
my_struct_instance = MyStruct(a=1, b=2)
print("a bitwidth", my_struct_instance.a.bitwidth)
print("b bitwidth", my_struct_instance.b.bitwidth)
print("total bitwidth", my_struct_instance.bitwidth)$ uv run ...
a bitwidth 4
b bitwidth 8
total bitwidth 12|
|
||
| class AddSubHelper: | ||
| @staticmethod | ||
| def add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need docstrings for all user-facing classes, methods, and functions. Docstrings for internal stuff would also be great; for example as someone new to this code it would really help to have an idea of what make_denormals_zero does conceptually.
| ) | ||
|
|
||
| @staticmethod | ||
| def is_inf(fp_prop: FPTypeProperties, wire: pyrtl.WireVector) -> pyrtl.WireVector: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would really help to have some comments with brief definitions and references for important floating point concepts (NaN, inf, denormalized (how are these different?), guard, round, sticky, ...).
I'm not expecting you to teach me everything in comments, but it would really help to know what I should read if I want to understand how this works :)
| exponent_min_value = sum_exponent_min_value - need_to_normalize | ||
|
|
||
| if rounding_mode == RoundingMode.RNE: | ||
| raw_result_exponent = rounded_product_exponent[0:num_exp_bits] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can omit the 0 here, also on line 119 below. Doing so would be more consistent with your code above, line 63 for example.
| result_exponent = pyrtl.WireVector(bitwidth=num_exp_bits) | ||
| result_mantissa = pyrtl.WireVector(bitwidth=num_mant_bits) | ||
|
|
||
| operand_a_nan = FloatUtils.is_NaN(fp_type_props, operand_a_daz) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capitalization of NaN is inconsistent here, it's nan on the left hand side and NaN on the right hand side. I'd rename is_NaN to is_nan since we pretty consistently use lowercase with underscores for method and function names.
| self.sim = pyrtl.Simulation() | ||
|
|
||
| def test_multiplication_simple(self): | ||
| self.sim.step({"a": 0b0100001000000000, "b": 0b0100010100000000}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need a better story for how this is tested :) A few things to consider:
Every PyRTL developer will need to run these tests for the indefinite future, so we really don't want the tests to run for too long. I'd say all continuously-run floating point tests should complete in a couple seconds or less. These tests can't be comprehensive, but we should still try to get as much value as we can out of them.
Try to think about the most interesting inputs (zeroes, inf, nan, ...), and which combinations of inputs and operations are worth testing. I would avoid randomly generated tests because they are unlikely to cover all the interesting combinations with a reasonable number of test cases.
Think about how someone besides yourself would debug one of these test failures. It would help a lot to know what cases a particular test exercises (# Test addition with inf, etc). It also helps to know if the unexpected output is slightly wrong, or complete nonsense, which is pretty hard to tell from a raw bit pattern (see next point :)
Consider adding some helper functions that can convert float to and from these bit patterns. That would make it easier for someone to understand that the test checks that 1.0 + 2.0 == 3.0, for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, testing needs more work, I'll add more rigorous tests and test edge cases. That's why I marked it as a draft PR because it is not finished yet.
|
|
||
| @staticmethod | ||
| def multiply( | ||
| def mul( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem very useful, since it just passes all its arguments through to MultiplicationHelper.multiply? Seems better to point users directly to MultiplicationHelper.multiply. Same comments for the other methods below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I have this FloatOperations class and separate MultiplicationHelper and AddSubHelper classes is because the user is supposed to use FloatOperations, MultiplicationHelper and AddSubHelper are internal helpers and the user is not supposed to use them. I created these helpers so I can separate the multiplication and addition/subtraction logic into separate files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense, but I'd still try to refactor, it would be nice to remove these functions that are just pass-through wrappers for another function. Maybe:
AddSubHelper.addcould be an ordinary function instead of a@staticmethod_add_sub.pycould be renamed toadd_sub.py- Non-public functions in
add_sub.pycould have an underscore prefix __init__.pycould importaddfromadd_sub
I think that would still make it clear which parts of the interface are meant for public use, while removing a layer of indirection?
This kind of indirection tends to be annoying when debugging problems, because you have to jump through another hoop to find the code you're looking for. "Our princess is in another castle!"
| operand_a: pyrtl.WireVector, | ||
| operand_b: pyrtl.WireVector, | ||
| ) -> pyrtl.WireVector: | ||
| fp_type_props = config.fp_type_properties |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All user-facing methods/functions should check (somewhere, not necessarily here) that the operand bitwidths are consistent with config
No description provided.