~mht/cra

e4ca49b32bbf6cfa17a2c76764e74e9ce1592aac — Martin Hafskjold Thoresen 2 years ago 8ad9e86
Don't needlessly allocate in `column_add`.

Probably still a  bit of allocation going on, when the simplices vectors
are too small.
2 files changed, 153 insertions(+), 173 deletions(-)

M cra/Cargo.toml
M cra/src/main.rs
M cra/Cargo.toml => cra/Cargo.toml +3 -0
@@ 5,3 5,6 @@ authors = ["Martin Hafskjold Thoresen <git@mht.technology>"]

[dependencies]
time = "0.1.42"

[profile.release]
debug = true

M cra/src/main.rs => cra/src/main.rs +150 -173
@@ 1,5 1,3 @@
#![allow(dead_code)]

use std::cmp::Ordering::*;
use std::collections::HashSet;
use std::fmt::Write as fmtWrite;


@@ 12,8 10,8 @@ const R_CUTOFF: f64 = 300.0;

fn f64_eq(a: f64, b: f64) -> bool {
    let u = 1_000_000.0;
    let a = (a * u) as usize;
    let b = (b * u) as usize;
    let a = (a * u).round() as usize;
    let b = (b * u).round() as usize;
    a == b
    // limit is ~0.00119
    // (a - b).abs() < 10_000.0 * std::f64::EPSILON


@@ 135,14 133,14 @@ impl Statistics {
            format!("{}", num_iters),
        ));

        let xor_cost =
        let column_add_cost =
            self.add_size_sum.iter().sum::<usize>() as f64 / self.add_size_sum.len() as f64;
        print_pairs.push((
            "Average cost of one column addition",
            format!("{}", xor_cost),
            format!("{}", column_add_cost),
        ));

        let adds = (xor_cost * self.col_adds as f64 + self.pops as f64) as usize;
        let adds = (column_add_cost * self.col_adds as f64 + self.pops as f64) as usize;
        print_pairs.push(("Estimate of number of adds", thousand_split(adds)));

        if self.ex_reductions.len() > 0 {


@@ 202,26 200,6 @@ fn thousand_split(mut num: usize) -> String {
    out
}

fn eprint_thousand_split(mut num: usize) {
    let mut nums = vec![];
    while num > 0 {
        nums.push(num % 1_000);
        num /= 1_000;
    }
    for i in (1..nums.len()).rev() {
        if i == nums.len() - 1 {
            eprint!("{},", nums[i]);
        } else {
            eprint!("{:03},", nums[i]);
        }
    }
    if nums.len() == 1 {
        eprint!("{}", nums[0]);
    } else {
        eprint!("{:03}", nums[0]);
    }
}

/// One simplex
#[derive(Debug, Clone)]
pub struct Simplex {


@@ 248,8 226,7 @@ pub struct Persistence {
    points: Vec<[f64; 2]>,
}

type Error = Box<std::error::Error>;
pub fn read_input_stdin2() -> Result<Persistence, Error> {
pub fn read_input_stdin2() -> Result<Persistence, Box<std::error::Error>> {
    let stdin = std::io::stdin();
    let mut lines = BufReader::new(stdin.lock()).lines();



@@ 506,9 483,10 @@ pub fn reduce(p: &Persistence, exhaustive: bool, stats: &mut Statistics) -> Vec<
                        // from the back and check if there is a simplex with that low.
                        let mut iter = 0;
                        'search: loop {
                            let list_len = simplex_with_low[low].get().len();
                            let this = simplex_with_low[low].get();
                            let list_len = this.len();
                            for (iter_i, face_i) in (0..list_len).rev().enumerate() {
                                let this_index = simplex_with_low[low].get()[face_i];
                                let this_index = this[face_i];
                                assert!(this_index != low);
                                let other = &simplex_with_low[this_index];
                                if other.is_null() {


@@ 528,11 506,10 @@ pub fn reduce(p: &Persistence, exhaustive: bool, stats: &mut Statistics) -> Vec<
                                    ours.remove(face_i); // Sort of a `pop`
                                }
                                if other.is_list() {
                                    xor(ours, other.get(), stats);
                                    column_add(ours, other.get(), stats);
                                }
                                iter += 1;
                                stats.ex_search(iter_i + 1);
                                // if ours.is_empty() { break; }
                                continue 'search;
                            }
                            stats.ex_search(list_len);


@@ 541,7 518,7 @@ pub fn reduce(p: &Persistence, exhaustive: bool, stats: &mut Statistics) -> Vec<
                        }
                        let curr = &mut simplex_with_low[low];
                        if curr.get().len() == 0 {
                            curr.implicit();
                            *curr = Implicit;
                        }
                    }
                }


@@ 554,7 531,7 @@ pub fn reduce(p: &Persistence, exhaustive: bool, stats: &mut Statistics) -> Vec<
                // Add `low` to `j`.
                let other = simplex_with_low[low].get();
                assert!(*other.iter().max().unwrap() <= low);
                xor(&mut current_simplex, other, stats);
                column_add(&mut current_simplex, other, stats);
                pop(&mut current_simplex, stats); // xor out the implicit value as well.
            }



@@ 624,14 601,6 @@ impl IndList {
        }
    }

    fn null(&mut self) {
        *self = IndList::Null;
    }

    fn implicit(&mut self) {
        *self = IndList::Null;
    }

    fn get(&self) -> &Vec<usize> {
        match self {
            List(ref v) => v,


@@ 652,13 621,21 @@ fn pop(this: &mut Vec<usize>, stats: &mut Statistics) {
    stats.pop();
}

fn xor(this: &mut Vec<usize>, other: &Vec<usize>, stats: &mut Statistics) {
fn column_add(this: &mut Vec<usize>, other: &Vec<usize>, stats: &mut Statistics) {
    // TODO: hehe
    static mut BUFFER: Option<Vec<usize>> = None;
    let buffer: &mut Vec<usize> = unsafe {
        if BUFFER.is_none() {
            BUFFER = Some(Vec::with_capacity(1_000));
        }
        BUFFER.as_mut().unwrap()
    };
    assert_eq!(buffer.len(), 0);
    stats.col_adds += 1;
    stats.add_sizes(this.len(), other.len());
    // For now, let's just walk through both vecs trying to find matches, add new ones into
    // self, and sort at the end.
    // PERF: Allocation in here! Can we do without?
    let mut new = Vec::new();
    let mut our_i = 0;
    let mut their_i = 0;
    let our_last = this.len();


@@ 668,20 645,22 @@ fn xor(this: &mut Vec<usize>, other: &Vec<usize>, stats: &mut Statistics) {
    while our_i < our_last && their_i < their_last {
        i += 1;
        if this[our_i] < other[their_i] {
            new.push(this[our_i]);
            buffer.push(this[our_i]);
            our_i += 1;
        } else if this[our_i] == other[their_i] {
            our_i += 1;
            their_i += 1;
        } else {
            new.push(other[their_i]);
            buffer.push(other[their_i]);
            their_i += 1;
        }
    }
    stats.add_iters(i);
    new.extend(&other[their_i..]);
    new.extend(&this[our_i..]);
    *this = new;
    buffer.extend(&other[their_i..]);
    buffer.extend(&this[our_i..]);
    this.clear();
    this.extend(&*buffer);
    buffer.clear();
}

#[derive(Copy, Clone, Debug)]


@@ 723,129 702,6 @@ impl Point {
    }
}

fn main() {
    let mut persistence = read_input_stdin2().unwrap();

    for (i, s) in persistence.simplices.iter().enumerate() {
        assert_eq!(i, s.j);
    }

    persistence.simplices.sort_by(|a, b| {
        if b.faces.contains(&a.j) {
            Less
        } else if a.faces.contains(&b.j) {
            Greater
        } else if f64_eq(a.r_when_born, b.r_when_born) {
            a.dim().cmp(&b.dim())
        } else {
            a.r_when_born.partial_cmp(&b.r_when_born).unwrap()
        }
    });

    // Map j to sorted index.
    let sorted_index_of_j = {
        let mut v = (0..persistence.simplices.len()).collect::<Vec<_>>();
        for (i, s) in persistence.simplices.iter().enumerate() {
            v[s.j] = i;
        }
        v
    };

    // Change all `r-values` to be in the sorted format, such that `simplex[a].j == a`.
    for (j, s) in persistence.simplices.iter_mut().enumerate() {
        for face in s.faces.iter_mut() {
            *face = sorted_index_of_j[*face];
        }
        s.j = j;
    }

    for (j, s) in persistence.simplices.iter().enumerate() {
        if let Some(&face) = s.faces.iter().max() {
            if face >= j {
                eprintln!("face is after simplex! {} >= {}", face, j);
                eprintln!("{:?}", persistence.simplices[face]);
                eprintln!("{:?}", persistence.simplices[j]);
                eprintln!(
                    "{:#?}",
                    &persistence.simplices
                        [(j.saturating_sub(3)..((face + 4).min(persistence.simplices.len())))]
                );
                panic!("The ordering is not right!");
            }
        }
    }

    let mut r_stats = Statistics::new();
    r_stats.time();
    let _pairings = reduce(&persistence, false, &mut r_stats);

    eprintln!("## Statistics for the =Regular= variant ##");
    r_stats.eprint_time();
    r_stats.eprint_avg();
    eprintln!("\n");

    let mut e_stats = Statistics::new();
    e_stats.time();
    let pairings = reduce(&persistence, true, &mut e_stats);

    eprintln!("## Statistics for the =Exhaustive= variant ##");
    e_stats.eprint_time();
    e_stats.eprint_avg();
    eprintln!("\n");

    output_histogram(
        &e_stats.ex_searches,
        "Iters for finding k=low(i) for any i",
        "ex_searches.pdf",
    );
    output_histogram(
        &e_stats.ex_reductions,
        "Iters for exhaustively reducing a column",
        "ex_reductions.pdf",
    );

    output_2histogram(
        &r_stats.add_iters,
        &e_stats.add_iters,
        "Loop iterations in 'xor'",
        "add_iters.pdf",
    );
    output_2histogram(
        &r_stats.add_size_sum,
        &e_stats.add_size_sum,
        "Cost estimate for column addition",
        "add_size_sum.pdf",
    );
    output_2histogram(
        &r_stats.num_iters,
        &e_stats.num_iters,
        "Number of iterations for reducing a column",
        "num_iters.pdf",
    );

    // eprintln!("{:#?}", persistence.simplices);
    // eprintln!("{:?}", pairings);

    for (a, b) in pairings {
        let birth = persistence.simplices[a].r_when_born;
        let death = persistence.simplices[b].r_when_born;
        // eprintln!("{}", birth);
        if birth > death && !f64_eq(death, birth) {
            panic!(
                "Birth cannot be _after_ death! {} {} diff={} {:?} {:?}",
                death,
                birth,
                (death - birth),
                persistence.simplices[a],
                persistence.simplices[b],
            );
        }
        println!("pair({}, {})", birth, death);
    }

    output_svg(&persistence);
}

fn output_svg(persistence: &Persistence) {
    let mut f = std::io::BufWriter::new(std::fs::File::create("lol.svg").unwrap());



@@ 1089,6 945,127 @@ plot 'kek.freq' using ($0 - 0.25):($1+1) with boxes lc rgb"gray20" title "Regula
    run_gnuplot(&plot_script, out_file);
}

fn main() {
    let mut persistence = read_input_stdin2().unwrap();

    for (i, s) in persistence.simplices.iter().enumerate() {
        assert_eq!(i, s.j);
    }

    persistence.simplices.sort_by(|a, b| {
        if b.faces.contains(&a.j) {
            Less
        } else if a.faces.contains(&b.j) {
            Greater
        } else if f64_eq(a.r_when_born, b.r_when_born) {
            a.dim().cmp(&b.dim())
        } else {
            a.r_when_born.partial_cmp(&b.r_when_born).unwrap()
        }
    });

    // Map j to sorted index.
    let sorted_index_of_j = {
        let mut v = (0..persistence.simplices.len()).collect::<Vec<_>>();
        for (i, s) in persistence.simplices.iter().enumerate() {
            v[s.j] = i;
        }
        v
    };

    // Change all `r-values` to be in the sorted format, such that `simplex[a].j == a`.
    for (j, s) in persistence.simplices.iter_mut().enumerate() {
        for face in s.faces.iter_mut() {
            *face = sorted_index_of_j[*face];
        }
        s.j = j;
    }

    for (j, s) in persistence.simplices.iter().enumerate() {
        if let Some(&face) = s.faces.iter().max() {
            if face >= j {
                eprintln!("face is after simplex! {} >= {}", face, j);
                eprintln!("{:?}", persistence.simplices[face]);
                eprintln!("{:?}", persistence.simplices[j]);
                eprintln!(
                    "{:#?}",
                    &persistence.simplices
                        [(j.saturating_sub(3)..((face + 4).min(persistence.simplices.len())))]
                );
                panic!("The ordering is not right!");
            }
        }
    }


    let mut e_stats = Statistics::new();
    e_stats.time();
    let pairings = reduce(&persistence, true, &mut e_stats);

    eprintln!("## Statistics for the =Exhaustive= variant ##");
    e_stats.eprint_time();
    e_stats.eprint_avg();
    eprintln!("\n");

    let mut r_stats = Statistics::new();
    r_stats.time();
    let pairings = reduce(&persistence, false, &mut r_stats);

    eprintln!("## Statistics for the =Regular= variant ##");
    r_stats.eprint_time();
    r_stats.eprint_avg();
    eprintln!("\n");

    output_histogram(
        &e_stats.ex_searches,
        "Iters for finding k=low(i) for any i",
        "ex_searches.pdf",
    );
    output_histogram(
        &e_stats.ex_reductions,
        "Iters for exhaustively reducing a column",
        "ex_reductions.pdf",
    );

    output_2histogram(
        &r_stats.add_iters,
        &e_stats.add_iters,
        "Loop iterations in 'column\\_add'",
        "add_iters.pdf",
    );
    output_2histogram(
        &r_stats.add_size_sum,
        &e_stats.add_size_sum,
        "Cost estimate for column addition",
        "add_size_sum.pdf",
    );
    output_2histogram(
        &r_stats.num_iters,
        &e_stats.num_iters,
        "Number of iterations for reducing a column",
        "num_iters.pdf",
    );

    for (a, b) in pairings {
        let birth = persistence.simplices[a].r_when_born;
        let death = persistence.simplices[b].r_when_born;
        // eprintln!("{}", birth);
        if birth > death && !f64_eq(death, birth) {
            panic!(
                "Birth cannot be _after_ death! {} {} diff={} {:?} {:?}",
                death,
                birth,
                (death - birth),
                persistence.simplices[a],
                persistence.simplices[b],
            );
        }
        println!("pair({}, {})", birth, death);
    }

    output_svg(&persistence);
}

#[cfg(test)]
mod test {
    use super::*;