diff --git a/compiler/src/value/char/char.rs b/compiler/src/value/char/char.rs index dd0e294cd2..11eade5c91 100644 --- a/compiler/src/value/char/char.rs +++ b/compiler/src/value/char/char.rs @@ -23,14 +23,11 @@ use crate::{ use leo_ast::{InputValue, Span}; use snarkvm_fields::PrimeField; -use snarkvm_gadgets::{ - fields::FpGadget, - utilities::{ - bits::comparator::{ComparatorGadget, EvaluateLtGadget}, - boolean::Boolean, - eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget}, - select::CondSelectGadget, - }, +use snarkvm_gadgets::utilities::{ + bits::comparator::{ComparatorGadget, EvaluateLtGadget}, + boolean::Boolean, + eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget, NEqGadget}, + select::CondSelectGadget, }; use snarkvm_r1cs::{ConstraintSystem, SynthesisError}; @@ -96,7 +93,44 @@ impl ConditionalEqGadget for Char { } fn cost() -> usize { - 2 * as CondSelectGadget>::cost() + as ConditionalEqGadget>::cost() + } +} + +impl NEqGadget for Char { + fn enforce_not_equal>(&self, cs: CS, other: &Self) -> Result<(), SynthesisError> { + self.field.enforce_not_equal(cs, &other.field) + } + + fn cost() -> usize { + as NEqGadget>::cost() + } +} + +impl CondSelectGadget for Char { + fn conditionally_select>( + cs: CS, + cond: &Boolean, + first: &Self, + second: &Self, + ) -> Result { + let field = FieldType::::conditionally_select(cs, cond, &first.field, &second.field)?; + + if field == first.field { + return Ok(Char { + character: first.character, + field, + }); + } + + Ok(Char { + character: second.character, + field, + }) + } + + fn cost() -> usize { + as CondSelectGadget>::cost() } } diff --git a/compiler/src/value/field/field_type.rs b/compiler/src/value/field/field_type.rs index f2d35bc552..ddf6177304 100644 --- a/compiler/src/value/field/field_type.rs +++ b/compiler/src/value/field/field_type.rs @@ -28,7 +28,7 @@ use snarkvm_gadgets::{ alloc::AllocGadget, bits::comparator::{ComparatorGadget, EvaluateLtGadget}, boolean::Boolean, - eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget}, + eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget, NEqGadget}, select::CondSelectGadget, uint::UInt8, ToBitsBEGadget, @@ -37,8 +37,6 @@ use snarkvm_gadgets::{ }, }; use snarkvm_r1cs::{ConstraintSystem, SynthesisError}; - -use snarkvm_gadgets::utilities::eq::NEqGadget; use std::{borrow::Borrow, cmp::Ordering}; #[derive(Clone, Debug)] diff --git a/compiler/src/value/value.rs b/compiler/src/value/value.rs index 9a71e7f5d6..1bb94ed34b 100644 --- a/compiler/src/value/value.rs +++ b/compiler/src/value/value.rs @@ -148,7 +148,7 @@ impl<'a, F: PrimeField, G: GroupType> ConditionalEqGadget for ConstrainedV bool_1.conditional_enforce_equal(cs, bool_2, condition) } (ConstrainedValue::Char(char_1), ConstrainedValue::Char(char_2)) => { - char_1.field.conditional_enforce_equal(cs, &char_2.field, condition) + char_1.conditional_enforce_equal(cs, char_2, condition) } (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { field_1.conditional_enforce_equal(cs, field_2, condition) @@ -195,7 +195,7 @@ impl<'a, F: PrimeField, G: GroupType> CondSelectGadget for ConstrainedValu ConstrainedValue::Boolean(Boolean::conditionally_select(cs, cond, bool_1, bool_2)?) } (ConstrainedValue::Char(char_1), ConstrainedValue::Char(char_2)) => { - ConstrainedValue::Field(FieldType::conditionally_select(cs, cond, &char_1.field, &char_2.field)?) + ConstrainedValue::Char(Char::conditionally_select(cs, cond, char_1, char_2)?) } (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { ConstrainedValue::Field(FieldType::conditionally_select(cs, cond, field_1, field_2)?)