rust-index: add support for `reachableroots2`
authorRaphaël Gomès <rgomes@octobus.net>
Mon, 30 Oct 2023 11:57:36 +0100
changeset 51222 fc05dd74e907
parent 51221 5a7d5fd6808c
child 51223 42c8dbdb88ad
rust-index: add support for `reachableroots2` Exposition in `hg-cpython` done in regular impl block, again for rustfmt support etc.
rust/hg-core/src/revlog/index.rs
rust/hg-cpython/src/revlog.rs
--- a/rust/hg-core/src/revlog/index.rs	Thu Nov 02 12:17:06 2023 +0100
+++ b/rust/hg-core/src/revlog/index.rs	Mon Oct 30 11:57:36 2023 +0100
@@ -983,6 +983,56 @@
         }
         min_rev
     }
+
+    /// Return `(heads(::(<roots> and <roots>::<heads>)))`
+    /// If `include_path` is `true`, return `(<roots>::<heads>)`."""
+    ///
+    /// `min_root` and `roots` are unchecked since they are just used as
+    /// a bound or for comparison and don't need to represent a valid revision.
+    /// In practice, the only invalid revision passed is the working directory
+    /// revision ([`i32::MAX`]).
+    pub fn reachable_roots(
+        &self,
+        min_root: UncheckedRevision,
+        mut heads: Vec<Revision>,
+        roots: HashSet<UncheckedRevision>,
+        include_path: bool,
+    ) -> Result<HashSet<Revision>, GraphError> {
+        if roots.is_empty() {
+            return Ok(HashSet::new());
+        }
+        let mut reachable = HashSet::new();
+        let mut seen = HashMap::new();
+
+        while let Some(rev) = heads.pop() {
+            if roots.contains(&rev.into()) {
+                reachable.insert(rev);
+                if !include_path {
+                    continue;
+                }
+            }
+            let parents = self.parents(rev)?;
+            seen.insert(rev, parents);
+            for parent in parents {
+                if parent.0 >= min_root.0 && !seen.contains_key(&parent) {
+                    heads.push(parent);
+                }
+            }
+        }
+        if !include_path {
+            return Ok(reachable);
+        }
+        let mut revs: Vec<_> = seen.keys().collect();
+        revs.sort_unstable();
+        for rev in revs {
+            for parent in seen[rev] {
+                if reachable.contains(&parent) {
+                    reachable.insert(*rev);
+                }
+            }
+        }
+        Ok(reachable)
+    }
 }
 
 /// Set of roots of all non-public phases
--- a/rust/hg-cpython/src/revlog.rs	Thu Nov 02 12:17:06 2023 +0100
+++ b/rust/hg-cpython/src/revlog.rs	Mon Oct 30 11:57:36 2023 +0100
@@ -7,7 +7,7 @@
 
 use crate::{
     cindex,
-    conversion::rev_pyiter_collect,
+    conversion::{rev_pyiter_collect, rev_pyiter_collect_or_else},
     utils::{node_from_py_bytes, node_from_py_object},
     PyRevision,
 };
@@ -251,7 +251,23 @@
 
     /// reachableroots
     def reachableroots2(&self, *args, **kw) -> PyResult<PyObject> {
-        self.call_cindex(py, "reachableroots2", args, kw)
+        let rust_res = self.inner_reachableroots2(
+            py,
+            UncheckedRevision(args.get_item(py, 0).extract(py)?),
+            args.get_item(py, 1),
+            args.get_item(py, 2),
+            args.get_item(py, 3).extract(py)?,
+        )?;
+
+        let c_res = self.call_cindex(py, "reachableroots2", args, kw)?;
+        // ordering of C result depends on how the computation went, and
+        // Rust result ordering is arbitrary. Hence we compare after
+        // sorting the results (in Python to avoid reconverting everything
+        // back to Rust structs).
+        assert_py_eq_normalized(py, "reachableroots2", &rust_res, &c_res,
+                                |v| format!("sorted({})", v))?;
+
+        Ok(rust_res)
     }
 
     /// get head revisions
@@ -929,6 +945,37 @@
             Ok(PyList::new(py, &res).into_object())
         }
     }
+
+    fn inner_reachableroots2(
+        &self,
+        py: Python,
+        min_root: UncheckedRevision,
+        heads: PyObject,
+        roots: PyObject,
+        include_path: bool,
+    ) -> PyResult<PyObject> {
+        let index = &*self.index(py).borrow();
+        let heads = rev_pyiter_collect_or_else(py, &heads, index, |_rev| {
+            PyErr::new::<IndexError, _>(py, "head out of range")
+        })?;
+        let roots: Result<_, _> = roots
+            .iter(py)?
+            .map(|r| {
+                r.and_then(|o| match o.extract::<PyRevision>(py) {
+                    Ok(r) => Ok(UncheckedRevision(r.0)),
+                    Err(e) => Err(e),
+                })
+            })
+            .collect();
+        let as_set = index
+            .reachable_roots(min_root, heads, roots?, include_path)
+            .map_err(|e| graph_error(py, e))?;
+        let as_vec: Vec<PyObject> = as_set
+            .iter()
+            .map(|r| PyRevision::from(*r).into_py_object(py).into_object())
+            .collect();
+        Ok(PyList::new(py, &as_vec).into_object())
+    }
 }
 
 fn revlog_error(py: Python) -> PyErr {