~luyu/ndwfc-rs

537ade54b48479c6373b122cff59d5fc2cb09cd7 — Luyu Cheng a month ago d8932a7 main
refactor: remove a clone of `wave.keys`
2 files changed, 59 insertions(+), 68 deletions(-)

M images/output.png
M src/wfc.rs
M images/output.png => images/output.png +0 -0
M src/wfc.rs => src/wfc.rs +59 -68
@@ 80,68 80,6 @@ impl<const N: usize> WaveFunctionCollapse<N> {
        one_hot(i, self.weights.len())
    }

    // Propogate changes from the given position.
    fn propagate(&mut self, source: &Coordinate<N>) {
        let mut stack = vec![source.clone()];
        while let Some(p) = stack.pop() {
            // Iterate all adjacent positions of `p`.
            for (index, orientation) in create_directions::<N>() {
                let mut q = p.clone();
                q.apply_direction(index, orientation);
                if let Some(pattern_p) = self.wave.get(&p) {
                    if let Some(ys) = self.wavefront.get(&q) {
                        let mut modified = false;
                        for (pattern_q, y) in ys.borrow_mut().iter_mut().enumerate() {
                            if *y > 0.0
                                && !neighborable(
                                    &self.rules,
                                    index,
                                    orientation,
                                    *pattern_p,
                                    pattern_q,
                                )
                            {
                                *y = 0.0;
                                modified = true;
                            }
                        }
                        if modified {
                            stack.push(q);
                        }
                    }
                } else if let Some(xs) = self.wavefront.get(&p).map(|xs| xs.clone()) {
                    if let Some(ys) = self.wavefront.get(&q) {
                        let mut modified = false;
                        for (pattern_q, prob_y) in ys.borrow_mut().iter_mut().enumerate() {
                            if *prob_y == 0.0 {
                                continue;
                            }
                            if !xs.borrow().iter().enumerate().any(|(pattern_p, prob_x)| {
                                *prob_x > 0.0
                                    && *prob_y > 0.0
                                    && neighborable(
                                        &self.rules,
                                        index,
                                        orientation,
                                        pattern_p,
                                        pattern_q,
                                    )
                            }) {
                                *prob_y = 0.0;
                                modified = true;
                            }
                        }
                        if modified {
                            stack.push(q);
                        }
                    }
                } else {
                    println!("Invalid propagation parameter.");
                }
            }
        }
    }

    pub fn readout(&mut self, collapse: bool) -> HashMap<Coordinate<N>, usize> {
        if !collapse {
            let mut result = HashMap::new();


@@ 186,10 124,8 @@ impl<const N: usize> WaveFunctionCollapse<N> {
            self.wavefront
                .insert(coordinate, RefCell::new(vec![1.0; self.weights.len()]));
        }
        // TODO: remove clones here.
        let wave_keys = self.wave.keys().cloned().collect::<Vec<_>>();
        for coordinate in wave_keys.iter() {
            self.propagate(coordinate);
        for coordinate in self.wave.keys() {
            propagate(&self.wave, &mut self.wavefront, &self.rules, &coordinate);
        }
    }



@@ 204,7 140,7 @@ impl<const N: usize> WaveFunctionCollapse<N> {
                //     *value.borrow_mut() = vec![1.0; self.weights.len()];
                // }
                // for coordinate in self.wavefront.keys() {
                //     self.propagate(&coordinate);
                //     propagate(&self.wave, &mut self.wavefront, &self.rules, &coordinate);
                // }
                println!("Entropy is not a number.");
                return false;


@@ 227,7 163,7 @@ impl<const N: usize> WaveFunctionCollapse<N> {
            let values = self.wavefront.get(&coordinate).unwrap().borrow();
            if let Some(wave) = collapse(values, &self.weights) {
                *self.wavefront.get(&coordinate).unwrap().borrow_mut() = wave;
                self.propagate(&coordinate);
                propagate(&self.wave, &mut self.wavefront, &self.rules, &coordinate);
            } else {
                println!("collapse failed at {}", coordinate);
            }


@@ 249,6 185,61 @@ impl<const N: usize> WaveFunctionCollapse<N> {
    }
}

// Propogate changes from the given position.
fn propagate<const N: usize>(
    wave: &HashMap<Coordinate<N>, usize>,
    wavefront: &mut HashMap<Coordinate<N>, RefCell<Vec<f32>>>,
    rules: &Vec<(usize, usize, usize)>,
    source: &Coordinate<N>,
) {
    let mut stack = vec![source.clone()];
    while let Some(p) = stack.pop() {
        // Iterate all adjacent positions of `p`.
        for (index, orientation) in create_directions::<N>() {
            let mut q = p.clone();
            q.apply_direction(index, orientation);
            if let Some(pattern_p) = wave.get(&p) {
                if let Some(ys) = wavefront.get(&q) {
                    let mut modified = false;
                    for (pattern_q, y) in ys.borrow_mut().iter_mut().enumerate() {
                        if *y > 0.0
                            && !neighborable(&rules, index, orientation, *pattern_p, pattern_q)
                        {
                            *y = 0.0;
                            modified = true;
                        }
                    }
                    if modified {
                        stack.push(q);
                    }
                }
            } else if let Some(xs) = wavefront.get(&p).map(|xs| xs.clone()) {
                if let Some(ys) = wavefront.get(&q) {
                    let mut modified = false;
                    for (pattern_q, prob_y) in ys.borrow_mut().iter_mut().enumerate() {
                        if *prob_y == 0.0 {
                            continue;
                        }
                        if !xs.borrow().iter().enumerate().any(|(pattern_p, prob_x)| {
                            *prob_x > 0.0
                                && *prob_y > 0.0
                                && neighborable(&rules, index, orientation, pattern_p, pattern_q)
                        }) {
                            *prob_y = 0.0;
                            modified = true;
                        }
                    }
                    if modified {
                        stack.push(q);
                    }
                }
            } else {
                println!("Invalid propagation parameter.");
            }
        }
    }
}

fn neighborable(
    rules: &Vec<(usize, usize, usize)>,
    index: usize,