division i type wip

This commit is contained in:
collin 2020-07-14 20:17:37 -07:00
parent eb5ab1fbe1
commit f52dd77373
7 changed files with 210 additions and 10 deletions

View File

@ -8,6 +8,9 @@ pub enum IntegerError {
#[error("Integer overflow")]
Overflow,
#[error("Division by zero")]
DivisionByZero,
#[error("{}", _0)]
SynthesisError(#[from] SynthesisError),
}

View File

@ -1,5 +1,26 @@
use crate::{errors::IntegerError, Int16, Int32, Int64, Int8};
use snarkos_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem};
use crate::{
binary::ComparatorGadget,
errors::IntegerError,
signed_integer::arithmetic::*,
Int,
Int128,
Int16,
Int32,
Int64,
Int8,
};
use snarkos_models::{
curves::PrimeField,
gadgets::{
r1cs::ConstraintSystem,
utilities::{
alloc::AllocGadget,
boolean::{AllocatedBit, Boolean},
select::CondSelectGadget,
},
},
};
/// Division for a signed integer gadget
pub trait Div<Rhs = Self>
@ -7,17 +28,151 @@ where
Self: std::marker::Sized,
{
#[must_use]
fn div<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<(), IntegerError>;
fn div<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<Self, IntegerError>;
}
macro_rules! div_int_impl {
($($t:ty)*) => ($(
impl Div for $t {
fn div<F: PrimeField, CS: ConstraintSystem<F>>(&self, _cs: CS, _other: &Self) -> Result<(), IntegerError> {
Ok(())
($($gadget:ident)*) => ($(
impl Div for $gadget {
fn div<F: PrimeField, CS: ConstraintSystem<F>>(
&self,
mut cs: CS,
other: &Self
) -> Result<Self, IntegerError> {
// N / D pseudocode:
//
// if D = 0 then error(DivisionByZeroException) end
// positive = msb(N) == msb(D) -- if msb's equal, return positive result
// Q := 0 -- Initialize quotient and remainder to zero
// R := 0
// for i := n 1 .. 0 do -- Where n is number of bits in N
// R := R << 1 -- Left-shift R by 1 bit
// R(0) := N(i) -- Set the least-significant bit of R equal to bit i of the numerator
// if R ≥ D then
// R := R D
// Q(i) := 1
// end
// end
// if positive { -- positive result
// Q
// }
// !Q -- negative result
if other.eq(&Self::constant(0 as <$gadget as Int>::IntegerType)) {
return Err(IntegerError::DivisionByZero);
}
let is_constant = Boolean::constant(Self::result_is_constant(&self, &other));
let allocated_true = Boolean::from(AllocatedBit::alloc(&mut cs.ns(|| "true"), || Ok(true)).unwrap());
let true_bit = Boolean::conditionally_select(
&mut cs.ns(|| "constant_or_allocated_true"),
&is_constant,
&Boolean::constant(true),
&allocated_true,
)?;
let allocated_one = Self::alloc(&mut cs.ns(|| "one"), || Ok(1 as <$gadget as Int>::IntegerType))?;
let one = Self::conditionally_select(
&mut cs.ns(|| "constant_or_allocated_1"),
&is_constant,
&Self::constant(1 as <$gadget as Int>::IntegerType),
&allocated_one,
)?;
let allocated_zero = Self::alloc(&mut cs.ns(|| "zero"), || Ok(0 as <$gadget as Int>::IntegerType))?;
let zero = Self::conditionally_select(
&mut cs.ns(|| "constant_or_allocated_0"),
&is_constant,
&Self::constant(0 as <$gadget as Int>::IntegerType),
&allocated_zero,
)?;
// If the most significant bits of both numbers are equal, the quotient will be positive
let a_msb = self.bits.last().unwrap();
let b_msb = other.bits.last().unwrap();
let positive = Boolean::and(cs.ns(|| "compare msb"), &a_msb, &b_msb)?;
let self_is_zero = Boolean::Constant(self.eq(&Self::constant(0 as <$gadget as Int>::IntegerType)));
let mut q = zero.clone();
let mut r = zero.clone();
for (i, bit) in self.bits.iter().rev().enumerate() {
if i == 0 {
// skip the sign bit
continue;
}
// Left shift remainder by 1
r = r.add(
&mut cs.ns(|| format!("shift_left_{}", i)),
&r
)?;
// Set the least-significant bit of remainder to bit i of the numerator
let r_new = r.add(
&mut cs.ns(|| format!("set_remainder_bit_{}", i)),
&one.clone(),
)?;
r = Self::conditionally_select(
&mut cs.ns(|| format!("increment_or_remainder_{}", i)),
&bit,
&r_new,
&r
)?;
let can_sub = r.greater_than_or_equal(
&mut cs.ns(|| format!("compare_remainder_{}", i)),
other
)?;
let sub = r.sub(
&mut cs.ns(|| format!("subtract_divisor_{}", i)),
other
)?;
r = Self::conditionally_select(
&mut cs.ns(|| format!("subtract_or_same_{}", i)),
&can_sub,
&sub,
&r
)?;
let index = <$gadget as Int>::SIZE -1 -i as usize;
let bit_value = (1 as <$gadget as Int>::IntegerType) << (index as <$gadget as Int>::IntegerType);
let mut q_new = q.clone();
q_new.bits[index] = true_bit.clone();
q_new.value = Some(q_new.value.unwrap() + bit_value);
q = Self::conditionally_select(
&mut cs.ns(|| format!("set_bit_or_same_{}", i)),
&can_sub,
&q_new,
&q,
)?;
}
let q_neg = q.twos_comp(&mut cs.ns(|| "twos comp"))?;
q = Self::conditionally_select(
&mut cs.ns(|| "positive or negative"),
&positive,
&q,
&q_neg,
)?;
Ok(Self::conditionally_select(
&mut cs.ns(|| "self_or_quotient"),
&self_is_zero,
self,
&q
)?)
}
}
)*)
}
div_int_impl!(Int8 Int16 Int32 Int64);
div_int_impl!(Int8 Int16 Int32 Int64 Int128);

View File

@ -8,6 +8,8 @@ pub trait Int: Debug + Clone {
fn one() -> Self;
fn zero() -> Self;
/// Returns true if all bits in this `Int` are constant
fn is_constant(&self) -> bool;
@ -59,6 +61,10 @@ macro_rules! int_impl {
Self::constant(1 as $type_)
}
fn zero() -> Self {
Self::constant(0 as $type_)
}
fn is_constant(&self) -> bool {
let mut constant = true;

View File

@ -34,6 +34,14 @@ macro_rules! eq_gadget_impl {
Ok(result)
}
}
impl PartialEq for $gadget {
fn eq(&self, other: &Self) -> bool {
!self.value.is_none() && !other.value.is_none() && self.value == other.value
}
}
impl Eq for $gadget {}
)*)
}

View File

@ -2,5 +2,5 @@
pub mod eq;
pub use self::eq::*;
pub mod lt;
pub use self::lt::*;
pub mod cmp;
pub use self::cmp::*;

View File

@ -272,3 +272,31 @@ fn test_int8_mul() {
assert!(!cs.is_satisfied());
}
}
#[test]
fn test_int8_div_constants() {
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
for _ in 0..1 {
let mut cs = TestConstraintSystem::<Fr>::new();
let a: i8 = rng.gen();
let b: i8 = rng.gen();
println!("{} / {}", a, b);
let expected = match a.checked_div(b) {
Some(valid) => valid,
None => continue,
};
let a_bit = Int8::constant(a);
let b_bit = Int8::constant(b);
let r = a_bit.div(cs.ns(|| "division"), &b_bit).unwrap();
assert!(r.value == Some(expected));
check_all_constant_bits(expected, r);
}
}