diff --git a/compiler/tests/integers/int_macro.rs b/compiler/tests/integers/int_macro.rs index d9760fe4ba..a493ee949c 100644 --- a/compiler/tests/integers/int_macro.rs +++ b/compiler/tests/integers/int_macro.rs @@ -118,7 +118,7 @@ macro_rules! test_int { // make sure that we can calculate the inverse of each number // Leo signed integer division is non-wrapping. Thus attempting to calculate a // division result that wraps should be ignored here. - if a.checked_neg().is_none() || b.checked_neg().is_none() { + if a.checked_neg().is_none() { continue; } diff --git a/gadgets/src/signed_integer/arithmetic/div.rs b/gadgets/src/signed_integer/arithmetic/div.rs index eae60bd9d5..d834c5ac11 100644 --- a/gadgets/src/signed_integer/arithmetic/div.rs +++ b/gadgets/src/signed_integer/arithmetic/div.rs @@ -9,7 +9,6 @@ use crate::{ Int64, Int8, }; - use snarkos_models::{ curves::PrimeField, gadgets::{ @@ -24,7 +23,7 @@ use snarkos_models::{ }; macro_rules! div_int_impl { - ($($gadget:ident)*) => ($( + ($($gadget:ident),*) => ($( impl Div for $gadget { type ErrorType = SignedIntegerError; @@ -86,15 +85,52 @@ macro_rules! div_int_impl { &allocated_zero, )?; + // if the numerator is 0, return 0 let self_is_zero = Boolean::Constant(self.eq(&Self::constant(0 as <$gadget as Int>::IntegerType))); + // if other is the minimum number, the result will be zero or one + // -128 / -128 = 1 + // x / -128 = 0 fractional result rounds to 0 + let min = Self::constant(<$gadget as Int>::IntegerType::MIN); + let other_is_min = other.evaluate_equal( + &mut cs.ns(|| "other_min_check"), + &min + )?; + let self_is_min = self.evaluate_equal( + &mut cs.ns(|| "self_min_check"), + &min + )?; + let both_min = Boolean::and( + &mut cs.ns(|| "both_min"), + &other_is_min, + &self_is_min + )?; + + + // if other is the minimum, set other to -1 so the calculation will not fail + let negative_one = allocated_one.neg(&mut cs.ns(|| "allocated_one"))?; + let a_valid = min.add(&mut cs.ns(||"a_valid"), &allocated_one); + let a_set = Self::conditionally_select( + &mut cs.ns(|| "a_set"), + &self_is_min, + &a_valid?, + &self + )?; + + let b_set = Self::conditionally_select( + &mut cs.ns(|| "b_set"), + &other_is_min, + &negative_one, + &other + )?; + // If the most significant bits of both numbers are equal, the quotient will be positive - let a_msb = self.bits.last().unwrap(); let b_msb = other.bits.last().unwrap(); + let a_msb = self.bits.last().unwrap(); let positive = a_msb.evaluate_equal(cs.ns(|| "compare_msb"), &b_msb)?; // Get the absolute value of each number - let a_comp = self.neg(&mut cs.ns(|| "a_neg"))?; + let a_comp = a_set.neg(&mut cs.ns(|| "a_neg"))?; let a = Self::conditionally_select( &mut cs.ns(|| "a_abs"), &a_msb, @@ -102,12 +138,12 @@ macro_rules! div_int_impl { &self )?; - let b_comp = other.neg(&mut cs.ns(|| "b_neg"))?; + let b_comp = b_set.neg(&mut cs.ns(|| "b_neg"))?; let b = Self::conditionally_select( &mut cs.ns(|| "b_abs"), &b_msb, &b_comp, - &other, + &b_set, )?; let mut q = zero.clone(); @@ -142,13 +178,11 @@ macro_rules! div_int_impl { &b )?; - let sub = r.sub( &mut cs.ns(|| format!("subtract_divisor_{}", i)), &b ); - r = Self::conditionally_select( &mut cs.ns(|| format!("subtract_or_same_{}", i)), &can_sub, @@ -182,6 +216,22 @@ macro_rules! div_int_impl { &q_neg, )?; + // set to zero if we know result is fractional + q = Self::conditionally_select( + &mut cs.ns(|| "fraction"), + &other_is_min, + &allocated_zero, + &q, + )?; + + // set to one if we know result is division of the minimum number by itself + q = Self::conditionally_select( + &mut cs.ns(|| "one_result"), + &both_min, + &allocated_one, + &q, + )?; + Ok(Self::conditionally_select( &mut cs.ns(|| "self_or_quotient"), &self_is_zero, @@ -193,4 +243,4 @@ macro_rules! div_int_impl { )*) } -div_int_impl!(Int8 Int16 Int32 Int64 Int128); +div_int_impl!(Int8, Int16, Int32, Int64, Int128); diff --git a/gadgets/tests/signed_integer/i8.rs b/gadgets/tests/signed_integer/i8.rs index 3ba64d99d1..9801f98e15 100644 --- a/gadgets/tests/signed_integer/i8.rs +++ b/gadgets/tests/signed_integer/i8.rs @@ -282,12 +282,16 @@ fn test_int8_div_constants() { for _ in 0..1000 { let mut cs = TestConstraintSystem::::new(); - let a: i8 = rng.gen_range(-127i8, i8::MAX); - let b: i8 = rng.gen_range(-127i8, i8::MAX); + let a: i8 = rng.gen(); + let b: i8 = rng.gen(); + + if a.checked_neg().is_none() { + return; + } let expected = match a.checked_div(b) { Some(valid) => valid, - None => continue, + None => return, }; let a_bit = Int8::constant(a); @@ -308,12 +312,16 @@ fn test_int8_div() { for _ in 0..100 { let mut cs = TestConstraintSystem::::new(); - let a: i8 = rng.gen_range(-127i8, i8::MAX); - let b: i8 = rng.gen_range(-127i8, i8::MAX); + let a: i8 = rng.gen(); + let b: i8 = rng.gen(); + + if a.checked_neg().is_none() { + continue; + } let expected = match a.checked_div(b) { Some(valid) => valid, - None => continue, + None => return, }; let a_bit = Int8::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();