~samhsmith/quatmaths

3a1ae311495c6f0ca88492149edda0e920a7a14f — Sam H. Smith 2 years ago cc1e3a1
quaternion multiplication, conjugate and inverse
1 files changed, 95 insertions(+), 7 deletions(-)

M src/lib.rs
M src/lib.rs => src/lib.rs +95 -7
@@ 1,7 1,8 @@

use num_traits::float::Float;

struct Quaternion<T> where T : Float
#[derive(Clone, Copy, PartialEq)]
pub struct Quaternion<T> where T : Float
{
    w : T,
    x : T,


@@ 23,20 24,107 @@ impl<T> std::fmt::Display for Quaternion<T> where T : Float + std::fmt::Display
    }
}

impl<T> std::fmt::Debug for Quaternion<T> where T : Float + std::fmt::Debug
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error>
    {
        write!(f, "Quat{}[{:?}, {:?}, {:?}, {:?}]", std::any::type_name::<T>(), self.w, self.x, self.y, self.z)
    }
}

impl<T> Quaternion<T> where T : Float
{
    pub fn new(w : T, x : T, y : T, z : T) -> Self
    {
        Self { w, x, y, z }
    }
    pub fn approx_equal(a : Quaternion<T>, b : Quaternion<T>, close_enough : T) -> bool
    {
        (a.w - b.w).abs() < close_enough &&
        (a.x - b.x).abs() < close_enough &&
        (a.y - b.y).abs() < close_enough &&
        (a.z - b.z).abs() < close_enough
    }
    pub fn mul(a : Quaternion<T>, b : Quaternion<T>) -> Self
    {
        let mut c = Quaternion::<T>::default();
        c.w = a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z;
        c.x = a.x * b.w + a.w * b.x - a.z * b.y + a.y * b.z;
        c.y = a.y * b.w + a.w * b.y - a.x * b.z + a.z * b.x;
        c.z = a.z * b.w + a.w * b.z - a.y * b.x + a.x * b.y;
        c
    }
    pub fn mul_scalar(q : Quaternion<T>, s : T) -> Self
    {
        let mut f = Quaternion::<T>::default();
        f.w = q.w * s;
        f.x = q.x * s;
        f.y = q.y * s;
        f.z = q.z * s;
        f
    }
    pub fn div_scalar(q : Quaternion<T>, s : T) -> Self
    {
        let mut f = Quaternion::<T>::default();
        f.w = q.w / s;
        f.x = q.x / s;
        f.y = q.y / s;
        f.z = q.z / s;
        f
    }
    pub fn conj(mut self) -> Self
    {
        self.x = -self.x;
        self.y = -self.y;
        self.z = -self.z;
        self
    }
    pub fn inverse(mut self) -> Self
    {
        // the inverse is equal to the conjugate divided by the conjugate multiplied by the original   q' = q*/(q* x q)
        let mut divide_by = Quaternion::mul(self, self.conj()).w;
        Quaternion::div_scalar(self.conj(), divide_by)
    }
}

#[cfg(test)]
mod tests {
    use crate::Quaternion;

    #[test]
    fn it_works() {
        let result = 2 + 2;
        assert_eq!(result, 4);
    }

    #[test]
    fn test_display_and_default_identity()
    {
        assert_eq!("Quatf32[1, 0, 0, 0]", format!("{}", Quaternion::<f32>::default()));
        assert_eq!("Quatf64[1, 0, 0, 0]", format!("{}", Quaternion::<f64>::default()));
    }

    #[test]
    fn test_display_and_constructors()
    {
        assert_eq!("Quatf32[1, 2.1, 3, 4]", format!("{}", Quaternion::<f32>::new(1.0, 2.1, 3.0, 4.0)));
    }

    #[test]
    fn test_quaternion_multiplication()
    {
        assert_eq!(Quaternion::<f64>::mul(Quaternion::new(1.0,2.0,3.0,4.0), Quaternion::new(4.0,3.0,2.0,1.0)),
            Quaternion::<f64>::new(-12.0,6.0,24.0,12.0));

        assert_eq!(Quaternion::<f64>::mul(Quaternion::new(4.0,3.0,2.0,1.0), Quaternion::new(1.0,2.0,3.0,4.0)),
            Quaternion::<f64>::new(-12.0,16.0,4.0,22.0));

        assert_eq!(Quaternion::<f32>::new(1.0, 2.0, 3.0, 4.0), Quaternion::<f32>::mul_scalar(Quaternion::new(2.0, 4.0, 6.0, 8.0), 0.5));
    }

    #[test]
    fn test_quaternion_conjugate_and_inverse()
    {
        // multiplying a quaternion by it's conjugate gives us w² + x² + y² + z² as a real number in the w slot
        assert_eq!(Quaternion::<f32>::mul(Quaternion::new(4.0,3.0,2.0,1.0), Quaternion::new(4.0,3.0,2.0,1.0).conj()), Quaternion::new(30.0,0.0,0.0,0.0));

        // multiplying with a quaternion and then it's inverse should result in no change save for floating point errors
        assert!(Quaternion::approx_equal(
            Quaternion::<f64>::mul(Quaternion::mul(Quaternion::new(4.0,3.0,2.0,1.0), Quaternion::new(1.0,2.0,3.0,4.0)), Quaternion::new(1.0,2.0,3.0,4.0).inverse()),
            Quaternion::<f64>::new(4.0,3.0,2.0,1.0), 0.01));
    }
}