diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index fdf6b9990..6c21ff39d 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -54,7 +54,8 @@ jobs: git clone https://github.com/Chia-Network/clvm_tools.git --branch=main --single-branch python -m pip install ./clvm_tools python -m pip install colorama - maturin develop --release + git clone https://github.com/Chia-Network/mpir_gc_x64.git --depth 1 + maturin develop --release --cargo-extra-args="--features=mpir" - name: Run benchmarks (Windows) if: startsWith(matrix.os, 'windows') @@ -62,6 +63,13 @@ jobs: . .\venv\Scripts\Activate.ps1 python benchmark/run-benchmark.py + - name: Install GMP + if: ${{ startsWith(matrix.os, 'ubuntu') }} + run: | + sudo apt install libgmp3-dev + ls -la /usr/lib64 + ls -la /usr/lib + - name: Build if: ${{ !startsWith(matrix.os, 'windows') }} env: @@ -108,6 +116,9 @@ jobs: run: | python -m pip install maturin rustup target add x86_64-unknown-linux-musl + sudo apt install libgmp3-dev + ls -la /usr/lib64 + ls -la /usr/lib - name: Build env: diff --git a/.github/workflows/build-arm64-wheels.yml b/.github/workflows/build-arm64-wheels.yml index 36884792c..964308204 100644 --- a/.github/workflows/build-arm64-wheels.yml +++ b/.github/workflows/build-arm64-wheels.yml @@ -44,6 +44,8 @@ jobs: curl -L https://sh.rustup.rs > rustup-init.sh && \ sh rustup-init.sh -y && \ yum -y install openssl-devel && \ + yum -y install gmp && \ + pkg-config --libs gmp && \ source $HOME/.cargo/env && \ rustup target add aarch64-unknown-linux-musl && \ rm -rf venv && \ @@ -54,7 +56,7 @@ jobs: if [ ! -f "activate" ]; then ln -s venv/bin/activate; fi && \ . ./activate && \ pip install maturin && \ - CC=gcc maturin build --no-sdist --release --strip --manylinux 2014 --cargo-extra-args="--features=openssl" \ + CC=gcc maturin build --no-sdist --release --strip --manylinux 2014 --cargo-extra-args="--features=openssl,libgmp10" \ ' - name: Upload artifacts diff --git a/.github/workflows/build-crate-and-npm.yml b/.github/workflows/build-crate-and-npm.yml index 54ab1330e..aae0354a4 100644 --- a/.github/workflows/build-crate-and-npm.yml +++ b/.github/workflows/build-crate-and-npm.yml @@ -39,6 +39,8 @@ jobs: run: cargo +stable fmt -- --files-with-diff --check - name: clippy (stable) run: cargo +stable clippy + - name: install GMP + run: sudo apt install libgmp3-dev - name: tests run: cargo test && cargo test --release - name: build @@ -56,7 +58,7 @@ jobs: run: cargo install wasm-pack - name: wasm-pack build and pack - run: wasm-pack build && wasm-pack pack + run: wasm-pack build --features=num-bigint && wasm-pack pack - name: Upload npm pkg artifacts uses: actions/upload-artifact@v2 diff --git a/.github/workflows/build-m1-wheel.yml b/.github/workflows/build-m1-wheel.yml index 941d538ee..5bf40c918 100644 --- a/.github/workflows/build-m1-wheel.yml +++ b/.github/workflows/build-m1-wheel.yml @@ -37,6 +37,16 @@ jobs: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs > rust.sh arch -arm64 sh rust.sh -y + - name: install GMP + run: | + curl -L https://gmplib.org/download/gmp/gmp-6.2.1.tar.lz | tar x + cd gmp-6.2.1 + ./configure --enable-fat --with-pic + make -j 6 + sudo make install + cd .. + rm -rf gmp-6.2.1 + - name: Build m1 wheels run: | arch -arm64 python3 -m venv venv diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 30a832e05..8a0caeea5 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -47,6 +47,7 @@ jobs: - name: Build MacOs with maturin on Python ${{ matrix.python }} if: startsWith(matrix.os, 'macos') run: | + brew install gmp python -m venv venv ln -s venv/bin/activate . ./activate @@ -70,6 +71,8 @@ jobs: curl -L https://sh.rustup.rs > rustup-init.sh && \ sh rustup-init.sh -y && \ yum -y install openssl-devel && \ + yum -y install gmp && \ + pkg-config --libs gmp && \ source $HOME/.cargo/env && \ rustup target add x86_64-unknown-linux-musl && \ rm -rf venv && \ @@ -86,7 +89,7 @@ jobs: . ./activate && \ pip install --upgrade pip && \ pip install maturin && \ - CC=gcc maturin build --no-sdist --release --strip --manylinux 2010 --cargo-extra-args="--features=openssl" \ + CC=gcc maturin build --no-sdist --release --strip --manylinux 2010 --cargo-extra-args="--features=openssl,libgmp" \ ' - name: Build Windows with maturin on Python ${{ matrix.python }} @@ -95,12 +98,13 @@ jobs: python -m venv venv . .\venv\Scripts\Activate.ps1 ln -s venv\Scripts\Activate.ps1 activate - maturin build --no-sdist -i python --release --strip + git clone https://github.com/Chia-Network/mpir_gc_x64.git --depth 1 + maturin build --no-sdist -i python --release --strip --cargo-extra-args="--features=mpir" # this will install into the venv # it'd be better to use the wheel, but I can't figure out how to do that # TODO: figure this out # this does NOT work: pip install target/wheels/clvm_rs-*.whl - maturin develop --release + maturin develop --release --cargo-extra-args="--features=mpir" # the line above also doesn't seem to work - name: Install clvm_rs wheel @@ -227,6 +231,8 @@ jobs: uses: actions-rs/toolchain@v1 with: toolchain: nightly + - name: install GMP + run: sudo apt install libgmp-dev - name: cargo-fuzz run: cargo +nightly install cargo-fuzz - name: build @@ -244,5 +250,7 @@ jobs: with: toolchain: stable components: rustfmt, clippy + - name: install GMP + run: sudo apt install libgmp-dev - name: cargo test run: cargo test diff --git a/Cargo.lock b/Cargo.lock index 6b1d5d4f3..1475810e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -256,9 +256,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.0" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d047c1062aa51e256408c560894e5251f08925980e53cf1aa5bd00eec6512" +checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" dependencies = [ "autocfg", "num-integer", diff --git a/Cargo.toml b/Cargo.toml index a918a222a..4166eaa57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,12 +18,16 @@ lto = true [features] extension-module = ["pyo3/extension-module"] -default = ["extension-module"] +libgmp = [] +libgmp3 = [] +libgmp10 = [] +mpir = [] +default = ["extension-module", "libgmp"] [dependencies] hex = "=0.4.3" lazy_static = "=1.4.0" -num-bigint = "=0.4.0" +num-bigint = { version = "0.4.0", optional = true } num-traits = "=0.2.14" num-integer = "=0.1.44" bls12_381 = "=0.5.0" diff --git a/README.md b/README.md index c40ede7f9..c0238e6a1 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,12 @@ Use `maturin` to build the python interface. First, install into current virtual $ pip install maturin ``` +As we need `MPIR` for MSVC builds, prepare this dependency with + +``` +$ git clone https://github.com/Chia-Network/mpir_gc_x64.git --depth 1 +``` + Build `clvm_rs` directly into the current virtualenv with ``` diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..0d9ee860b --- /dev/null +++ b/build.rs @@ -0,0 +1,24 @@ +use std::env; + +fn main() { + if env::var_os("CARGO_FEATURE_MPIR").is_some() { + println!("cargo:rustc-link-lib=mpir"); + println!("cargo:rustc-link-search=mpir_gc_x64"); + } else if env::var_os("CARGO_FEATURE_LIBGMP3").is_some() { + println!("cargo:rustc-link-lib=libgmp.so.3"); + } else if env::var_os("CARGO_FEATURE_LIBGMP10").is_some() { + println!("cargo:rustc-link-lib=libgmp.so.10"); + } else if env::var_os("CARGO_FEATURE_LIBGMP").is_some() { + println!("cargo:rustc-link-lib=gmp"); + } + + #[cfg(target_os = "linux")] + { + println!("cargo:rustc-link-search=/usr/lib64"); + println!("cargo:rustc-link-search=/usr/lib"); + } + #[cfg(target_os = "macos")] + { + println!("cargo:rustc-link-search=/opt/homebrew/lib"); + } +} diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 31867deaa..a87a87d1e 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -14,6 +14,7 @@ libfuzzer-sys = "0.4" [dependencies.clvm_rs] path = ".." default-features = false +features = ["libgmp", "openssl"] # Prevent this from interfering with workspaces [workspace] diff --git a/src/gen/conditions.rs b/src/gen/conditions.rs index 4745dde0d..30c1a7681 100644 --- a/src/gen/conditions.rs +++ b/src/gen/conditions.rs @@ -596,8 +596,6 @@ use crate::serialize::node_to_bytes; #[cfg(test)] use hex::FromHex; #[cfg(test)] -use num_traits::Num; -#[cfg(test)] use std::collections::HashMap; #[cfg(test)] @@ -694,6 +692,9 @@ fn test_coin_id(parent_id: &[u8], puzzle_hash: &[u8], amount: u64) -> [u8; 32] { // (1 (2 (3 ) means: (1 . (2 . (3 . ()))) // and: +#[cfg(test)] +use crate::number_traits::TestNumberTraits; + #[cfg(test)] fn parse_list_impl( a: &mut Allocator, @@ -730,7 +731,7 @@ fn parse_list_impl( (a.new_atom(&buf).unwrap(), v.len() + 1) } else if input.starts_with("-") || "0123456789".contains(input.get(0..1).unwrap()) { let v = input.split_once(" ").unwrap().0; - let num = Number::from_str_radix(v, 10).unwrap(); + let num = Number::from_str_radix(v, 10); (ptr_from_number(a, &num).unwrap(), v.len() + 1) } else { panic!("atom not supported \"{}\"", input); diff --git a/src/gmp_ffi.rs b/src/gmp_ffi.rs new file mode 100644 index 000000000..8f6736d96 --- /dev/null +++ b/src/gmp_ffi.rs @@ -0,0 +1,118 @@ +#![allow(non_camel_case_types, non_snake_case)] + +use core::ptr::NonNull; +use std::ffi::c_void; + +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct mpz_t { + pub alloc: c_int, + pub size: c_int, + pub d: NonNull, +} + +type c_int = i32; +type c_long = i64; +type c_ulong = u64; +type c_ulonglong = u64; +type mpz_srcptr = *const mpz_t; +type mpz_ptr = *mut mpz_t; +type bitcnt_t = c_ulong; + +extern "C" { + #[link_name = "__gmpz_init"] + pub fn mpz_init(x: mpz_ptr); + #[link_name = "__gmpz_import"] + pub fn mpz_import( + rop: mpz_ptr, + count: usize, + order: c_int, + size: usize, + endian: c_int, + nails: usize, + op: *const c_void, + ); + #[link_name = "__gmpz_add_ui"] + pub fn mpz_add_ui(rop: mpz_ptr, op1: mpz_srcptr, op2: c_ulong); + #[link_name = "__gmpz_set"] + pub fn mpz_set(rop: mpz_ptr, op: mpz_srcptr); + #[link_name = "__gmpz_export"] + pub fn mpz_export( + rop: *mut c_void, + countp: *mut usize, + order: c_int, + size: usize, + endian: c_int, + nails: usize, + op: mpz_srcptr, + ) -> *mut c_void; + #[link_name = "__gmpz_sizeinbase"] + pub fn mpz_sizeinbase(arg1: mpz_srcptr, arg2: c_int) -> usize; + #[link_name = "__gmpz_fdiv_qr"] + pub fn mpz_fdiv_qr(q: mpz_ptr, r: mpz_ptr, n: mpz_srcptr, d: mpz_srcptr); + #[link_name = "__gmpz_fdiv_q"] + pub fn mpz_fdiv_q(q: mpz_ptr, n: mpz_srcptr, d: mpz_srcptr); + #[link_name = "__gmpz_fdiv_r"] + pub fn mpz_fdiv_r(r: mpz_ptr, n: mpz_srcptr, d: mpz_srcptr); + #[link_name = "__gmpz_fdiv_q_2exp"] + pub fn mpz_fdiv_q_2exp(q: mpz_ptr, n: mpz_srcptr, b: bitcnt_t); + #[link_name = "__gmpz_init_set_ui"] + pub fn mpz_init_set_ui(rop: mpz_ptr, op: c_ulong); + #[link_name = "__gmpz_init_set_si"] + pub fn mpz_init_set_si(rop: mpz_ptr, op: c_long); + #[link_name = "__gmpz_clear"] + pub fn mpz_clear(x: mpz_ptr); + #[link_name = "__gmpz_add"] + pub fn mpz_add(rop: mpz_ptr, op1: mpz_srcptr, op2: mpz_srcptr); + #[link_name = "__gmpz_sub"] + pub fn mpz_sub(rop: mpz_ptr, op1: mpz_srcptr, op2: mpz_srcptr); + #[link_name = "__gmpz_mul"] + pub fn mpz_mul(rop: mpz_ptr, op1: mpz_srcptr, op2: mpz_srcptr); + #[link_name = "__gmpz_mul_2exp"] + pub fn mpz_mul_2exp(rop: mpz_ptr, op1: mpz_srcptr, op2: bitcnt_t); + #[link_name = "__gmpz_get_si"] + pub fn mpz_get_si(op: mpz_srcptr) -> c_long; + #[link_name = "__gmpz_and"] + pub fn mpz_and(rop: mpz_ptr, op1: mpz_srcptr, op2: mpz_srcptr); + #[link_name = "__gmpz_ior"] + pub fn mpz_ior(rop: mpz_ptr, op1: mpz_srcptr, op2: mpz_srcptr); + #[link_name = "__gmpz_xor"] + pub fn mpz_xor(rop: mpz_ptr, op1: mpz_srcptr, op2: mpz_srcptr); + #[link_name = "__gmpz_com"] + pub fn mpz_com(rop: mpz_ptr, op: mpz_srcptr); + #[link_name = "__gmpz_cmp"] + pub fn mpz_cmp(op1: mpz_srcptr, op2: mpz_srcptr) -> c_int; + #[link_name = "__gmpz_cmp_si"] + pub fn mpz_cmp_si(op1: mpz_srcptr, op2: c_long) -> c_int; + #[link_name = "__gmpz_cmp_ui"] + pub fn mpz_cmp_ui(op1: mpz_srcptr, op2: c_ulong) -> c_int; +} + +#[cfg(test)] +type c_char = i8; + +#[cfg(test)] +extern "C" { + #[link_name = "__gmpz_init_set_str"] + pub fn mpz_init_set_str(rop: mpz_ptr, str: *const c_char, base: c_int) -> c_int; + #[link_name = "__gmpz_get_str"] + pub fn mpz_get_str(str: *mut c_char, base: c_int, op: mpz_srcptr) -> *mut c_char; +} + +#[inline] +pub unsafe extern "C" fn mpz_neg(rop: mpz_ptr, op: mpz_srcptr) { + if rop as mpz_srcptr != op { + mpz_set(rop, op); + } + (*rop).size = -(*rop).size; +} + +#[inline] +pub unsafe extern "C" fn mpz_get_ui(op: mpz_srcptr) -> c_ulong { + if { (*op).size } != 0 { + let p = (*op).d.as_ptr(); + (*p) as c_ulong + } else { + 0 + } +} diff --git a/src/lib.rs b/src/lib.rs index 41b9bef16..bb54177da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,15 @@ mod err_utils; mod gen; pub mod more_ops; pub mod node; + +#[cfg(not(feature = "num-bigint"))] +mod gmp_ffi; +#[cfg(not(feature = "num-bigint"))] +mod number_gmp; + mod number; +mod number_traits; + mod op_utils; #[cfg(not(any(test, target_family = "wasm")))] mod py; diff --git a/src/more_ops.rs b/src/more_ops.rs index 6d40a7fff..23a4dde32 100644 --- a/src/more_ops.rs +++ b/src/more_ops.rs @@ -1,7 +1,4 @@ use bls12_381::{G1Affine, G1Projective, Scalar}; -use num_bigint::{BigUint, Sign}; -use num_integer::Integer; -use std::convert::TryFrom; use std::ops::BitAndAssign; use std::ops::BitOrAssign; use std::ops::BitXorAssign; @@ -12,7 +9,7 @@ use crate::allocator::{Allocator, NodePtr, SExp}; use crate::cost::{check_cost, Cost}; use crate::err_utils::err; use crate::node::Node; -use crate::number::{number_from_u8, ptr_from_number, Number}; +use crate::number::{ptr_from_number, Number, Sign}; use crate::op_utils::{ arg_count, atom, check_arg_count, i32_atom, int_atom, two_ints, u32_from_u8, }; @@ -20,6 +17,8 @@ use crate::reduction::{Reduction, Response}; use crate::serialize::node_to_bytes; use crate::sha2::Sha256; +use crate::number_traits::NumberTraits; + // We ascribe some additional cost per byte for operations that allocate new atoms const MALLOC_COST_PER_BYTE: Cost = 10; @@ -354,7 +353,7 @@ pub fn op_sha256(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response pub fn op_add(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response { let mut cost = ARITH_BASE_COST; let mut byte_count: usize = 0; - let mut total: Number = 0.into(); + let mut total = Number::zero(); for arg in Node::new(a, input) { cost += ARITH_COST_PER_ARG; check_cost( @@ -363,9 +362,9 @@ pub fn op_add(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response { max_cost, )?; let blob = int_atom(&arg, "+")?; - let v: Number = number_from_u8(blob); + let v: Number = Number::from_u8(blob); byte_count += blob.len(); - total += v; + total += &v; } let total = ptr_from_number(a, &total)?; cost += byte_count as Cost * ARITH_COST_PER_BYTE; @@ -375,18 +374,18 @@ pub fn op_add(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response { pub fn op_subtract(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response { let mut cost = ARITH_BASE_COST; let mut byte_count: usize = 0; - let mut total: Number = 0.into(); + let mut total = Number::zero(); let mut is_first = true; for arg in Node::new(a, input) { cost += ARITH_COST_PER_ARG; check_cost(a, cost + byte_count as Cost * ARITH_COST_PER_BYTE, max_cost)?; let blob = int_atom(&arg, "-")?; - let v: Number = number_from_u8(blob); + let v: Number = Number::from_u8(blob); byte_count += blob.len(); if is_first { - total += v; + total += &v; } else { - total -= v; + total -= &v; }; is_first = false; } @@ -405,13 +404,13 @@ pub fn op_multiply(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Respons let blob = int_atom(&arg, "*")?; if first_iter { l0 = blob.len(); - total = number_from_u8(blob); + total = Number::from_u8(blob); first_iter = false; continue; } let l1 = blob.len(); - total *= number_from_u8(blob); + total *= Number::from_u8(blob); cost += MUL_COST_PER_OP; cost += (l0 + l1) as Cost * MUL_LINEAR_COST_PER_BYTE; @@ -437,7 +436,7 @@ pub fn op_div_impl(a: &mut Allocator, input: NodePtr, mempool: bool) -> Response // this is to preserve a buggy behavior from the initial implementation // of this operator. - if q == (-1).into() && r != 0.into() { + if q.equal(-1) && r.not_equal(0) { q += 1; } let q1 = ptr_from_number(a, &q)?; @@ -480,7 +479,7 @@ pub fn op_gr(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { let cost = GR_BASE_COST + (v0.len() + v1.len()) as Cost * GR_COST_PER_BYTE; Ok(Reduction( cost, - if number_from_u8(v0) > number_from_u8(v1) { + if Number::from_u8(v0) > Number::from_u8(v1) { a.one() } else { a.null() @@ -570,7 +569,7 @@ pub fn op_ash(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { check_arg_count(&args, 2, "ash")?; let a0 = args.first()?; let b0 = int_atom(&a0, "ash")?; - let i0 = number_from_u8(b0); + let i0 = Number::from_u8(b0); let l0 = b0.len(); let rest = args.rest()?; let a1 = i32_atom(&rest.first()?, "ash")?; @@ -647,7 +646,7 @@ pub fn op_lsh(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { check_arg_count(&args, 2, "lsh")?; let a0 = args.first()?; let b0 = int_atom(&a0, "lsh")?; - let i0 = BigUint::from_bytes_be(b0); + let i0 = Number::from_unsigned_bytes_be(b0); let l0 = b0.len(); let rest = args.rest()?; let a1 = i32_atom(&rest.first()?, "lsh")?; @@ -655,8 +654,6 @@ pub fn op_lsh(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response { return args.rest()?.first()?.err("shift too large"); } - let i0: Number = i0.into(); - let v: Number = if a1 > 0 { i0 << a1 } else { i0 >> -a1 }; let l1 = limbs_for_int(&v); @@ -720,7 +717,7 @@ fn binop_reduction( let mut cost = LOG_BASE_COST; for arg in Node::new(a, input) { let blob = int_atom(&arg, op_name)?; - let n0 = number_from_u8(blob); + let n0 = Number::from_u8(blob); op_f(&mut total, &n0); arg_size += blob.len(); cost += LOG_COST_PER_ARG; @@ -745,7 +742,7 @@ fn logior_op(a: &mut Number, b: &Number) { } pub fn op_logior(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response { - let v: Number = (0).into(); + let v = Number::zero(); binop_reduction("logior", a, v, input, max_cost, logior_op) } @@ -754,7 +751,7 @@ fn logxor_op(a: &mut Number, b: &Number) { } pub fn op_logxor(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Response { - let v: Number = (0).into(); + let v = Number::zero(); binop_reduction("logxor", a, v, input, max_cost, logxor_op) } @@ -763,7 +760,7 @@ pub fn op_lognot(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response check_arg_count(&args, 1, "lognot")?; let a0 = args.first()?; let v0 = int_atom(&a0, "lognot")?; - let mut n: Number = number_from_u8(v0); + let mut n: Number = Number::from_u8(v0); n = !n; let cost = LOGNOT_BASE_COST + ((v0.len() as Cost) * LOGNOT_COST_PER_BYTE); let r = ptr_from_number(a, &n)?; @@ -808,12 +805,12 @@ pub fn op_softfork(a: &mut Allocator, input: NodePtr, max_cost: Cost) -> Respons let args = Node::new(a, input); match args.pair() { Some((p1, _)) => { - let n: Number = number_from_u8(int_atom(&p1, "softfork")?); + let n: Number = Number::from_u8(int_atom(&p1, "softfork")?); if n.sign() == Sign::Plus { - if n > Number::from(max_cost) { + if n.greater_than(max_cost) { return err(a.null(), "cost exceeded"); } - let cost: Cost = TryFrom::try_from(&n).unwrap(); + let cost: Cost = n.to_u64(); Ok(Reduction(cost, args.null().node)) } else { args.err("cost must be > 0") @@ -830,14 +827,13 @@ lazy_static! { 0xd8, 0x05, 0x53, 0xbd, 0xa4, 0x02, 0xff, 0xfe, 0x5b, 0xfe, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, ]; - let n = BigUint::from_bytes_be(order_as_bytes); - n.into() + Number::from_unsigned_bytes_be(order_as_bytes) }; } fn mod_group_order(n: Number) -> Number { - let order = GROUP_ORDER.clone(); - let mut remainder = n.mod_floor(&order); + let order: &Number = &GROUP_ORDER; + let mut remainder = n.mod_floor(order); if remainder.sign() == Sign::Minus { remainder += order; } @@ -862,7 +858,7 @@ pub fn op_pubkey_for_exp(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> let a0 = args.first()?; let v0 = int_atom(&a0, "pubkey_for_exp")?; - let exp: Number = mod_group_order(number_from_u8(v0)); + let exp: Number = mod_group_order(Number::from_u8(v0)); let cost = PUBKEY_BASE_COST + (v0.len() as Cost) * PUBKEY_COST_PER_BYTE; let exp: Scalar = number_to_scalar(exp); let point: G1Projective = G1Affine::generator() * exp; diff --git a/src/number.rs b/src/number.rs index b66a5df16..bdd338864 100644 --- a/src/number.rs +++ b/src/number.rs @@ -1,368 +1,144 @@ +#[cfg(not(feature = "num-bigint"))] +pub use crate::number_gmp::{Number, Sign}; + +#[cfg(feature = "num-bigint")] +pub use num_bigint::BigInt as Number; +#[cfg(feature = "num-bigint")] +pub use num_bigint::Sign; + use crate::allocator::{Allocator, NodePtr}; -use crate::node::Node; +use crate::number_traits::NumberTraits; use crate::reduction::EvalErr; -use num_bigint::BigInt; -pub type Number = BigInt; - pub fn ptr_from_number(allocator: &mut Allocator, item: &Number) -> Result { - let bytes: Vec = item.to_signed_bytes_be(); - let mut slice = bytes.as_slice(); + let bytes: Vec = item.to_signed_bytes(); + allocator.new_atom(bytes.as_slice()) +} - // make number minimal by removing leading zeros - while (!slice.is_empty()) && (slice[0] == 0) { - if slice.len() > 1 && (slice[1] & 0x80 == 0x80) { - break; - } - slice = &slice[1..]; +#[cfg(test)] +#[cfg(feature = "num-bigint")] +impl crate::number_traits::TestNumberTraits for Number { + fn from_str_radix(s: &str, radix: i32) -> Number { + num_traits::Num::from_str_radix(s, radix as u32).unwrap() } - allocator.new_atom(slice) } -impl From<&Node<'_>> for Option { - fn from(item: &Node) -> Self { - let v: &[u8] = item.atom()?; - Some(number_from_u8(v)) +#[cfg(feature = "num-bigint")] +impl crate::number_traits::NumberTraits for Number { + fn from_unsigned_bytes_be(v: &[u8]) -> Number { + let i = num_bigint::BigUint::from_bytes_be(v); + i.into() + } + + fn to_signed_bytes(&self) -> Vec { + let mut ret = self.to_signed_bytes_be(); + + // make number minimal by removing leading zeros + while (!ret.is_empty()) && (ret[0] == 0) { + if ret.len() > 1 && (ret[1] & 0x80 == 0x80) { + break; + } + ret.remove(0); + } + ret + } + + fn zero() -> Number { + ::zero() + } + + fn from_u8(v: &[u8]) -> Number { + let len = v.len(); + if len == 0 { + Number::zero() + } else { + Number::from_signed_bytes_be(v) + } } -} -pub fn number_from_u8(v: &[u8]) -> Number { - let len = v.len(); - if len == 0 { - 0.into() - } else { - Number::from_signed_bytes_be(v) + fn to_u64(&self) -> u64 { + use std::convert::TryFrom; + TryFrom::try_from(self).unwrap() + } + + fn div_mod_floor(&self, denominator: &Number) -> (Number, Number) { + num_integer::Integer::div_mod_floor(self, denominator) + } + + fn mod_floor(&self, denominator: &Number) -> Number { + num_integer::Integer::mod_floor(&self, denominator) + } + + fn equal(&self, other: i64) -> bool { + self == &Number::from(other) + } + + fn not_equal(&self, other: i64) -> bool { + self != &Number::from(other) + } + + fn greater_than(&self, other: u64) -> bool { + self > &Number::from(other) } } #[test] fn test_ptr_from_number() { + use crate::number_traits::NumberTraits; let mut a = Allocator::new(); // 0 is encoded as an empty string - let num = number_from_u8(&[0]); + let num = Number::from_u8(&[0]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "0"); assert_eq!(a.atom(ptr).len(), 0); - let num = number_from_u8(&[1]); + let num = Number::from_u8(&[1]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "1"); assert_eq!(&[1], &a.atom(ptr)); // leading zeroes are redundant - let num = number_from_u8(&[0, 0, 0, 1]); + let num = Number::from_u8(&[0, 0, 0, 1]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "1"); assert_eq!(&[1], &a.atom(ptr)); - let num = number_from_u8(&[0x00, 0x00, 0x80]); + let num = Number::from_u8(&[0x00, 0x00, 0x80]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "128"); assert_eq!(&[0x00, 0x80], &a.atom(ptr)); // A leading zero is necessary to encode a positive number with the // penultimate byte's most significant bit set - let num = number_from_u8(&[0x00, 0xff]); + let num = Number::from_u8(&[0x00, 0xff]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "255"); assert_eq!(&[0x00, 0xff], &a.atom(ptr)); - let num = number_from_u8(&[0x7f, 0xff]); + let num = Number::from_u8(&[0x7f, 0xff]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "32767"); assert_eq!(&[0x7f, 0xff], &a.atom(ptr)); // the first byte is redundant, it's still -1 - let num = number_from_u8(&[0xff, 0xff]); + let num = Number::from_u8(&[0xff, 0xff]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "-1"); assert_eq!(&[0xff], &a.atom(ptr)); - let num = number_from_u8(&[0xff]); + let num = Number::from_u8(&[0xff]); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(format!("{}", num), "-1"); assert_eq!(&[0xff], &a.atom(ptr)); - let num = number_from_u8(&[0x00, 0x80, 0x00]); + let num = Number::from_u8(&[0x00, 0x80, 0x00]); assert_eq!(format!("{}", num), "32768"); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(&[0x00, 0x80, 0x00], &a.atom(ptr)); - let num = number_from_u8(&[0x00, 0x40, 0x00]); + let num = Number::from_u8(&[0x00, 0x40, 0x00]); assert_eq!(format!("{}", num), "16384"); let ptr = ptr_from_number(&mut a, &num).unwrap(); assert_eq!(&[0x40, 0x00], &a.atom(ptr)); } - -#[cfg(test)] -use num_bigint::{BigUint, Sign}; - -#[cfg(test)] -use std::convert::TryFrom; - -#[cfg(test)] -fn roundtrip_bytes(b: &[u8]) { - let negative = b.len() > 0 && (b[0] & 0x80) != 0; - let zero = b.len() == 0 || (b.len() == 1 && b[0] == 0); - - { - let num = Number::from_signed_bytes_be(b); - - if negative { - assert!(num.sign() == Sign::Minus); - } else if zero { - assert!(num.sign() == Sign::NoSign); - } else { - assert!(num.sign() == Sign::Plus); - } - - let round_trip = num.to_signed_bytes_be(); - // num-bigin produces a single 0 byte for the value 0. We expect an - // empty array - let round_trip = if round_trip == &[0] { - &round_trip[1..] - } else { - &round_trip - }; - - assert_eq!(round_trip, b); - - // test to_bytes_le() - let (sign, mut buf_le) = num.to_bytes_le(); - - // there's a special case for empty input buffers, which will result in - // a single 0 byte here - if b == &[] { - assert_eq!(buf_le, &[0]); - buf_le.remove(0); - } - assert!(sign == num.sign()); - - // the buffer we get from to_bytes_le() is unsigned (since the sign is - // returned separately). This means it doesn't ever need to prepend a 0 - // byte when the MSB is set. When we're comparing this against the input - // buffer, we need to add such 0 byte to buf_le to make them compare - // equal. - // the 0 prefix has to be added to the end though, since it's little - // endian - if buf_le.len() > 0 && (buf_le.last().unwrap() & 0x80) != 0 { - buf_le.push(0); - } - - if sign != Sign::Minus { - assert!(buf_le.iter().eq(b.iter().rev())); - } else { - let negated = -num; - let magnitude = negated.to_signed_bytes_be(); - assert!(buf_le.iter().eq(magnitude.iter().rev())); - } - } - - // test parsing unsigned bytes - { - let unsigned_num: Number = BigUint::from_bytes_be(b).into(); - assert!(unsigned_num.sign() != Sign::Minus); - let unsigned_round_trip = unsigned_num.to_signed_bytes_be(); - let unsigned_round_trip = if unsigned_round_trip == &[0] { - &unsigned_round_trip[1..] - } else { - &unsigned_round_trip - }; - if b.len() > 0 && (b[0] & 0x80) != 0 { - // we expect a new leading zero here, to keep the value positive - assert!(unsigned_round_trip[0] == 0); - assert_eq!(&unsigned_round_trip[1..], b); - } else { - assert_eq!(unsigned_round_trip, b); - } - } -} - -#[test] -fn test_number_round_trip_bytes() { - roundtrip_bytes(&[]); - - for i in 1..=255 { - roundtrip_bytes(&[i]); - } - - for i in 0..=127 { - roundtrip_bytes(&[0xff, i]); - } - - for i in 128..=255 { - roundtrip_bytes(&[0, i]); - } - - for i in 0..=127 { - roundtrip_bytes(&[ - 0xff, i, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - ]); - } - - for i in 128..=255 { - roundtrip_bytes(&[ - 0, i, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - ]); - } - - for i in 0..=127 { - roundtrip_bytes(&[0xff, i, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - } - - for i in 128..=255 { - roundtrip_bytes(&[0, i, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - } -} - -#[cfg(test)] -fn roundtrip_u64(v: u64) { - let num: Number = v.into(); - assert!(num.sign() != Sign::Minus); - - assert!(num.bits() <= 64); - - let round_trip: u64 = TryFrom::try_from(num).unwrap(); - assert_eq!(round_trip, v); -} - -#[test] -fn test_round_trip_u64() { - for v in 0..=0x100 { - roundtrip_u64(v); - } - - for v in 0x7ffe..=0x8001 { - roundtrip_u64(v); - } - - for v in 0xfffe..=0x10000 { - roundtrip_u64(v); - } - - for v in 0x7ffffffe..=0x80000001 { - roundtrip_u64(v); - } - for v in 0xfffffffe..=0x100000000 { - roundtrip_u64(v); - } - - for v in 0x7ffffffffffffffe..=0x8000000000000001 { - roundtrip_u64(v); - } - - for v in 0xfffffffffffffffe..=0xffffffffffffffff { - roundtrip_u64(v); - } -} - -#[cfg(test)] -fn roundtrip_i64(v: i64) { - let num: Number = v.into(); - if v == 0 { - assert!(num.sign() == Sign::NoSign); - } else if v < 0 { - assert!(num.sign() == Sign::Minus); - } else if v > 0 { - assert!(num.sign() == Sign::Plus); - } - - assert!(num.bits() <= 64); - - let round_trip: i64 = TryFrom::try_from(num).unwrap(); - assert_eq!(round_trip, v); -} - -#[test] -fn test_round_trip_i64() { - for v in -0x100..=0x100 { - roundtrip_i64(v); - } - - for v in 0x7ffe..=0x8001 { - roundtrip_i64(v); - } - - for v in -0x8001..-0x7ffe { - roundtrip_i64(v); - } - - for v in 0xfffe..=0x10000 { - roundtrip_i64(v); - } - - for v in -0x10000..-0xfffe { - roundtrip_i64(v); - } - - for v in 0x7ffffffe..=0x80000001 { - roundtrip_i64(v); - } - - for v in -0x80000001..-0x7ffffffe { - roundtrip_i64(v); - } - - for v in 0xfffffffe..=0x100000000 { - roundtrip_i64(v); - } - - for v in -0x100000000..-0xfffffffe { - roundtrip_i64(v); - } - - for v in 0x7ffffffffffffffe..=0x7fffffffffffffff { - roundtrip_i64(v); - } - - for v in -0x8000000000000000..-0x7ffffffffffffffe { - roundtrip_i64(v); - } -} - -#[cfg(test)] -fn bits(b: &[u8]) -> u64 { - Number::from_signed_bytes_be(b).bits() -} - -#[test] -fn test_bits() { - assert_eq!(bits(&[]), 0); - assert_eq!(bits(&[0]), 0); - assert_eq!(bits(&[0b01111111]), 7); - assert_eq!(bits(&[0b00111111]), 6); - assert_eq!(bits(&[0b00011111]), 5); - assert_eq!(bits(&[0b00001111]), 4); - assert_eq!(bits(&[0b00000111]), 3); - assert_eq!(bits(&[0b00000011]), 2); - assert_eq!(bits(&[0b00000001]), 1); - assert_eq!(bits(&[0b00000000]), 0); - - assert_eq!(bits(&[0b01111111, 0xff]), 15); - assert_eq!(bits(&[0b00111111, 0xff]), 14); - assert_eq!(bits(&[0b00011111, 0xff]), 13); - assert_eq!(bits(&[0b00001111, 0xff]), 12); - assert_eq!(bits(&[0b00000111, 0xff]), 11); - assert_eq!(bits(&[0b00000011, 0xff]), 10); - assert_eq!(bits(&[0b00000001, 0xff]), 9); - assert_eq!(bits(&[0b00000000, 0xff]), 8); - - assert_eq!(bits(&[0b11111111]), 1); - assert_eq!(bits(&[0b11111110]), 2); - assert_eq!(bits(&[0b11111100]), 3); - assert_eq!(bits(&[0b11111000]), 4); - assert_eq!(bits(&[0b11110000]), 5); - assert_eq!(bits(&[0b11100000]), 6); - assert_eq!(bits(&[0b11000000]), 7); - assert_eq!(bits(&[0b10000000]), 8); - - assert_eq!(bits(&[0b11111111, 0]), 9); - assert_eq!(bits(&[0b11111110, 0]), 10); - assert_eq!(bits(&[0b11111100, 0]), 11); - assert_eq!(bits(&[0b11111000, 0]), 12); - assert_eq!(bits(&[0b11110000, 0]), 13); - assert_eq!(bits(&[0b11100000, 0]), 14); - assert_eq!(bits(&[0b11000000, 0]), 15); - assert_eq!(bits(&[0b10000000, 0]), 16); -} diff --git a/src/number_gmp.rs b/src/number_gmp.rs new file mode 100644 index 000000000..47159ecfa --- /dev/null +++ b/src/number_gmp.rs @@ -0,0 +1,769 @@ +use crate::gmp_ffi as gmp; +use crate::node::Node; +use crate::number_traits::NumberTraits; +#[cfg(test)] +use crate::number_traits::TestNumberTraits; +use core::mem::MaybeUninit; +use std::cmp::Ordering; +use std::cmp::PartialOrd; +use std::ffi::c_void; +use std::ops::Drop; +use std::ops::{ + AddAssign, BitAndAssign, BitOrAssign, BitXorAssign, MulAssign, Not, Shl, Shr, SubAssign, +}; + +#[allow(clippy::enum_variant_names)] +#[derive(PartialEq)] +pub enum Sign { + Minus, + NoSign, + Plus, +} + +pub struct Number { + v: gmp::mpz_t, +} + +#[cfg(test)] +impl TestNumberTraits for Number { + fn from_str_radix(mut s: &str, radix: i32) -> Number { + let negative = s.get(0..1).unwrap() == "-"; + if negative { + s = s.get(1..).unwrap(); + } + let input = CString::new(s).unwrap(); + let mut v = MaybeUninit::::uninit(); + let result = unsafe { gmp::mpz_init_set_str(v.as_mut_ptr(), input.as_ptr(), radix) }; + // v will be initialized even if an error occurs, so we will need to + // capture it in a Number regardless + let mut ret = Number { + v: unsafe { v.assume_init() }, + }; + if negative { + unsafe { + gmp::mpz_neg(&mut ret.v, &ret.v); + } + } + assert!(result == 0); + ret + } +} + +impl NumberTraits for Number { + fn from_unsigned_bytes_be(v: &[u8]) -> Number { + let mut ret = Number::zero(); + if !v.is_empty() { + unsafe { + gmp::mpz_import(&mut ret.v, v.len(), 1, 1, 0, 0, v.as_ptr() as *const c_void); + } + } + ret + } + + fn to_signed_bytes(&self) -> Vec { + let size = (self.bits() + 7) / 8; + let mut ret: Vec = Vec::new(); + if size == 0 { + return ret; + } + ret.resize(size + 1, 0); + let sign = self.sign(); + let mut out_size: usize = size; + unsafe { + gmp::mpz_export( + ret.as_mut_slice()[1..].as_mut_ptr() as *mut c_void, + &mut out_size, + 1, + 1, + 0, + 0, + &self.v, + ); + } + // apparently mpz_export prints 0 bytes to the buffer if the value is 0 + // hence the special case in the assert below. + assert!(out_size == ret.len() - 1); + if sign == Sign::Minus { + // If the value is negative, we need to convert it to two's + // complement. We can't do that in-place. + let mut carry = true; + for digit in &mut ret.iter_mut().rev() { + let res = (!*digit).overflowing_add(carry as u8); + *digit = res.0; + carry = res.1; + } + assert!(!carry); + assert!(ret[0] & 0x80 != 0); + if (ret[1] & 0x80) != 0 { + ret.remove(0); + } + } else if ret[1] & 0x80 == 0 { + ret.remove(0); + } + ret + } + + fn zero() -> Number { + let mut v = MaybeUninit::::uninit(); + unsafe { + gmp::mpz_init(v.as_mut_ptr()); + } + Number { + v: unsafe { v.assume_init() }, + } + } + + fn from_u8(v: &[u8]) -> Number { + Number::from_signed_bytes_be(v) + } + + fn to_u64(&self) -> u64 { + u64::from(self) + } + + // returns the quotient and remained, from dividing self with denominator + fn div_mod_floor(&self, denominator: &Number) -> (Number, Number) { + let mut q = Number::zero(); + let mut r = Number::zero(); + unsafe { + gmp::mpz_fdiv_qr(&mut q.v, &mut r.v, &self.v, &denominator.v); + } + (q, r) + } + + fn mod_floor(&self, denominator: &Number) -> Number { + let mut r = Number::zero(); + unsafe { + gmp::mpz_fdiv_r(&mut r.v, &self.v, &denominator.v); + } + r + } + + fn equal(&self, other: i64) -> bool { + self == &other + } + + fn not_equal(&self, other: i64) -> bool { + self != &other + } + + fn greater_than(&self, other: u64) -> bool { + self > &other + } +} + +impl Number { + pub fn from_signed_bytes_be(v: &[u8]) -> Number { + let mut ret = Number::zero(); + if v.is_empty() { + return ret; + } + // mpz_import() only reads unsigned values + let negative = (v[0] & 0x80) != 0; + + if negative { + // since the bytes we read are two's complement + // if the most significant bit was set, we need to + // convert the value to a negative one. We do this by flipping + // all bits, adding one and then negating it. + let mut v = v.to_vec(); + for digit in &mut v { + *digit = !*digit; + } + unsafe { + gmp::mpz_import(&mut ret.v, v.len(), 1, 1, 0, 0, v.as_ptr() as *const c_void); + gmp::mpz_add_ui(&mut ret.v, &ret.v, 1); + gmp::mpz_neg(&mut ret.v, &ret.v); + } + } else { + unsafe { + gmp::mpz_import(&mut ret.v, v.len(), 1, 1, 0, 0, v.as_ptr() as *const c_void); + } + } + ret + } + + pub fn to_bytes_le(&self) -> (Sign, Vec) { + let sgn = self.sign(); + + let size = (self.bits() + 7) / 8; + let mut ret: Vec = Vec::new(); + if size == 0 { + return (Sign::NoSign, ret); + } + ret.resize(size, 0); + + let mut out_size: usize = size; + unsafe { + gmp::mpz_export( + ret.as_mut_ptr() as *mut c_void, + &mut out_size, + -1, + 1, + 0, + 0, + &self.v, + ); + } + assert_eq!(out_size, ret.len()); + (sgn, ret) + } + + pub fn bits(&self) -> usize { + // GnuMP says that any integer needs at least 1 bit to be represented. + // but we say 0 requires 0 bits + if self.sign() == Sign::NoSign { + 0 + } else { + unsafe { gmp::mpz_sizeinbase(&self.v, 2) } + } + } + + pub fn sign(&self) -> Sign { + match unsafe { gmp::mpz_cmp_si(&self.v, 0) } { + d if d < 0 => Sign::Minus, + d if d > 0 => Sign::Plus, + _ => Sign::NoSign, + } + } + + pub fn div_floor(&self, denominator: &Number) -> Number { + let mut ret = Number::zero(); + unsafe { + gmp::mpz_fdiv_q(&mut ret.v, &self.v, &denominator.v); + } + ret + } +} + +impl Drop for Number { + fn drop(&mut self) { + unsafe { + gmp::mpz_clear(&mut self.v); + } + } +} + +// Addition + +impl AddAssign<&Number> for Number { + fn add_assign(&mut self, other: &Self) { + unsafe { + gmp::mpz_add(&mut self.v, &self.v, &other.v); + } + } +} + +// This is only here for op_div() +impl AddAssign for Number { + fn add_assign(&mut self, other: u64) { + unsafe { + gmp::mpz_add_ui(&mut self.v, &self.v, other); + } + } +} + +// Subtraction + +impl SubAssign<&Number> for Number { + fn sub_assign(&mut self, other: &Self) { + unsafe { + gmp::mpz_sub(&mut self.v, &self.v, &other.v); + } + } +} + +// Multiplication + +impl MulAssign for Number { + fn mul_assign(&mut self, other: Self) { + unsafe { + gmp::mpz_mul(&mut self.v, &self.v, &other.v); + } + } +} + +// Shift + +impl Shl for Number { + type Output = Self; + fn shl(mut self, n: i32) -> Self { + assert!(n >= 0); + unsafe { + gmp::mpz_mul_2exp(&mut self.v, &self.v, n as u64); + } + self + } +} + +impl Shr for Number { + type Output = Self; + fn shr(mut self, n: i32) -> Self { + assert!(n >= 0); + unsafe { + gmp::mpz_fdiv_q_2exp(&mut self.v, &self.v, n as u64); + } + self + } +} + +// Conversion + +impl From for Number { + fn from(other: i64) -> Self { + let mut v = MaybeUninit::::uninit(); + unsafe { + gmp::mpz_init_set_si(v.as_mut_ptr(), other); + } + Number { + v: unsafe { v.assume_init() }, + } + } +} + +impl From for Number { + fn from(other: i32) -> Self { + let mut v = MaybeUninit::::uninit(); + unsafe { + gmp::mpz_init_set_si(v.as_mut_ptr(), other as i64); + } + Number { + v: unsafe { v.assume_init() }, + } + } +} + +impl From for Number { + fn from(other: u64) -> Self { + let mut v = MaybeUninit::::uninit(); + unsafe { + gmp::mpz_init_set_ui(v.as_mut_ptr(), other); + } + Number { + v: unsafe { v.assume_init() }, + } + } +} + +impl From for Number { + fn from(other: usize) -> Self { + let mut v = MaybeUninit::::uninit(); + unsafe { + gmp::mpz_init_set_ui(v.as_mut_ptr(), other as u64); + } + Number { + v: unsafe { v.assume_init() }, + } + } +} + +impl From<&Number> for u64 { + fn from(n: &Number) -> u64 { + unsafe { + assert!(gmp::mpz_sizeinbase(&n.v, 2) <= 64); + assert!(gmp::mpz_cmp_si(&n.v, 0) >= 0); + gmp::mpz_get_ui(&n.v) + } + } +} + +impl From<&Number> for i64 { + fn from(n: &Number) -> i64 { + unsafe { + assert!(gmp::mpz_sizeinbase(&n.v, 2) <= 64); + gmp::mpz_get_si(&n.v) + } + } +} + +// Bit operations + +impl BitXorAssign<&Number> for Number { + fn bitxor_assign(&mut self, other: &Self) { + unsafe { + gmp::mpz_xor(&mut self.v, &self.v, &other.v); + } + } +} + +impl BitOrAssign<&Number> for Number { + fn bitor_assign(&mut self, other: &Self) { + unsafe { + gmp::mpz_ior(&mut self.v, &self.v, &other.v); + } + } +} + +impl BitAndAssign<&Number> for Number { + fn bitand_assign(&mut self, other: &Self) { + unsafe { + gmp::mpz_and(&mut self.v, &self.v, &other.v); + } + } +} + +impl Not for Number { + type Output = Self; + fn not(self) -> Self { + let mut ret = Number::zero(); + unsafe { + gmp::mpz_com(&mut ret.v, &self.v); + } + ret + } +} + +// Comparisons + +impl PartialEq for Number { + fn eq(&self, other: &Self) -> bool { + unsafe { gmp::mpz_cmp(&self.v, &other.v) == 0 } + } +} + +impl PartialEq for Number { + fn eq(&self, other: &u64) -> bool { + unsafe { gmp::mpz_cmp_ui(&self.v, *other) == 0 } + } +} + +impl PartialEq for Number { + fn eq(&self, other: &i64) -> bool { + unsafe { gmp::mpz_cmp_si(&self.v, *other) == 0 } + } +} + +impl PartialEq for Number { + fn eq(&self, other: &i32) -> bool { + unsafe { gmp::mpz_cmp_si(&self.v, *other as i64) == 0 } + } +} + +fn ord_helper(r: i32) -> Option { + match r { + d if d < 0 => Some(Ordering::Less), + d if d > 0 => Some(Ordering::Greater), + _ => Some(Ordering::Equal), + } +} + +impl PartialOrd for Number { + fn partial_cmp(&self, other: &Number) -> Option { + ord_helper(unsafe { gmp::mpz_cmp(&self.v, &other.v) }) + } +} + +impl PartialOrd for Number { + fn partial_cmp(&self, other: &u64) -> Option { + ord_helper(unsafe { gmp::mpz_cmp_ui(&self.v, *other) }) + } +} + +unsafe impl Sync for Number {} + +impl From<&Node<'_>> for Option { + fn from(item: &Node) -> Self { + let v: &[u8] = item.atom()?; + Some(Number::from_u8(v)) + } +} + +// TODO: move all tests to number.rs so we can test both the GMP and num-bigint +// versions + +// ==== TESTS ==== + +#[cfg(test)] +use std::ffi::{CStr, CString}; +#[cfg(test)] +use std::fmt; + +#[cfg(test)] +impl fmt::Display for Number { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let len = unsafe { gmp::mpz_sizeinbase(&self.v, 10) } + 2; + let mut storage = Vec::::with_capacity(len); + let c_str = unsafe { gmp::mpz_get_str(storage.as_mut_ptr(), 10, &self.v) }; + unsafe { f.write_str(CStr::from_ptr(c_str).to_str().unwrap()) } + } +} + +#[cfg(test)] +fn roundtrip_bytes(b: &[u8]) { + let negative = b.len() > 0 && (b[0] & 0x80) != 0; + let zero = b.len() == 0 || (b.len() == 1 && b[0] == 0); + + { + let num = Number::from_signed_bytes_be(b); + + if negative { + assert!(num.sign() == Sign::Minus); + } else if zero { + assert!(num.sign() == Sign::NoSign); + } else { + assert!(num.sign() == Sign::Plus); + } + + let round_trip = num.to_signed_bytes(); + + assert_eq!(round_trip, b); + + // test to_bytes_le() + let (sign, mut buf_le) = num.to_bytes_le(); + + assert!(sign == num.sign()); + + // the buffer we get from to_bytes_le() is unsigned (since the sign is + // returned separately). This means it doesn't ever need to prepend a 0 + // byte when the MSB is set. When we're comparing this against the input + // buffer, we need to add such 0 byte to buf_le to make them compare + // equal. + // the 0 prefix has to be added to the end though, since it's little + // endian + if buf_le.len() > 0 && (buf_le.last().unwrap() & 0x80) != 0 { + buf_le.push(0); + } + + if sign != Sign::Minus { + assert!(buf_le.iter().eq(b.iter().rev())); + } else { + let mut negated = Number::zero(); + unsafe { + gmp::mpz_neg(&mut negated.v, &num.v); + } + let magnitude = negated.to_signed_bytes(); + assert!(buf_le.iter().eq(magnitude.iter().rev())); + } + } + + // test parsing unsigned bytes + { + let unsigned_num = Number::from_unsigned_bytes_be(b); + assert!(unsigned_num.sign() != Sign::Minus); + let unsigned_round_trip = unsigned_num.to_signed_bytes(); + let unsigned_round_trip = if unsigned_round_trip == &[0] { + &unsigned_round_trip[1..] + } else { + &unsigned_round_trip + }; + if b.len() > 0 && (b[0] & 0x80) != 0 { + // we expect a new leading zero here, to keep the value positive + assert!(unsigned_round_trip[0] == 0); + assert_eq!(&unsigned_round_trip[1..], b); + } else { + assert_eq!(unsigned_round_trip, b); + } + } +} + +#[test] +fn test_number_round_trip_bytes() { + roundtrip_bytes(&[]); + + // 0 doesn't round-trip, since we represent that by an empty buffer + for i in 1..=255 { + roundtrip_bytes(&[i]); + } + + for i in 0..=127 { + roundtrip_bytes(&[0xff, i]); + } + + for i in 128..=255 { + roundtrip_bytes(&[0, i]); + } + + for i in 0..=127 { + roundtrip_bytes(&[ + 0xff, i, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ]); + } + + for i in 128..=255 { + roundtrip_bytes(&[ + 0, i, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ]); + } + + for i in 0..=127 { + roundtrip_bytes(&[0xff, i, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + } + + for i in 128..=255 { + roundtrip_bytes(&[0, i, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + } +} + +#[cfg(test)] +fn roundtrip_u64(v: u64) { + let num: Number = v.into(); + assert!(num.sign() != Sign::Minus); + + assert!(num.bits() <= 64); + assert!(!(num < v)); + assert!(!(num > v)); + assert!(!(num != v)); + assert!(num == v); + assert!(num <= v); + assert!(num >= v); + + if v != u64::MAX { + assert!(num < v + 1); + assert!(!(num > v + 1)); + assert!(num != v + 1); + assert!(!(num == v + 1)); + assert!(num <= v + 1); + assert!(!(num >= v + 1)); + } + + if v != u64::MIN { + assert!(!(num < v - 1)); + assert!(num > v - 1); + assert!(num != v - 1); + assert!(!(num == v - 1)); + assert!(!(num <= v - 1)); + assert!(num >= v - 1); + } + + let round_trip: u64 = (&num).into(); + assert_eq!(round_trip, v); +} + +#[test] +fn test_round_trip_u64() { + for v in 0..=0x100 { + roundtrip_u64(v); + } + + for v in 0x7ffe..=0x8001 { + roundtrip_u64(v); + } + + for v in 0xfffe..=0x10000 { + roundtrip_u64(v); + } + + for v in 0x7ffffffe..=0x80000001 { + roundtrip_u64(v); + } + for v in 0xfffffffe..=0x100000000 { + roundtrip_u64(v); + } + + for v in 0x7ffffffffffffffe..=0x8000000000000001 { + roundtrip_u64(v); + } + + for v in 0xfffffffffffffffe..=0xffffffffffffffff { + roundtrip_u64(v); + } +} + +#[cfg(test)] +fn roundtrip_i64(v: i64) { + let num: Number = v.into(); + if v == 0 { + assert!(num.sign() == Sign::NoSign); + } else if v < 0 { + assert!(num.sign() == Sign::Minus); + } else if v > 0 { + assert!(num.sign() == Sign::Plus); + } + + assert!(num.bits() <= 64); + let round_trip: i64 = (&num).into(); + assert_eq!(round_trip, v); +} + +#[test] +fn test_round_trip_i64() { + for v in -0x100..=0x100 { + roundtrip_i64(v); + } + + for v in 0x7ffe..=0x8001 { + roundtrip_i64(v); + } + + for v in -0x8001..-0x7ffe { + roundtrip_i64(v); + } + + for v in 0xfffe..=0x10000 { + roundtrip_i64(v); + } + + for v in -0x10000..-0xfffe { + roundtrip_i64(v); + } + + for v in 0x7ffffffe..=0x80000001 { + roundtrip_i64(v); + } + + for v in -0x80000001..-0x7ffffffe { + roundtrip_i64(v); + } + + for v in 0xfffffffe..=0x100000000 { + roundtrip_i64(v); + } + + for v in -0x100000000..-0xfffffffe { + roundtrip_i64(v); + } + + for v in 0x7ffffffffffffffe..=0x7fffffffffffffff { + roundtrip_i64(v); + } + + for v in -0x8000000000000000..-0x7ffffffffffffffe { + roundtrip_i64(v); + } +} + +#[cfg(test)] +fn bits(b: &[u8]) -> u64 { + Number::from_signed_bytes_be(b).bits() as u64 +} + +#[test] +fn test_bits() { + assert_eq!(bits(&[]), 0); + assert_eq!(bits(&[0]), 0); + assert_eq!(bits(&[0b01111111]), 7); + assert_eq!(bits(&[0b00111111]), 6); + assert_eq!(bits(&[0b00011111]), 5); + assert_eq!(bits(&[0b00001111]), 4); + assert_eq!(bits(&[0b00000111]), 3); + assert_eq!(bits(&[0b00000011]), 2); + assert_eq!(bits(&[0b00000001]), 1); + assert_eq!(bits(&[0b00000000]), 0); + + assert_eq!(bits(&[0b01111111, 0xff]), 15); + assert_eq!(bits(&[0b00111111, 0xff]), 14); + assert_eq!(bits(&[0b00011111, 0xff]), 13); + assert_eq!(bits(&[0b00001111, 0xff]), 12); + assert_eq!(bits(&[0b00000111, 0xff]), 11); + assert_eq!(bits(&[0b00000011, 0xff]), 10); + assert_eq!(bits(&[0b00000001, 0xff]), 9); + assert_eq!(bits(&[0b00000000, 0xff]), 8); + + assert_eq!(bits(&[0b11111111]), 1); + assert_eq!(bits(&[0b11111110]), 2); + assert_eq!(bits(&[0b11111100]), 3); + assert_eq!(bits(&[0b11111000]), 4); + assert_eq!(bits(&[0b11110000]), 5); + assert_eq!(bits(&[0b11100000]), 6); + assert_eq!(bits(&[0b11000000]), 7); + assert_eq!(bits(&[0b10000000]), 8); + + assert_eq!(bits(&[0b11111111, 0]), 9); + assert_eq!(bits(&[0b11111110, 0]), 10); + assert_eq!(bits(&[0b11111100, 0]), 11); + assert_eq!(bits(&[0b11111000, 0]), 12); + assert_eq!(bits(&[0b11110000, 0]), 13); + assert_eq!(bits(&[0b11100000, 0]), 14); + assert_eq!(bits(&[0b11000000, 0]), 15); + assert_eq!(bits(&[0b10000000, 0]), 16); +} diff --git a/src/number_traits.rs b/src/number_traits.rs new file mode 100644 index 000000000..7a5f6bc3f --- /dev/null +++ b/src/number_traits.rs @@ -0,0 +1,19 @@ +pub trait NumberTraits { + fn from_unsigned_bytes_be(v: &[u8]) -> Self; + fn to_signed_bytes(&self) -> Vec; + fn zero() -> Self; + fn from_u8(v: &[u8]) -> Self; + fn to_u64(&self) -> u64; + fn div_mod_floor(&self, denominator: &Self) -> (Self, Self) + where + Self: Sized; + fn mod_floor(&self, denominator: &Self) -> Self; + fn equal(&self, other: i64) -> bool; + fn not_equal(&self, other: i64) -> bool; + fn greater_than(&self, other: u64) -> bool; +} + +#[cfg(test)] +pub trait TestNumberTraits { + fn from_str_radix(s: &str, radix: i32) -> Self; +} diff --git a/src/op_utils.rs b/src/op_utils.rs index 9fcb5a3cc..6502a5cf4 100644 --- a/src/op_utils.rs +++ b/src/op_utils.rs @@ -1,6 +1,7 @@ use crate::err_utils::err; use crate::node::Node; -use crate::number::{number_from_u8, Number}; +use crate::number::Number; +use crate::number_traits::NumberTraits; use crate::reduction::EvalErr; pub fn check_arg_count(args: &Node, expected: usize, name: &str) -> Result<(), EvalErr> { @@ -87,7 +88,7 @@ pub fn two_ints(args: &Node, op_name: &str) -> Result<(Number, usize, Number, us let a1 = args.rest()?.first()?; let n0 = int_atom(&a0, op_name)?; let n1 = int_atom(&a1, op_name)?; - Ok((number_from_u8(n0), n0.len(), number_from_u8(n1), n1.len())) + Ok((Number::from_u8(n0), n0.len(), Number::from_u8(n1), n1.len())) } fn u32_from_u8_impl(buf: &[u8], signed: bool) -> Option { diff --git a/src/test_ops.rs b/src/test_ops.rs index b2b574470..8711136d5 100644 --- a/src/test_ops.rs +++ b/src/test_ops.rs @@ -7,9 +7,9 @@ use crate::more_ops::{ op_point_add, op_pubkey_for_exp, op_sha256, op_softfork, op_strlen, op_substr, op_subtract, }; use crate::number::{ptr_from_number, Number}; +use crate::number_traits::TestNumberTraits; use crate::reduction::{Reduction, Response}; use hex::FromHex; -use num_traits::Num; use std::collections::HashMap; static TEST_CASES: &str = r#" @@ -743,7 +743,7 @@ fn parse_atom(a: &mut Allocator, v: &str) -> NodePtr { } if v.starts_with("-") || "0123456789".contains(v.get(0..1).unwrap()) { - let num = Number::from_str_radix(v, 10).unwrap(); + let num = Number::from_str_radix(v, 10); return ptr_from_number(a, &num).unwrap(); }