Skip to main content

dfir_lang/graph/
graph_algorithms.rs

1//! General graph algorithm utility functions
2
3use std::collections::{BTreeSet, HashMap, HashSet};
4use std::hash::Hash;
5
6use slotmap::{Key, SecondaryMap, SparseSecondaryMap};
7
8/// Topologically sorts a set of nodes. Returns a list where the order of `Id`s will agree with
9/// the order of any path through the graph.
10///
11/// This succeeds if the input is a directed acyclic graph (DAG).
12///
13/// If the input has a cycle, an `Err` will be returned containing the cycle. Each node in the
14/// cycle will be listed exactly once.
15///
16/// <https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search>
17pub fn topo_sort<Id, PredsIter>(
18    node_ids: impl IntoIterator<Item = Id>,
19    mut preds_fn: impl FnMut(Id) -> PredsIter,
20) -> Result<Vec<Id>, Vec<Id>>
21where
22    Id: Copy + Eq + Hash,
23    PredsIter: IntoIterator<Item = Id>,
24{
25    let (mut marked, mut order) = Default::default();
26
27    fn pred_dfs_postorder<Id, PredsIter>(
28        node_id: Id,
29        preds_fn: &mut impl FnMut(Id) -> PredsIter,
30        marked: &mut HashMap<Id, bool>, // `false` => temporary, `true` => permanent.
31        order: &mut Vec<Id>,
32    ) -> Result<(), ()>
33    where
34        Id: Copy + Eq + Hash,
35        PredsIter: IntoIterator<Item = Id>,
36    {
37        match marked.get(&node_id) {
38            Some(_permanent @ true) => Ok(()),
39            Some(_temporary @ false) => {
40                // Cycle found!
41                order.clear();
42                order.push(node_id);
43                Err(())
44            }
45            None => {
46                marked.insert(node_id, false);
47                for next_pred in (preds_fn)(node_id) {
48                    pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
49                        if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
50                            order.push(node_id);
51                        }
52                    })?;
53                }
54                order.push(node_id);
55                marked.insert(node_id, true);
56                Ok(())
57            }
58        }
59    }
60
61    for node_id in node_ids {
62        if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
63            // Cycle found.
64            let end = order.last().unwrap();
65            let beg = order.iter().position(|n| n == end).unwrap();
66            order.drain(0..=beg);
67            return Err(order);
68        }
69    }
70
71    Ok(order)
72}
73
74/// Datastructure for merging subgraphs while maintaining topological sort order.
75///
76/// Maintains a global topo-sorted Vec of all operators. Each subgraph (merged group)
77/// occupies a contiguous range in this Vec. Merging two groups combines their ranges
78/// and re-sorts the affected window so groups remain contiguous and correctly ordered.
79pub struct SubgraphMerge<K>
80where
81    K: Key,
82{
83    /// Predecessor edges in the quotient DAG (per representative).
84    subgraph_preds: SecondaryMap<K, Vec<K>>,
85    /// All operators in global topo-sort order (fixed length, reshuffled in windows).
86    /// Invariant: subgraphs are contiguous & non-overlapping ranges in this vec.
87    toposort_node: Vec<K>,
88    /// Reverse index: SG representative node -> index (in toposort_node).
89    /// Invariant: `K` is both the representative node and the first node in the SG.
90    sg_idx: SparseSecondaryMap<K, usize>,
91    /// SG representative node -> SG len.
92    /// The subgraph's nodes are `toposort_node[index..index+len]`.
93    /// Invariant: the subgraph ranges are complete and non-overlapping.
94    sg_len: SparseSecondaryMap<K, usize>,
95
96    /// Union-find for subgraph membership.
97    subgraph_unionfind: crate::union_find::UnionFind<K>,
98
99    /// Per-representative: set of representatives that this subgraph must not merge with.
100    /// Maintained symmetrically: if `enemies[a]` contains `b`, then `enemies[b]` contains `a`.
101    enemies: SecondaryMap<K, HashSet<K>>,
102}
103
104impl<K> SubgraphMerge<K>
105where
106    K: Key,
107{
108    /// Creates a new `SubgraphMerge` from nodes and their predecessor edges.
109    ///
110    /// `enemies` specifies pairs of nodes that must never be placed in the same subgraph.
111    /// These are checked in O(1) during [`Self::try_merge`] and maintained as representatives
112    /// change.
113    ///
114    /// Returns `Err` with a cycle if the input graph is not a DAG.
115    pub fn new<PredsIter>(
116        keys: impl IntoIterator<Item = K>,
117        mut preds_fn: impl FnMut(K) -> PredsIter,
118        enemies_iter: impl IntoIterator<Item = (K, K)>,
119    ) -> Result<Self, Vec<K>>
120    where
121        PredsIter: IntoIterator<Item = K>,
122    {
123        let subgraph_preds = keys
124            .into_iter()
125            .map(|k| (k, (preds_fn)(k).into_iter().collect()))
126            .collect::<SecondaryMap<K, Vec<K>>>();
127        let toposort_node =
128            topo_sort(subgraph_preds.keys(), |k| subgraph_preds[k].iter().copied())?;
129        let sg_idx = toposort_node
130            .iter()
131            .enumerate()
132            .map(|(i, &k)| (k, i))
133            .collect();
134        let sg_len = toposort_node.iter().map(|&k| (k, 1)).collect();
135        let subgraph_unionfind = crate::union_find::UnionFind::with_capacity(toposort_node.len());
136
137        let mut enemies = SecondaryMap::<K, HashSet<K>>::new();
138        for (a, b) in enemies_iter {
139            assert_ne!(a, b, "no-merge pair must not contain the same node twice");
140            enemies.entry(a).unwrap().or_default().insert(b);
141            enemies.entry(b).unwrap().or_default().insert(a);
142        }
143
144        Ok(Self {
145            subgraph_preds,
146            toposort_node,
147            sg_idx,
148            sg_len,
149            subgraph_unionfind,
150            enemies,
151        })
152    }
153
154    /// Find the representative of the subgraph containing `k`.
155    pub fn find(&mut self, k: K) -> K {
156        self.subgraph_unionfind.find(k)
157    }
158
159    /// Returns true if `u` and `v` are in the same subgraph.
160    pub fn same_set(&mut self, u: K, v: K) -> bool {
161        self.subgraph_unionfind.same_set(u, v)
162    }
163
164    /// Iterates all subgraph representatives with their topo-sorted operator slices,
165    /// in topological order (by position in `toposort_node`).
166    pub fn subgraphs(&self) -> impl Iterator<Item = &[K]> {
167        let mut i = 0;
168        std::iter::from_fn(move || {
169            let Some(&sg_node) = self.toposort_node.get(i) else {
170                debug_assert_eq!(i, self.toposort_node.len());
171                return None;
172            };
173            debug_assert_eq!(i, self.sg_idx[sg_node]);
174            let sg_len = self.sg_len[sg_node];
175            let sg_slice = &self.toposort_node[i..i + sg_len];
176            i += sg_len;
177            Some(sg_slice)
178        })
179    }
180
181    /// Attempts to merge the subgraphs containing `u` and `v`.
182    /// Returns `false` if merging would create a cycle in the subgraph DAG,
183    /// or if the merge is forbidden by a no-merge constraint.
184    pub fn try_merge(&mut self, u: K, v: K) -> bool {
185        // 0. Set up `u` and `v` to be in order, and subgraph representatives.
186
187        // Ensure `u` and `v` are subgraph representatives.
188        let u = self.subgraph_unionfind.find(u);
189        let v = self.subgraph_unionfind.find(v);
190        if u == v {
191            // Short circuit no-op case. Guards against weird `u == v` aliasing.
192            return true;
193        }
194
195        // O(1) no-merge constraint check.
196        if self
197            .enemies
198            .get(u)
199            .is_some_and(|enemy_set| enemy_set.contains(&v))
200        {
201            return false;
202        }
203
204        // Ensure `u` is before `v` in topo order.
205        let (u, v) = if self.sg_idx[u] < self.sg_idx[v] {
206            (u, v)
207        } else {
208            (v, u)
209        };
210        // Get the member nodes of `u` and `v`, and the `window`. Pulling references here does ensure that
211        // `toposort_node` remains unchanged until we properly merge `u_nodes` and `v_nodes`.
212        let (u_nodes, v_nodes, window) = {
213            let (u_idx, u_len) = (self.sg_idx[u], self.sg_len[u]);
214            let (v_idx, v_len) = (self.sg_idx[v], self.sg_len[v]);
215            (
216                &self.toposort_node[u_idx..u_idx + u_len],
217                &self.toposort_node[v_idx..v_idx + v_len],
218                u_idx..v_idx + v_len,
219            )
220        };
221
222        // 1. Cycle check: can `v` reach `u` via predecessor edges?
223        // Only groups within `window` can be on such a path. Direct predecessor edges from `v` to `u` become
224        // self-loops after merge and are not real cycles, so we skip direct `u -> v` edges.
225
226        let mut stack = vec![v];
227        let mut visited = HashSet::<_>::from_iter([v]);
228
229        while let Some(x) = stack.pop() {
230            for &p in self.subgraph_preds[x].iter() {
231                let root_p = self.subgraph_unionfind.find(p);
232
233                if root_p == u {
234                    if x == v {
235                        // Ignore `u -> v` direct edge, not a real cycle.
236                        continue;
237                    }
238                    // Cycle found, return false.
239                    return false;
240                }
241
242                // Prune: group must be within the `window`.
243                if window.contains(&self.sg_idx[root_p]) && visited.insert(root_p) {
244                    stack.push(root_p);
245                }
246            }
247        }
248
249        // 2. Perform merge in union-find and append predecessors.
250        // `u` will be the new representative.
251        {
252            // `UnionFind::union` ensures the first arg's representative will represent the new merged group. `u` is before
253            // `v` in the topo order, and `u` is already its own representative. This ensures that `u` stays at the *start*
254            // of its subgraph group, so the `idx..idx+len` slice is the whole subgraph.
255            let _new_root = self.subgraph_unionfind.union(u, v);
256            debug_assert_eq!(u, _new_root);
257            let v_preds = &mut self.subgraph_preds.remove(v).unwrap();
258            let u_preds = &mut self.subgraph_preds[u];
259            u_preds.append(v_preds);
260            // Update all preds to be representatives (from past unioning). Delete any self-edges.
261            u_preds.retain_mut(|x| {
262                *x = self.subgraph_unionfind.find(*x);
263                *x != u // Retain only non-self edges.
264            });
265            // Remove any duplicates (may have be created from past unioning).
266            u_preds.sort_unstable();
267            u_preds.dedup();
268        }
269        // Remove subsumed `v` and grow `u`'s length.
270        {
271            self.sg_idx.remove(v).unwrap();
272            let v_len = self.sg_len.remove(v).unwrap();
273            // Set `u`'s len to the combined size. (Note: `sg_idx[u]` still needs updating, below after re-sort).
274            self.sg_len[u] += v_len;
275        }
276        // Merge enemies: remap v's enemies to point to u.
277        for w in self.enemies.remove(v).into_iter().flatten() {
278            debug_assert_ne!(
279                w, u,
280                "`w` in an enemy of `v`, so it can't be `w == u`, since we are merging `u` and `v`"
281            );
282            // Add `w`` to `u`'s enemies.
283            self.enemies.entry(u).unwrap().or_default().insert(w);
284            // Add `u` to `w`'s enemies. Remove `v`.
285            // `w` enemies guaranteed to exist by the symmetric invariant: if `v`'s enemies contain `w``, then `w`'s
286            // enemies contain `v`.
287            let w_enemies = self.enemies.get_mut(w).unwrap();
288            let _removed = w_enemies.remove(&v);
289            debug_assert!(_removed);
290            w_enemies.insert(u);
291        }
292
293        // 3. Re-sort groups in `window`.
294        // Topo-sort groups in the window by their quotient edges.
295        {
296            let sorted_groups = {
297                let reps_in_window = self.toposort_node[window.clone()]
298                    .iter()
299                    .map(|&k| self.subgraph_unionfind.find(k))
300                    .collect::<BTreeSet<_>>();
301
302                // We borrow fields separately to allow the closure to call `find()` (which needs `&mut`) while also reading
303                // `subgraph_preds` and `sg_idx` (via `&`).
304                // Only predecessor groups whose range overlaps the window are included - groups entirely outside the window
305                // have their ordering already satisfied.
306                let subgraph_preds = &self.subgraph_preds;
307                let subgraph_unionfind = &mut self.subgraph_unionfind;
308                let sg_idx = &self.sg_idx;
309                topo_sort(reps_in_window, |k| {
310                    subgraph_preds[k]
311                    .iter()
312                    .map(|&p| subgraph_unionfind.find(p))
313                    .filter(|&p| window.contains(&sg_idx[p])) // Prune to window.
314                    .collect::<Vec<_>>()
315                    .into_iter()
316                })
317                .expect("bug: cycle check passed but re-toposort found cycle")
318            };
319
320            // Rebuild the window: lay out each group's operators in sorted group order.
321            // All groups except `u` (new root) have contiguous operators at their current range. `u`'s operators will be
322            // `u_nodes` *and* `v_nodes`.
323            let mut buf = Vec::with_capacity(window.len());
324            for &group in &sorted_groups {
325                if group == u {
326                    buf.extend_from_slice(u_nodes);
327                    buf.extend_from_slice(v_nodes);
328                } else {
329                    let g_idx = self.sg_idx[group];
330                    let g_len = self.sg_len[group];
331                    buf.extend_from_slice(&self.toposort_node[g_idx..g_idx + g_len]);
332                }
333            }
334            self.toposort_node[window.clone()].copy_from_slice(&buf);
335
336            // Update reverse index `sg_idx` start positions (`sg_len` already correct).
337            let mut pos = window.start;
338            for &group in &sorted_groups {
339                self.sg_idx[group] = pos;
340                pos += self.sg_len[group];
341            }
342            debug_assert_eq!(window.end, pos);
343        }
344
345        true
346    }
347}
348
349#[cfg(test)]
350mod test {
351    use std::collections::{BTreeMap, BTreeSet};
352
353    use itertools::Itertools;
354    use slotmap::SlotMap;
355
356    use super::*;
357
358    #[test]
359    pub fn test_toposort() {
360        let edges = [
361            (5, 11),
362            (11, 2),
363            (11, 9),
364            (11, 10),
365            (7, 11),
366            (7, 8),
367            (8, 9),
368            (3, 8),
369            (3, 10),
370        ];
371
372        // https://commons.wikimedia.org/wiki/File:Directed_acyclic_graph_2.svg
373        let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
374            edges
375                .iter()
376                .filter(move |&&(_, dst)| v == dst)
377                .map(|&(src, _)| src)
378        });
379        assert!(
380            sort.is_ok(),
381            "Did not expect cycle: {:?}",
382            sort.unwrap_err()
383        );
384
385        let sort = sort.unwrap();
386        println!("{:?}", sort);
387
388        let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
389        for (src, dst) in edges.iter() {
390            assert!(position[src] < position[dst]);
391        }
392    }
393
394    #[test]
395    pub fn test_toposort_cycle() {
396        // https://commons.wikimedia.org/wiki/File:Directed_graph,_cyclic.svg
397        //          ┌────►C──────┐
398        //          │            │
399        //          │            ▼
400        // A───────►B            E ─────►F
401        //          ▲            │
402        //          │            │
403        //          └─────D◄─────┘
404        let edges = [
405            ('A', 'B'),
406            ('B', 'C'),
407            ('C', 'E'),
408            ('D', 'B'),
409            ('E', 'F'),
410            ('E', 'D'),
411        ];
412        let ids = edges
413            .iter()
414            .flat_map(|&(a, b)| [a, b])
415            .collect::<BTreeSet<_>>();
416        let cycle_rotations = BTreeSet::from_iter([
417            ['B', 'C', 'E', 'D'],
418            ['C', 'E', 'D', 'B'],
419            ['E', 'D', 'B', 'C'],
420            ['D', 'B', 'C', 'E'],
421        ]);
422
423        let permutations = ids.iter().copied().permutations(ids.len());
424        for permutation in permutations {
425            let result = topo_sort(permutation.iter().copied(), |v| {
426                edges
427                    .iter()
428                    .filter(move |&&(_, dst)| v == dst)
429                    .map(|&(src, _)| src)
430            });
431            assert!(result.is_err());
432            let cycle = result.unwrap_err();
433            assert!(
434                cycle_rotations.contains(&*cycle),
435                "cycle: {:?}, vertex order: {:?}",
436                cycle,
437                permutation
438            );
439        }
440    }
441
442    #[test]
443    pub fn test_subgraph_merge_basic() {
444        let mut preds = SlotMap::new();
445
446        let a = preds.insert(vec![]);
447        let b = preds.insert(vec![]);
448        let c = preds.insert(vec![]);
449        let d = preds.insert(vec![]);
450        let e = preds.insert(vec![]);
451        let f = preds.insert(vec![]);
452
453        preds[b].push(a);
454        preds[c].push(b);
455        preds[d].push(b);
456        preds[e].push(c);
457        preds[e].push(d);
458        preds[f].push(e);
459
460        let mut merge = SubgraphMerge::new(
461            preds.keys(),
462            |v| preds[v].iter().copied(),
463            std::iter::empty(),
464        )
465        .unwrap();
466
467        assert!(merge.try_merge(a, a)); // No-op.
468        //        ┌──► C ──┐
469        //        │        ▼
470        // A ───► B        E ───► F
471        //        │        ▲
472        //        └──► D ──┘
473        assert!(merge.try_merge(b, c));
474        assert!(merge.try_merge(b, c)); // No-op.
475        // A ───► BC ────► E ───► F
476        //        │        ▲
477        //        └──► D ──┘
478        assert!(!merge.try_merge(c, e)); // Rejected due to `D` outside-cycle.
479
480        assert!(merge.try_merge(d, e));
481        assert!(merge.try_merge(c, e)); // Now valid since `D` is no longer outside.
482    }
483
484    #[test]
485    pub fn test_subgraph_merge_enemies() {
486        let mut preds = SlotMap::new();
487
488        // A ───► B ───► C ───► D
489        let a = preds.insert(vec![]);
490        let b = preds.insert(vec![]);
491        let c = preds.insert(vec![]);
492        let d = preds.insert(vec![]);
493
494        preds[b].push(a);
495        preds[c].push(b);
496        preds[d].push(c);
497
498        // B and C are enemies (must not merge).
499        let mut merge =
500            SubgraphMerge::new(preds.keys(), |v| preds[v].iter().copied(), [(b, c)]).unwrap();
501
502        // Direct enemy pair: rejected.
503        assert!(!merge.try_merge(b, c));
504
505        // Non-enemy pairs: allowed.
506        assert!(merge.try_merge(a, b));
507
508        // Now A and B are merged. C is still an enemy of the AB group.
509        assert!(!merge.try_merge(a, c));
510        assert!(!merge.try_merge(b, c));
511
512        // D is not an enemy of anyone.
513        assert!(merge.try_merge(c, d));
514
515        // After C and D merge, the CD group is still an enemy of AB.
516        assert!(!merge.try_merge(a, d));
517        assert!(!merge.try_merge(b, d));
518    }
519}