MeanField

Mean-field solvers and EI-cluster specializations.

The package exposes a generic fixed-point solver in RateSystem and a concrete clustered E/I specialization in EIClusterNetwork.

Minimal example:

from MeanField import EIClusterNetwork

system = EIClusterNetwork(parameter, v_focus=0.2)
x, residual, success = system.solve()
rates = system.full_rates_numpy(x)

The helpers in rate_system are used by the figure pipelines to trace event-rate functions, store fixpoint bundles, and reuse cached results.

Regenerating docs:

python scripts/generate_api_docs.py
 1"""Mean-field solvers and EI-cluster specializations.
 2
 3The package exposes a generic fixed-point solver in `RateSystem` and a concrete
 4clustered E/I specialization in `EIClusterNetwork`.
 5
 6Minimal example:
 7
 8```python
 9from MeanField import EIClusterNetwork
10
11system = EIClusterNetwork(parameter, v_focus=0.2)
12x, residual, success = system.solve()
13rates = system.full_rates_numpy(x)
14```
15
16The helpers in `rate_system` are used by the figure pipelines to trace
17event-rate functions, store fixpoint bundles, and reuse cached results.
18
19Regenerating docs:
20
21```bash
22python scripts/generate_api_docs.py
23```
24"""
25
26from .rate_system import (
27    ERFResult,
28    RateSystem,
29    aggregate_data,
30    ensure_output_folder,
31    serialize_erf,
32)
33from .ei_cluster_network import EIClusterNetwork
34
35__all__ = [
36    "RateSystem",
37    "ERFResult",
38    "EIClusterNetwork",
39    "ensure_output_folder",
40    "serialize_erf",
41    "aggregate_data",
42]
class RateSystem:
 71class RateSystem:
 72    """General mean-field solver with helper utilities for ERF and fixpoints.
 73
 74    Subclasses implement `_build_dynamics(...)` and then inherit the fixed-point
 75    solver, ERF sweep, and fixpoint analysis helpers.
 76    """
 77    # Minimum allowed variance to prevent numerical instabilities.
 78    # Chosen small enough to avoid affecting dynamics, only avoids divide-by-zero.
 79    VAR_EPS = 1e-12
 80    # Accept near-roots even when the underlying SciPy solver does not set success=True.
 81    RESIDUAL_ACCEPT_TOL = 5e-4
 82    # Smallest continuation step introduced when adaptively subdividing the ERF sweep.
 83    ADAPTIVE_CONTINUATION_MIN_STEP = 1e-3
 84
 85    def __init__(
 86        self,
 87        parameter: Dict,
 88        v_focus: float,
 89        *,
 90        focus_population: Optional[Union[int, Sequence[int]]] = None,
 91        prefer_jax: bool = True,
 92        root_tol: float = 1e-9,
 93        max_function_evals: int = 4000,
 94        max_newton_steps: Optional[int] = 1000,
 95        **network_kwargs,
 96    ) -> None:
 97        self.parameter = dict(parameter)
 98        self.v_focus = float(v_focus)
 99        self.prefer_jax = bool(prefer_jax)
100        self.root_tol = root_tol
101        self.max_function_evals = max_function_evals
102        default_steps = 256
103        if max_newton_steps is None:
104            self.max_steps = default_steps
105        else:
106            steps = int(max_newton_steps)
107            if steps <= 0:
108                raise ValueError("max_newton_steps must be positive.")
109            self.max_steps = steps
110        self.network_kwargs = dict(network_kwargs)
111        dynamics = self._build_dynamics(self.parameter, **self.network_kwargs)
112        if len(dynamics) == 4:
113            self.A, self.B, self.bias, self.tau = dynamics
114            self.C = np.zeros_like(self.B)
115        elif len(dynamics) == 5:
116            self.A, self.B, self.C, self.bias, self.tau = dynamics
117            if self.C is None:
118                self.C = np.zeros_like(self.B)
119        else:
120            raise ValueError("Expected _build_dynamics to return (A, B, bias, tau) or (A, B, C, bias, tau).")
121        self.population_count = int(self.A.shape[0])
122        if self.A.shape != self.B.shape or self.A.shape != self.C.shape:
123            raise ValueError("Connectivity mean and variance matrices must match in shape.")
124        if self.bias.shape[0] != self.population_count or self.tau.shape[0] != self.population_count:
125            raise ValueError("Bias and tau vectors must match the matrix dimensions.")
126        if focus_population is not None:
127            focus_config = focus_population
128        elif isinstance(self.parameter.get("focus_population"), (list, tuple, range)):
129            focus_config = self.parameter.get("focus_population")
130        else:
131            count = int(self.parameter.get("focus_count", 1) or 1)
132            focus_config = list(range(max(count, 1)))
133        self.focus_indices = self._resolve_focus_indices(focus_config)
134        self._initialize_group_constraints()
135        self.use_jax = self.prefer_jax and HAS_JAX and optx is not None
136        self._jax_args = None
137
138    # --- abstract hooks -------------------------------------------------
139    def _build_dynamics(
140        self, parameter: Dict, **network_kwargs
141    ) -> Union[
142        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
143        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
144    ]:
145        raise NotImplementedError("Derived classes must implement '_build_dynamics'.")
146
147    def _build_population_groups(self, focus: np.ndarray) -> List[np.ndarray]:
148        focus_set = set(int(idx) for idx in focus.tolist())
149        if not focus_set:
150            focus_set = {0}
151        groups: List[np.ndarray] = [np.array(sorted(focus_set), dtype=int)]
152        remaining = [idx for idx in range(self.population_count) if idx not in focus_set]
153        for idx in remaining:
154            groups.append(np.array([idx], dtype=int))
155        return groups
156
157    # --- solver core ----------------------------------------------------
158    def _initialize_group_constraints(self) -> None:
159        groups = self._build_population_groups(self.focus_indices)
160        covered = np.concatenate(groups)
161        if covered.size != self.population_count or not np.all(np.sort(covered) == np.arange(self.population_count)):
162            raise ValueError("Population groups must cover each population exactly once.")
163        self.groups: List[np.ndarray] = groups
164        self.group_count = len(groups)
165        self.focus_group_index = 0
166        self.solve_groups = [idx for idx in range(self.group_count) if idx != self.focus_group_index]
167        self.dim = len(self.solve_groups)
168        self.focus_population_mask = np.zeros(self.population_count, dtype=bool)
169        self.focus_population_mask[self.focus_indices] = True
170        self.group_membership = np.zeros((self.population_count, self.group_count), dtype=float)
171        for idx, members in enumerate(self.groups):
172            self.group_membership[members, idx] = 1.0
173        self.group_sizes = np.array([len(members) for members in self.groups], dtype=float)
174        self.group_inverse_sizes = np.reciprocal(self.group_sizes)
175        self.focus_vector = np.zeros(self.group_count, dtype=float)
176        self.focus_vector[self.focus_group_index] = 1.0
177        self.selector_matrix = np.zeros((self.group_count, self.dim), dtype=float)
178        for col, group_idx in enumerate(self.solve_groups):
179            self.selector_matrix[group_idx, col] = 1.0
180        self.residual_matrix = np.zeros((self.dim, self.population_count), dtype=float)
181        for row, group_idx in enumerate(self.solve_groups):
182            members = self.groups[group_idx]
183            self.residual_matrix[row, members] = 1.0 / len(members)
184
185    def _resolve_focus_indices(self, focus_config) -> np.ndarray:
186        if focus_config is None:
187            entries = [0]
188        elif isinstance(focus_config, (int, np.integer)):
189            entries = [int(focus_config)]
190        else:
191            entries = []
192            for value in focus_config:
193                if isinstance(value, slice):
194                    start = 0 if value.start is None else value.start
195                    stop = self.population_count if value.stop is None else value.stop
196                    step = 1 if value.step is None else value.step
197                    entries.extend(range(start, stop, step))
198                else:
199                    entries.append(int(value))
200        focus = sorted(set(entries))
201        if not focus:
202            focus = [0]
203        for idx in focus:
204            if idx < 0 or idx >= self.population_count:
205                raise ValueError(f"Focus index {idx} out of bounds for population size {self.population_count}.")
206        return np.asarray(focus, dtype=int)
207
208    def _full_rates_numpy(self, x: np.ndarray) -> np.ndarray:
209        arr = np.asarray(x, dtype=float).ravel()
210        if arr.size == self.population_count:
211            return arr
212        if arr.size == self.group_count:
213            group_values = arr.copy()
214            group_values[self.focus_group_index] = self.v_focus
215        else:
216            if arr.size != self.dim:
217                raise ValueError(
218                    f"Expected vector of length {self.dim}, {self.group_count}, or {self.population_count}, got {arr.size}."
219                )
220            group_values = self.selector_matrix @ arr
221            group_values += self.focus_vector * self.v_focus
222        return self.group_membership @ group_values
223
224    def _phi_numpy(self, full_rates: np.ndarray) -> np.ndarray:
225        mean = self.A.dot(full_rates) + self.bias
226        rates_sq = full_rates * full_rates
227        var = np.maximum(self.B.dot(full_rates) + self.C.dot(rates_sq), self.VAR_EPS)
228        return 0.5 * (1 - special.erf(-mean / np.sqrt(2.0 * var)))
229
230    def _phi_jacobian_numpy(self, full_rates: np.ndarray) -> np.ndarray:
231        mean = self.A.dot(full_rates) + self.bias
232        rates_sq = full_rates * full_rates
233        var = np.maximum(self.B.dot(full_rates) + self.C.dot(rates_sq), self.VAR_EPS)
234        inv_sqrt = 1.0 / np.sqrt(2.0 * var)
235        exp_term = np.exp(-(mean ** 2) / (2.0 * var))
236        coeff = (1.0 / np.sqrt(np.pi)) * exp_term * inv_sqrt
237        correction = mean / (2.0 * var)
238        dvar_dm = self.B + self.C * (2.0 * full_rates[None, :])
239        return coeff[:, None] * (self.A - correction[:, None] * dvar_dm)
240
241    def residual_numpy(self, x: np.ndarray) -> np.ndarray:
242        rates = self._full_rates_numpy(x)
243        phi = self._phi_numpy(rates)
244        residual = (phi - rates) / self.tau
245        if self.dim == 0:
246            return np.zeros((0,), dtype=float)
247        return self.residual_matrix @ residual
248
249    @staticmethod
250    def _jax_residual(x, v_focus, args):
251        A, B, C, bias, tau, focus_vector, selector, membership, reduction, var_eps = args
252        group_values = selector @ x + focus_vector * v_focus
253        full_rates = membership @ group_values
254        mean = A @ full_rates + bias
255        var = jnp.maximum(B @ full_rates + C @ (full_rates * full_rates), var_eps)
256        phi = 0.5 * (1 - jspecial.erf(-mean / jnp.sqrt(2.0 * var)))
257        residual = (phi - full_rates) / tau
258        return reduction @ residual
259
260    def solve(self, initial_guess: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, bool]:
261        """Solve for a fixed point of the reduced mean-field system.
262
263        Returns
264        -------
265        tuple[np.ndarray, np.ndarray, bool]
266            `(x, residual, success)` where `x` is the reduced solution vector.
267
268        Expected output
269        ---------------
270        `success` is `True` when the nonlinear solver converged and `residual`
271        is close to zero.
272        """
273        initial = self._coerce_initial(initial_guess)
274        if self.dim == 0:
275            residual = self.residual_numpy(initial)
276            return initial, residual, True
277        if self.use_jax:
278            try:
279                return self._solve_with_optimistix(initial)
280            except SolverConvergenceError:
281                raise
282            except (ValueError, RuntimeError, TypeError, AttributeError) as e:
283                logger.warning(
284                    "Optimistix solver failed with %s: %s. Falling back to scipy solver.",
285                    type(e).__name__,
286                    str(e),
287                )
288                self.use_jax = False
289        return self._solve_with_scipy(initial)
290
291    @staticmethod
292    def _residual_norm(residual: np.ndarray) -> float:
293        arr = np.asarray(residual, dtype=float).ravel()
294        if arr.size == 0:
295            return 0.0
296        return float(np.linalg.norm(arr))
297
298    def _normalize_solver_result(
299        self,
300        x: np.ndarray,
301        residual: np.ndarray,
302        success: bool,
303        *,
304        method: str,
305    ) -> Tuple[np.ndarray, np.ndarray, bool]:
306        x_arr = np.asarray(x, dtype=float).ravel()
307        residual_arr = np.asarray(residual, dtype=float).ravel()
308        finite = np.isfinite(x_arr).all() and np.isfinite(residual_arr).all()
309        accepted = bool(success) and finite
310        if (not accepted) and finite:
311            residual_norm = self._residual_norm(residual_arr)
312            if residual_norm <= self.RESIDUAL_ACCEPT_TOL:
313                logger.debug(
314                    "Accepting near-root from %s at v_focus %.6f with residual norm %.3e.",
315                    method,
316                    float(self.v_focus),
317                    residual_norm,
318                )
319                accepted = True
320        return x_arr, residual_arr, accepted
321
322    def _solve_with_scipy(self, initial: np.ndarray) -> Tuple[np.ndarray, np.ndarray, bool]:
323        best_x = np.asarray(initial, dtype=float).ravel()
324        best_residual = np.asarray(self.residual_numpy(best_x), dtype=float).ravel()
325        best_norm = self._residual_norm(best_residual)
326
327        def remember(x: np.ndarray, residual: np.ndarray) -> None:
328            nonlocal best_x, best_residual, best_norm
329            if not (np.isfinite(x).all() and np.isfinite(residual).all()):
330                return
331            norm = self._residual_norm(residual)
332            if norm < best_norm:
333                best_x = np.asarray(x, dtype=float).ravel()
334                best_residual = np.asarray(residual, dtype=float).ravel()
335                best_norm = norm
336
337        hybr = optimize.root(
338            self.residual_numpy,
339            initial,
340            method="hybr",
341            tol=self.root_tol,
342            options={"maxfev": self.max_function_evals},
343        )
344        x_hybr, residual_hybr, success_hybr = self._normalize_solver_result(
345            hybr.x,
346            hybr.fun,
347            bool(hybr.success),
348            method="scipy.root(hybr)",
349        )
350        if success_hybr:
351            return x_hybr, residual_hybr, True
352        remember(x_hybr, residual_hybr)
353
354        lm_initial = x_hybr if np.isfinite(x_hybr).all() else np.asarray(initial, dtype=float).ravel()
355        lm = optimize.root(
356            self.residual_numpy,
357            lm_initial,
358            method="lm",
359            tol=self.root_tol,
360            options={"maxiter": self.max_function_evals},
361        )
362        x_lm, residual_lm, success_lm = self._normalize_solver_result(
363            lm.x,
364            lm.fun,
365            bool(lm.success),
366            method="scipy.root(lm)",
367        )
368        if success_lm:
369            return x_lm, residual_lm, True
370        remember(x_lm, residual_lm)
371
372        ls_initial = x_lm if np.isfinite(x_lm).all() else lm_initial
373        ls = optimize.least_squares(
374            self.residual_numpy,
375            ls_initial,
376            method="trf",
377            xtol=self.root_tol,
378            ftol=self.root_tol,
379            gtol=self.root_tol,
380            max_nfev=self.max_function_evals,
381        )
382        x_ls, residual_ls, success_ls = self._normalize_solver_result(
383            ls.x,
384            ls.fun,
385            bool(ls.success),
386            method="scipy.least_squares(trf)",
387        )
388        if success_ls:
389            return x_ls, residual_ls, True
390        remember(x_ls, residual_ls)
391
392        return best_x, best_residual, False
393
394    def _solve_with_optimistix(self, initial: np.ndarray) -> Tuple[np.ndarray, np.ndarray, bool]:
395        args = self._prepare_jax_args()
396        solver_entry = self._get_jax_solver()
397        if solver_entry is None:
398            raise RuntimeError("Optimistix solver is unavailable.")
399        x0 = jnp.asarray(initial, dtype=jnp.float64)
400        v_focus = jnp.asarray(float(self.v_focus), dtype=jnp.float64)
401        value, status, _ = solver_entry["single"](x0, v_focus, args)
402        success = bool(np.asarray(status, dtype=bool))
403        if not success:
404            raise SolverConvergenceError(self.v_focus, self.max_steps)
405        value_np = np.asarray(value, dtype=float)
406        return value_np, self.residual_numpy(value_np), success
407
408    def _prepare_jax_args(self):
409        if self._jax_args is None:
410            self._jax_args = (
411                jnp.asarray(self.A, dtype=jnp.float64),
412                jnp.asarray(self.B, dtype=jnp.float64),
413                jnp.asarray(self.C, dtype=jnp.float64),
414                jnp.asarray(self.bias, dtype=jnp.float64),
415                jnp.asarray(self.tau, dtype=jnp.float64),
416                jnp.asarray(self.focus_vector, dtype=jnp.float64),
417                jnp.asarray(self.selector_matrix, dtype=jnp.float64),
418                jnp.asarray(self.group_membership, dtype=jnp.float64),
419                jnp.asarray(self.residual_matrix, dtype=jnp.float64),
420                float(self.VAR_EPS),
421            )
422        return self._jax_args
423
424    def _get_jax_solver(self):
425        if not (self.use_jax and HAS_JAX and optx is not None):
426            return None
427        if self.dim == 0:
428            return None
429        key = (self.dim, float(self.root_tol), int(self.max_steps))
430        entry = JAX_SOLVER_CACHE.get(key)
431        if entry is None:
432            entry = _build_jax_solver_entry(self.dim, float(self.root_tol), int(self.max_steps))
433            JAX_SOLVER_CACHE[key] = entry
434        return entry
435
436    def phi_numpy(self, x: np.ndarray) -> np.ndarray:
437        """Evaluate the transfer function on the full-rate state implied by `x`.
438
439        Expected output
440        ---------------
441        Returns one activity value per population in the interval `[0, 1]`.
442        """
443        rates = self._full_rates_numpy(np.asarray(x, dtype=float))
444        return self._phi_numpy(rates)
445
446    def full_rates_numpy(self, x: np.ndarray) -> np.ndarray:
447        """Expand a reduced solver vector into one rate per population.
448
449        Expected output
450        ---------------
451        The returned array has length `population_count`.
452        """
453        return self._full_rates_numpy(np.asarray(x, dtype=float))
454
455    def jacobian_numpy(self, x: np.ndarray) -> np.ndarray:
456        """Return the Jacobian of the full mean-field residual at `x`.
457
458        Expected output
459        ---------------
460        The returned matrix has shape `(population_count, population_count)`.
461        """
462        rates = self._full_rates_numpy(np.asarray(x, dtype=float))
463        return (self._phi_jacobian_numpy(rates) - np.eye(self.population_count)) / self.tau[:, np.newaxis]
464
465    def focus_output(self, rates: np.ndarray) -> float:
466        """Average the rates over the configured focus populations."""
467        values = np.asarray(rates, dtype=float)[self.focus_indices]
468        return float(values.mean()) if values.size else float(rates[0])
469
470    def _reduce_full_rates(self, full_rates: np.ndarray) -> np.ndarray:
471        if self.dim == 0:
472            return np.zeros((0,), dtype=float)
473        full = np.asarray(full_rates, dtype=float).reshape((self.population_count,))
474        group_sums = self.group_membership.T @ full
475        group_means = group_sums * self.group_inverse_sizes
476        return group_means[self.solve_groups]
477
478    def _coerce_initial(self, initial_guess: Optional[np.ndarray]) -> np.ndarray:
479        if self.dim == 0:
480            return np.zeros((0,), dtype=float)
481        if initial_guess is None:
482            return np.full((self.dim,), 0.1, dtype=float)
483        arr = np.asarray(initial_guess, dtype=float).ravel()
484        if arr.size == self.dim:
485            return arr
486        if arr.size == self.group_count:
487            copy = arr.copy()
488            copy[self.focus_group_index] = self.v_focus
489            return copy[self.solve_groups]
490        if arr.size == self.population_count:
491            return self._reduce_full_rates(arr)
492        non_focus_count = self.population_count - len(self.focus_indices)
493        if arr.size == non_focus_count:
494            full_rates = np.empty(self.population_count, dtype=float)
495            full_rates[self.focus_population_mask] = self.v_focus
496            full_rates[~self.focus_population_mask] = arr
497            return self._reduce_full_rates(full_rates)
498        raise ValueError(
499            f"Initial guess must have length {self.dim}, {self.group_count}, {self.population_count}, or {non_focus_count}, got {arr.size}."
500        )
501
502    def solve_sequence(
503        self,
504        v_focus_values: Sequence[float],
505        initial_guess: Optional[np.ndarray] = None,
506    ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
507        """Solve a sequence of focus inputs with the accelerated JAX path.
508
509        Expected output
510        ---------------
511        Returns `(solutions, success_flags)` when the JAX solver is available,
512        otherwise `None`.
513        """
514        values = np.asarray(list(v_focus_values), dtype=float)
515        if values.size == 0:
516            return np.zeros((0, self.dim)), np.zeros((0,), dtype=bool)
517        if self.dim == 0:
518            zeros = np.zeros((values.size, 0), dtype=float)
519            return zeros, np.ones((values.size,), dtype=bool)
520        if not (self.use_jax and HAS_JAX and optx is not None):
521            return None
522        solver_entry = self._get_jax_solver()
523        if solver_entry is None:
524            return None
525        args = self._prepare_jax_args()
526        initial = self._coerce_initial(initial_guess)
527        x0 = jnp.asarray(initial, dtype=jnp.float64)
528        v_seq = jnp.asarray(values, dtype=jnp.float64)
529        try:
530            solutions, statuses, _ = solver_entry["scan"](x0, v_seq, args)
531        except (ValueError, RuntimeError, TypeError, AttributeError) as exc:
532            logger.warning(
533                "Optimistix sweep failed with %s: %s. Falling back to sequential solver.",
534                type(exc).__name__,
535                str(exc),
536            )
537            self.use_jax = False
538            return None
539        solution_np = np.asarray(solutions, dtype=float)
540        success = np.asarray(statuses, dtype=bool)
541        return solution_np, success
542
543    # --- class helpers --------------------------------------------------
544    @classmethod
545    def generate_erf_curve(
546        cls,
547        parameter: Dict,
548        *,
549        start: float = 0.02,
550        end: float = 1.0,
551        step_number: int = 20,
552        retry_step: Optional[float] = None,
553        initial_guess: Optional[np.ndarray] = None,
554        fallback_initials: Optional[Sequence[float]] = None,
555        **network_kwargs,
556    ) -> ERFResult:
557        """Compute the ERF for ``parameter``.
558
559        When ``retry_step`` is provided, the solver retries inputs separated by
560        that value whenever convergence fails. If the ERF cannot be completed
561        the ``completed`` flag is ``False`` and no serialization should happen.
562        Parameters
563        ----------
564        start : float, optional
565            Lower bound of the input range for the ERF (default: 0.02).
566        end : float, optional
567            Upper bound of the input range for the ERF (default: 1.0).
568        step_number : int, optional
569            Number of steps between ``start`` and ``end`` (default: 20).
570        ...
571
572        Expected output
573        ---------------
574        Returns an :class:`ERFResult` whose `x_data` stores the driven focus
575        inputs and whose `y_data` stores the corresponding mean-field outputs.
576        """
577        ERF_EPS = 1e-12  # Tolerance for floating-point comparison of v_in against end
578        x_data: List[float] = []
579        y_data: List[float] = []
580        solves: List[np.ndarray] = []
581        step = (end - start) / max(step_number, 1)
582        adaptive_min_step = min(abs(step), cls.ADAPTIVE_CONTINUATION_MIN_STEP) if step != 0 else cls.ADAPTIVE_CONTINUATION_MIN_STEP
583        aborted = False
584        fallback_values = list(fallback_initials) if fallback_initials is not None else [0.02, 0.2, 0.5, 0.8, 0.98]
585
586        def solve_with_fallback_initials(
587            system: "RateSystem",
588            initial: Optional[np.ndarray],
589        ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], bool]:
590            best_solution: Optional[np.ndarray] = None
591            best_residual: Optional[np.ndarray] = None
592            best_norm = float("inf")
593            candidates: List[Optional[np.ndarray]] = [initial]
594            for seed in fallback_values:
595                if system.dim == 0:
596                    candidate = np.zeros((0,), dtype=float)
597                else:
598                    candidate = np.full((system.dim,), float(seed), dtype=float)
599                candidates.append(candidate)
600            for candidate in candidates:
601                try:
602                    solution, residual, success = system.solve(candidate)
603                except SolverConvergenceError as exc:
604                    logger.debug("Skipping v_in %.6f: %s", float(system.v_focus), str(exc))
605                    continue
606                residual_arr = np.asarray(residual, dtype=float).ravel()
607                if success:
608                    return np.asarray(solution, dtype=float).ravel(), residual_arr, True
609                if np.isfinite(residual_arr).all():
610                    residual_norm = cls._residual_norm(residual_arr)
611                    if residual_norm < best_norm:
612                        best_solution = np.asarray(solution, dtype=float).ravel()
613                        best_residual = residual_arr
614                        best_norm = residual_norm
615            return best_solution, best_residual, False
616
617        vector_values: Optional[List[float]] = None
618        if retry_step is None:
619            vector_values = []
620            if step == 0:
621                vector_values.append(float(start))
622            else:
623                cursor = float(start)
624                while cursor <= end + ERF_EPS:
625                    vector_values.append(cursor)
626                    cursor += step
627        initial_focus = float(start)
628        if vector_values:
629            initial_focus = float(vector_values[0])
630        system = cls(parameter, initial_focus, **network_kwargs)
631        current_initial = initial_guess
632        prefetched_solutions: Optional[np.ndarray] = None
633        prefix_limit = 0
634        if vector_values and system.use_jax:
635            seq_result = system.solve_sequence(vector_values, initial_guess=current_initial)
636            if seq_result is not None:
637                prefetched_solutions, success_flags = seq_result
638                failure = np.flatnonzero(~success_flags)
639                prefix_limit = int(failure[0]) if failure.size else len(vector_values)
640        if vector_values is not None:
641            idx = 0
642            while idx < len(vector_values):
643                v_in = vector_values[idx]
644                system.v_focus = float(v_in)
645                use_prefetched = prefetched_solutions is not None and idx < prefix_limit
646                if use_prefetched:
647                    solution = np.asarray(prefetched_solutions[idx], dtype=float)
648                    success = True
649                else:
650                    solution, residual, success = solve_with_fallback_initials(system, current_initial)
651                    if not success:
652                        prev_v = x_data[-1] if x_data else None
653                        if prev_v is not None:
654                            gap = float(v_in) - float(prev_v)
655                            if gap > adaptive_min_step + ERF_EPS:
656                                midpoint = float(prev_v) + 0.5 * gap
657                                vector_values.insert(idx, midpoint)
658                                prefetched_solutions = None
659                                prefix_limit = 0
660                                continue
661                        aborted = True
662                        break
663                phi_values = system.phi_numpy(solution)
664                x_data.append(system.v_focus)
665                y_data.append(system.focus_output(phi_values))
666                solves.append(solution)
667                current_initial = solution
668                idx += 1
669            completed = (not aborted) and (len(x_data) == len(vector_values))
670        else:
671            v_in = float(start)
672            next_value = v_in
673            while v_in <= end + ERF_EPS:
674                system.v_focus = float(v_in)
675                solution, residual, success = solve_with_fallback_initials(system, current_initial)
676                if not success:
677                    if retry_step is not None:
678                        v_in += retry_step
679                        next_value = v_in
680                        if solution is not None:
681                            current_initial = solution
682                        continue
683                    aborted = True
684                    break
685                phi_values = system.phi_numpy(solution)
686                x_data.append(system.v_focus)
687                y_data.append(system.focus_output(phi_values))
688                solves.append(solution)
689                current_initial = solution
690                if step == 0:
691                    next_value = float("inf")
692                    break
693                v_in = system.v_focus + step
694                next_value = v_in
695            completed = (not aborted) and (step == 0 or next_value > end + ERF_EPS)
696        return ERFResult(x_data=x_data, y_data=y_data, solves=solves, completed=completed)
697
698    @classmethod
699    def compute_fixpoints(
700        cls,
701        sweep_entry: Sequence,
702        *,
703        tol: float = 1e-3,
704        interpolation_steps: int = 10_000,
705        **network_kwargs,
706    ) -> Dict[float, Dict[str, Any]]:
707        """
708        Compute fixed points from an ERF sweep.
709        Fixed points with slope larger than 1 at the intersection with the identity line are considered unstable in the 1D map approximation.
710
711        Parameters
712        ----------
713        sweep_entry : sequence
714            Tuple (x_data, y_data, solves, parameter) as returned by generate_erf_curve.
715        tol : float, optional
716            Tolerance for detecting crossings of the identity line (x = y) and for
717            merging nearby crossings. Crossings where |x - y| <= tol are treated as
718            fixed points, and crossings within tol of each other are merged. Default
719            is 1e-3.
720        interpolation_steps : int, optional
721            Number of interpolation points used to refine the ERF before searching
722            for crossings. Larger values increase accuracy but also cost. Default
723            is 10_000.
724
725        Expected output
726        ---------------
727        Returns a dictionary keyed by fixed-point location. Each value contains
728        at least `stability`, `rates`, `residual_norm`, and `solver_success`.
729        """
730        SLOPE_STABILITY_THRESHOLD = 1.0 # |d(ERF)/dv| < 1 ⇒ stable in 1D; > 1 ⇒ unstable
731        x_data, y_data, solves, parameter = sweep_entry
732        x_interp, y_interp = interpolate_curve(x_data, y_data, steps=interpolation_steps)
733        if x_interp.size == 0 or y_interp.size == 0:
734            print("Skipping fixpoint analysis: empty ERF data.")
735            return {}
736        diff = x_interp - y_interp
737        crossings = []
738        prev_diff = diff[0]
739        for idx in range(1, len(diff)):
740            curr_diff = diff[idx]
741            cross_val = None
742            if np.abs(curr_diff) <= tol:
743                cross_val = y_interp[idx]
744            elif np.abs(prev_diff) <= tol:
745                cross_val = y_interp[idx - 1]
746            elif prev_diff * curr_diff < 0:
747                weight = prev_diff / (prev_diff - curr_diff)
748                cross_val = y_interp[idx - 1] + weight * (y_interp[idx] - y_interp[idx - 1])
749            if cross_val is not None:
750                if crossings and np.abs(cross_val - crossings[-1][0]) <= tol:
751                    crossings[-1] = (float(cross_val), idx)
752                else:
753                    crossings.append((float(cross_val), idx))
754            prev_diff = curr_diff
755        fixpoints: Dict[float, Dict[str, Any]] = {}
756        if not crossings:
757            return fixpoints
758        solves_array = [np.asarray(s, dtype=float) for s in solves]
759        v_out_old = np.asarray(y_data, dtype=float)
760        for cross_point, idx in crossings:
761            print(f"Cross-Point: {cross_point}")
762            if idx <= 0:
763                slope = np.inf
764            else:
765                slope = (y_interp[idx] - y_interp[idx - 1]) / (x_interp[idx] - x_interp[idx - 1])
766            entry: Dict[str, Any] = {
767                "stability": "unstable",
768                "rates": None,
769                "residual_norm": float("inf"),
770                "solver_success": False,
771                "slope": float(slope) if np.isfinite(slope) else float("inf"),
772                "included": False,
773            }
774            slope_unstable = not np.isfinite(slope) or slope > SLOPE_STABILITY_THRESHOLD
775            if len(solves_array) == 0:
776                entry["reason"] = "missing_erf_solution"
777                fixpoints[cross_point] = entry
778                continue
779            closest_idx = int(np.argmin(np.abs(v_out_old - cross_point)))
780            closest_idx = min(max(closest_idx, 0), len(solves_array) - 1)
781            initial = solves_array[closest_idx]
782            system = cls(parameter, cross_point, **network_kwargs)
783            try:
784                solve, residual, success = system.solve(initial)
785            except SolverConvergenceError as exc:
786                entry["solver_success"] = False
787                entry["residual_norm"] = float("inf")
788                entry["rates"] = None
789                entry["stability"] = "unstable"
790                entry["reason"] = "solver_failed"
791                entry["error"] = str(exc)
792                fixpoints[cross_point] = entry
793                continue
794            residual = np.asarray(residual, dtype=float)
795            residual_norm = float(np.linalg.norm(residual)) if residual.size else 0.0
796            if not np.isfinite(residual).all():
797                residual_norm = float("inf")
798            entry["solver_success"] = bool(success)
799            entry["residual_norm"] = residual_norm
800            if success and np.isfinite(residual).all():
801                entry["rates"] = system.full_rates_numpy(solve)
802            else:
803                entry["rates"] = None
804            if slope_unstable or not success or not np.isfinite(residual).all():
805                if not success or not np.isfinite(residual).all():
806                    print("Warning: convergence problems near cross-point")
807                stability = "unstable"
808            else:
809                jacobian = system.jacobian_numpy(solve)
810                if not np.isfinite(jacobian).all():
811                    stability = "unstable"
812                else:
813                    try:
814                        eigval = np.linalg.eigvals(jacobian)
815                    except np.linalg.LinAlgError:
816                        stability = "unstable"
817                    else:
818                        stability = "stable" if (eigval < 0).all() else "unstable"
819            entry["stability"] = stability
820            if "reason" not in entry:
821                entry["reason"] = "slope" if slope_unstable else None
822            fixpoints[cross_point] = entry
823        return fixpoints

General mean-field solver with helper utilities for ERF and fixpoints.

Subclasses implement _build_dynamics(...) and then inherit the fixed-point solver, ERF sweep, and fixpoint analysis helpers.

RateSystem( parameter: Dict, v_focus: float, *, focus_population: Union[int, Sequence[int], NoneType] = None, prefer_jax: bool = True, root_tol: float = 1e-09, max_function_evals: int = 4000, max_newton_steps: Optional[int] = 1000, **network_kwargs)
 85    def __init__(
 86        self,
 87        parameter: Dict,
 88        v_focus: float,
 89        *,
 90        focus_population: Optional[Union[int, Sequence[int]]] = None,
 91        prefer_jax: bool = True,
 92        root_tol: float = 1e-9,
 93        max_function_evals: int = 4000,
 94        max_newton_steps: Optional[int] = 1000,
 95        **network_kwargs,
 96    ) -> None:
 97        self.parameter = dict(parameter)
 98        self.v_focus = float(v_focus)
 99        self.prefer_jax = bool(prefer_jax)
100        self.root_tol = root_tol
101        self.max_function_evals = max_function_evals
102        default_steps = 256
103        if max_newton_steps is None:
104            self.max_steps = default_steps
105        else:
106            steps = int(max_newton_steps)
107            if steps <= 0:
108                raise ValueError("max_newton_steps must be positive.")
109            self.max_steps = steps
110        self.network_kwargs = dict(network_kwargs)
111        dynamics = self._build_dynamics(self.parameter, **self.network_kwargs)
112        if len(dynamics) == 4:
113            self.A, self.B, self.bias, self.tau = dynamics
114            self.C = np.zeros_like(self.B)
115        elif len(dynamics) == 5:
116            self.A, self.B, self.C, self.bias, self.tau = dynamics
117            if self.C is None:
118                self.C = np.zeros_like(self.B)
119        else:
120            raise ValueError("Expected _build_dynamics to return (A, B, bias, tau) or (A, B, C, bias, tau).")
121        self.population_count = int(self.A.shape[0])
122        if self.A.shape != self.B.shape or self.A.shape != self.C.shape:
123            raise ValueError("Connectivity mean and variance matrices must match in shape.")
124        if self.bias.shape[0] != self.population_count or self.tau.shape[0] != self.population_count:
125            raise ValueError("Bias and tau vectors must match the matrix dimensions.")
126        if focus_population is not None:
127            focus_config = focus_population
128        elif isinstance(self.parameter.get("focus_population"), (list, tuple, range)):
129            focus_config = self.parameter.get("focus_population")
130        else:
131            count = int(self.parameter.get("focus_count", 1) or 1)
132            focus_config = list(range(max(count, 1)))
133        self.focus_indices = self._resolve_focus_indices(focus_config)
134        self._initialize_group_constraints()
135        self.use_jax = self.prefer_jax and HAS_JAX and optx is not None
136        self._jax_args = None
VAR_EPS = 1e-12
RESIDUAL_ACCEPT_TOL = 0.0005
ADAPTIVE_CONTINUATION_MIN_STEP = 0.001
parameter
v_focus
prefer_jax
root_tol
max_function_evals
network_kwargs
population_count
focus_indices
use_jax
def residual_numpy(self, x: numpy.ndarray) -> numpy.ndarray:
241    def residual_numpy(self, x: np.ndarray) -> np.ndarray:
242        rates = self._full_rates_numpy(x)
243        phi = self._phi_numpy(rates)
244        residual = (phi - rates) / self.tau
245        if self.dim == 0:
246            return np.zeros((0,), dtype=float)
247        return self.residual_matrix @ residual
def solve( self, initial_guess: Optional[numpy.ndarray] = None) -> Tuple[numpy.ndarray, numpy.ndarray, bool]:
260    def solve(self, initial_guess: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, bool]:
261        """Solve for a fixed point of the reduced mean-field system.
262
263        Returns
264        -------
265        tuple[np.ndarray, np.ndarray, bool]
266            `(x, residual, success)` where `x` is the reduced solution vector.
267
268        Expected output
269        ---------------
270        `success` is `True` when the nonlinear solver converged and `residual`
271        is close to zero.
272        """
273        initial = self._coerce_initial(initial_guess)
274        if self.dim == 0:
275            residual = self.residual_numpy(initial)
276            return initial, residual, True
277        if self.use_jax:
278            try:
279                return self._solve_with_optimistix(initial)
280            except SolverConvergenceError:
281                raise
282            except (ValueError, RuntimeError, TypeError, AttributeError) as e:
283                logger.warning(
284                    "Optimistix solver failed with %s: %s. Falling back to scipy solver.",
285                    type(e).__name__,
286                    str(e),
287                )
288                self.use_jax = False
289        return self._solve_with_scipy(initial)

Solve for a fixed point of the reduced mean-field system.

Returns
  • tuple[np.ndarray, np.ndarray, bool]: (x, residual, success) where x is the reduced solution vector.
Expected output

success is True when the nonlinear solver converged and residual is close to zero.

def phi_numpy(self, x: numpy.ndarray) -> numpy.ndarray:
436    def phi_numpy(self, x: np.ndarray) -> np.ndarray:
437        """Evaluate the transfer function on the full-rate state implied by `x`.
438
439        Expected output
440        ---------------
441        Returns one activity value per population in the interval `[0, 1]`.
442        """
443        rates = self._full_rates_numpy(np.asarray(x, dtype=float))
444        return self._phi_numpy(rates)

Evaluate the transfer function on the full-rate state implied by x.

Expected output

Returns one activity value per population in the interval [0, 1].

def full_rates_numpy(self, x: numpy.ndarray) -> numpy.ndarray:
446    def full_rates_numpy(self, x: np.ndarray) -> np.ndarray:
447        """Expand a reduced solver vector into one rate per population.
448
449        Expected output
450        ---------------
451        The returned array has length `population_count`.
452        """
453        return self._full_rates_numpy(np.asarray(x, dtype=float))

Expand a reduced solver vector into one rate per population.

Expected output

The returned array has length population_count.

def jacobian_numpy(self, x: numpy.ndarray) -> numpy.ndarray:
455    def jacobian_numpy(self, x: np.ndarray) -> np.ndarray:
456        """Return the Jacobian of the full mean-field residual at `x`.
457
458        Expected output
459        ---------------
460        The returned matrix has shape `(population_count, population_count)`.
461        """
462        rates = self._full_rates_numpy(np.asarray(x, dtype=float))
463        return (self._phi_jacobian_numpy(rates) - np.eye(self.population_count)) / self.tau[:, np.newaxis]

Return the Jacobian of the full mean-field residual at x.

Expected output

The returned matrix has shape (population_count, population_count).

def focus_output(self, rates: numpy.ndarray) -> float:
465    def focus_output(self, rates: np.ndarray) -> float:
466        """Average the rates over the configured focus populations."""
467        values = np.asarray(rates, dtype=float)[self.focus_indices]
468        return float(values.mean()) if values.size else float(rates[0])

Average the rates over the configured focus populations.

def solve_sequence( self, v_focus_values: Sequence[float], initial_guess: Optional[numpy.ndarray] = None) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
502    def solve_sequence(
503        self,
504        v_focus_values: Sequence[float],
505        initial_guess: Optional[np.ndarray] = None,
506    ) -> Optional[Tuple[np.ndarray, np.ndarray]]:
507        """Solve a sequence of focus inputs with the accelerated JAX path.
508
509        Expected output
510        ---------------
511        Returns `(solutions, success_flags)` when the JAX solver is available,
512        otherwise `None`.
513        """
514        values = np.asarray(list(v_focus_values), dtype=float)
515        if values.size == 0:
516            return np.zeros((0, self.dim)), np.zeros((0,), dtype=bool)
517        if self.dim == 0:
518            zeros = np.zeros((values.size, 0), dtype=float)
519            return zeros, np.ones((values.size,), dtype=bool)
520        if not (self.use_jax and HAS_JAX and optx is not None):
521            return None
522        solver_entry = self._get_jax_solver()
523        if solver_entry is None:
524            return None
525        args = self._prepare_jax_args()
526        initial = self._coerce_initial(initial_guess)
527        x0 = jnp.asarray(initial, dtype=jnp.float64)
528        v_seq = jnp.asarray(values, dtype=jnp.float64)
529        try:
530            solutions, statuses, _ = solver_entry["scan"](x0, v_seq, args)
531        except (ValueError, RuntimeError, TypeError, AttributeError) as exc:
532            logger.warning(
533                "Optimistix sweep failed with %s: %s. Falling back to sequential solver.",
534                type(exc).__name__,
535                str(exc),
536            )
537            self.use_jax = False
538            return None
539        solution_np = np.asarray(solutions, dtype=float)
540        success = np.asarray(statuses, dtype=bool)
541        return solution_np, success

Solve a sequence of focus inputs with the accelerated JAX path.

Expected output

Returns (solutions, success_flags) when the JAX solver is available, otherwise None.

@classmethod
def generate_erf_curve( cls, parameter: Dict, *, start: float = 0.02, end: float = 1.0, step_number: int = 20, retry_step: Optional[float] = None, initial_guess: Optional[numpy.ndarray] = None, fallback_initials: Optional[Sequence[float]] = None, **network_kwargs) -> ERFResult:
544    @classmethod
545    def generate_erf_curve(
546        cls,
547        parameter: Dict,
548        *,
549        start: float = 0.02,
550        end: float = 1.0,
551        step_number: int = 20,
552        retry_step: Optional[float] = None,
553        initial_guess: Optional[np.ndarray] = None,
554        fallback_initials: Optional[Sequence[float]] = None,
555        **network_kwargs,
556    ) -> ERFResult:
557        """Compute the ERF for ``parameter``.
558
559        When ``retry_step`` is provided, the solver retries inputs separated by
560        that value whenever convergence fails. If the ERF cannot be completed
561        the ``completed`` flag is ``False`` and no serialization should happen.
562        Parameters
563        ----------
564        start : float, optional
565            Lower bound of the input range for the ERF (default: 0.02).
566        end : float, optional
567            Upper bound of the input range for the ERF (default: 1.0).
568        step_number : int, optional
569            Number of steps between ``start`` and ``end`` (default: 20).
570        ...
571
572        Expected output
573        ---------------
574        Returns an :class:`ERFResult` whose `x_data` stores the driven focus
575        inputs and whose `y_data` stores the corresponding mean-field outputs.
576        """
577        ERF_EPS = 1e-12  # Tolerance for floating-point comparison of v_in against end
578        x_data: List[float] = []
579        y_data: List[float] = []
580        solves: List[np.ndarray] = []
581        step = (end - start) / max(step_number, 1)
582        adaptive_min_step = min(abs(step), cls.ADAPTIVE_CONTINUATION_MIN_STEP) if step != 0 else cls.ADAPTIVE_CONTINUATION_MIN_STEP
583        aborted = False
584        fallback_values = list(fallback_initials) if fallback_initials is not None else [0.02, 0.2, 0.5, 0.8, 0.98]
585
586        def solve_with_fallback_initials(
587            system: "RateSystem",
588            initial: Optional[np.ndarray],
589        ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], bool]:
590            best_solution: Optional[np.ndarray] = None
591            best_residual: Optional[np.ndarray] = None
592            best_norm = float("inf")
593            candidates: List[Optional[np.ndarray]] = [initial]
594            for seed in fallback_values:
595                if system.dim == 0:
596                    candidate = np.zeros((0,), dtype=float)
597                else:
598                    candidate = np.full((system.dim,), float(seed), dtype=float)
599                candidates.append(candidate)
600            for candidate in candidates:
601                try:
602                    solution, residual, success = system.solve(candidate)
603                except SolverConvergenceError as exc:
604                    logger.debug("Skipping v_in %.6f: %s", float(system.v_focus), str(exc))
605                    continue
606                residual_arr = np.asarray(residual, dtype=float).ravel()
607                if success:
608                    return np.asarray(solution, dtype=float).ravel(), residual_arr, True
609                if np.isfinite(residual_arr).all():
610                    residual_norm = cls._residual_norm(residual_arr)
611                    if residual_norm < best_norm:
612                        best_solution = np.asarray(solution, dtype=float).ravel()
613                        best_residual = residual_arr
614                        best_norm = residual_norm
615            return best_solution, best_residual, False
616
617        vector_values: Optional[List[float]] = None
618        if retry_step is None:
619            vector_values = []
620            if step == 0:
621                vector_values.append(float(start))
622            else:
623                cursor = float(start)
624                while cursor <= end + ERF_EPS:
625                    vector_values.append(cursor)
626                    cursor += step
627        initial_focus = float(start)
628        if vector_values:
629            initial_focus = float(vector_values[0])
630        system = cls(parameter, initial_focus, **network_kwargs)
631        current_initial = initial_guess
632        prefetched_solutions: Optional[np.ndarray] = None
633        prefix_limit = 0
634        if vector_values and system.use_jax:
635            seq_result = system.solve_sequence(vector_values, initial_guess=current_initial)
636            if seq_result is not None:
637                prefetched_solutions, success_flags = seq_result
638                failure = np.flatnonzero(~success_flags)
639                prefix_limit = int(failure[0]) if failure.size else len(vector_values)
640        if vector_values is not None:
641            idx = 0
642            while idx < len(vector_values):
643                v_in = vector_values[idx]
644                system.v_focus = float(v_in)
645                use_prefetched = prefetched_solutions is not None and idx < prefix_limit
646                if use_prefetched:
647                    solution = np.asarray(prefetched_solutions[idx], dtype=float)
648                    success = True
649                else:
650                    solution, residual, success = solve_with_fallback_initials(system, current_initial)
651                    if not success:
652                        prev_v = x_data[-1] if x_data else None
653                        if prev_v is not None:
654                            gap = float(v_in) - float(prev_v)
655                            if gap > adaptive_min_step + ERF_EPS:
656                                midpoint = float(prev_v) + 0.5 * gap
657                                vector_values.insert(idx, midpoint)
658                                prefetched_solutions = None
659                                prefix_limit = 0
660                                continue
661                        aborted = True
662                        break
663                phi_values = system.phi_numpy(solution)
664                x_data.append(system.v_focus)
665                y_data.append(system.focus_output(phi_values))
666                solves.append(solution)
667                current_initial = solution
668                idx += 1
669            completed = (not aborted) and (len(x_data) == len(vector_values))
670        else:
671            v_in = float(start)
672            next_value = v_in
673            while v_in <= end + ERF_EPS:
674                system.v_focus = float(v_in)
675                solution, residual, success = solve_with_fallback_initials(system, current_initial)
676                if not success:
677                    if retry_step is not None:
678                        v_in += retry_step
679                        next_value = v_in
680                        if solution is not None:
681                            current_initial = solution
682                        continue
683                    aborted = True
684                    break
685                phi_values = system.phi_numpy(solution)
686                x_data.append(system.v_focus)
687                y_data.append(system.focus_output(phi_values))
688                solves.append(solution)
689                current_initial = solution
690                if step == 0:
691                    next_value = float("inf")
692                    break
693                v_in = system.v_focus + step
694                next_value = v_in
695            completed = (not aborted) and (step == 0 or next_value > end + ERF_EPS)
696        return ERFResult(x_data=x_data, y_data=y_data, solves=solves, completed=completed)

Compute the ERF for parameter.

When retry_step is provided, the solver retries inputs separated by that value whenever convergence fails. If the ERF cannot be completed the completed flag is False and no serialization should happen.

Parameters
  • start (float, optional): Lower bound of the input range for the ERF (default: 0.02).
  • end (float, optional): Upper bound of the input range for the ERF (default: 1.0).
  • step_number (int, optional): Number of steps between start and end (default: 20).
  • ...
Expected output

Returns an ERFResult whose x_data stores the driven focus inputs and whose y_data stores the corresponding mean-field outputs.

@classmethod
def compute_fixpoints( cls, sweep_entry: Sequence, *, tol: float = 0.001, interpolation_steps: int = 10000, **network_kwargs) -> Dict[float, Dict[str, Any]]:
698    @classmethod
699    def compute_fixpoints(
700        cls,
701        sweep_entry: Sequence,
702        *,
703        tol: float = 1e-3,
704        interpolation_steps: int = 10_000,
705        **network_kwargs,
706    ) -> Dict[float, Dict[str, Any]]:
707        """
708        Compute fixed points from an ERF sweep.
709        Fixed points with slope larger than 1 at the intersection with the identity line are considered unstable in the 1D map approximation.
710
711        Parameters
712        ----------
713        sweep_entry : sequence
714            Tuple (x_data, y_data, solves, parameter) as returned by generate_erf_curve.
715        tol : float, optional
716            Tolerance for detecting crossings of the identity line (x = y) and for
717            merging nearby crossings. Crossings where |x - y| <= tol are treated as
718            fixed points, and crossings within tol of each other are merged. Default
719            is 1e-3.
720        interpolation_steps : int, optional
721            Number of interpolation points used to refine the ERF before searching
722            for crossings. Larger values increase accuracy but also cost. Default
723            is 10_000.
724
725        Expected output
726        ---------------
727        Returns a dictionary keyed by fixed-point location. Each value contains
728        at least `stability`, `rates`, `residual_norm`, and `solver_success`.
729        """
730        SLOPE_STABILITY_THRESHOLD = 1.0 # |d(ERF)/dv| < 1 ⇒ stable in 1D; > 1 ⇒ unstable
731        x_data, y_data, solves, parameter = sweep_entry
732        x_interp, y_interp = interpolate_curve(x_data, y_data, steps=interpolation_steps)
733        if x_interp.size == 0 or y_interp.size == 0:
734            print("Skipping fixpoint analysis: empty ERF data.")
735            return {}
736        diff = x_interp - y_interp
737        crossings = []
738        prev_diff = diff[0]
739        for idx in range(1, len(diff)):
740            curr_diff = diff[idx]
741            cross_val = None
742            if np.abs(curr_diff) <= tol:
743                cross_val = y_interp[idx]
744            elif np.abs(prev_diff) <= tol:
745                cross_val = y_interp[idx - 1]
746            elif prev_diff * curr_diff < 0:
747                weight = prev_diff / (prev_diff - curr_diff)
748                cross_val = y_interp[idx - 1] + weight * (y_interp[idx] - y_interp[idx - 1])
749            if cross_val is not None:
750                if crossings and np.abs(cross_val - crossings[-1][0]) <= tol:
751                    crossings[-1] = (float(cross_val), idx)
752                else:
753                    crossings.append((float(cross_val), idx))
754            prev_diff = curr_diff
755        fixpoints: Dict[float, Dict[str, Any]] = {}
756        if not crossings:
757            return fixpoints
758        solves_array = [np.asarray(s, dtype=float) for s in solves]
759        v_out_old = np.asarray(y_data, dtype=float)
760        for cross_point, idx in crossings:
761            print(f"Cross-Point: {cross_point}")
762            if idx <= 0:
763                slope = np.inf
764            else:
765                slope = (y_interp[idx] - y_interp[idx - 1]) / (x_interp[idx] - x_interp[idx - 1])
766            entry: Dict[str, Any] = {
767                "stability": "unstable",
768                "rates": None,
769                "residual_norm": float("inf"),
770                "solver_success": False,
771                "slope": float(slope) if np.isfinite(slope) else float("inf"),
772                "included": False,
773            }
774            slope_unstable = not np.isfinite(slope) or slope > SLOPE_STABILITY_THRESHOLD
775            if len(solves_array) == 0:
776                entry["reason"] = "missing_erf_solution"
777                fixpoints[cross_point] = entry
778                continue
779            closest_idx = int(np.argmin(np.abs(v_out_old - cross_point)))
780            closest_idx = min(max(closest_idx, 0), len(solves_array) - 1)
781            initial = solves_array[closest_idx]
782            system = cls(parameter, cross_point, **network_kwargs)
783            try:
784                solve, residual, success = system.solve(initial)
785            except SolverConvergenceError as exc:
786                entry["solver_success"] = False
787                entry["residual_norm"] = float("inf")
788                entry["rates"] = None
789                entry["stability"] = "unstable"
790                entry["reason"] = "solver_failed"
791                entry["error"] = str(exc)
792                fixpoints[cross_point] = entry
793                continue
794            residual = np.asarray(residual, dtype=float)
795            residual_norm = float(np.linalg.norm(residual)) if residual.size else 0.0
796            if not np.isfinite(residual).all():
797                residual_norm = float("inf")
798            entry["solver_success"] = bool(success)
799            entry["residual_norm"] = residual_norm
800            if success and np.isfinite(residual).all():
801                entry["rates"] = system.full_rates_numpy(solve)
802            else:
803                entry["rates"] = None
804            if slope_unstable or not success or not np.isfinite(residual).all():
805                if not success or not np.isfinite(residual).all():
806                    print("Warning: convergence problems near cross-point")
807                stability = "unstable"
808            else:
809                jacobian = system.jacobian_numpy(solve)
810                if not np.isfinite(jacobian).all():
811                    stability = "unstable"
812                else:
813                    try:
814                        eigval = np.linalg.eigvals(jacobian)
815                    except np.linalg.LinAlgError:
816                        stability = "unstable"
817                    else:
818                        stability = "stable" if (eigval < 0).all() else "unstable"
819            entry["stability"] = stability
820            if "reason" not in entry:
821                entry["reason"] = "slope" if slope_unstable else None
822            fixpoints[cross_point] = entry
823        return fixpoints

Compute fixed points from an ERF sweep. Fixed points with slope larger than 1 at the intersection with the identity line are considered unstable in the 1D map approximation.

Parameters
  • sweep_entry (sequence): Tuple (x_data, y_data, solves, parameter) as returned by generate_erf_curve.
  • tol (float, optional): Tolerance for detecting crossings of the identity line (x = y) and for merging nearby crossings. Crossings where |x - y| <= tol are treated as fixed points, and crossings within tol of each other are merged. Default is 1e-3.
  • interpolation_steps (int, optional): Number of interpolation points used to refine the ERF before searching for crossings. Larger values increase accuracy but also cost. Default is 10_000.
Expected output

Returns a dictionary keyed by fixed-point location. Each value contains at least stability, rates, residual_norm, and solver_success.

@dataclass
class ERFResult:
45@dataclass
46class ERFResult:
47    """Container for an event-rate-function sweep.
48
49    Examples
50    --------
51    >>> result = ERFResult(x_data=[0.1, 0.2], y_data=[0.12, 0.22], solves=[], completed=True)
52    >>> result.completed
53    True
54    """
55    x_data: List[float]
56    y_data: List[float]
57    solves: List[np.ndarray]
58    completed: bool

Container for an event-rate-function sweep.

Examples
>>> result = ERFResult(x_data=[0.1, 0.2], y_data=[0.12, 0.22], solves=[], completed=True)
>>> result.completed
True
ERFResult( x_data: List[float], y_data: List[float], solves: List[numpy.ndarray], completed: bool)
x_data: List[float]
y_data: List[float]
solves: List[numpy.ndarray]
completed: bool
class EIClusterNetwork(MeanField.RateSystem):
 20class EIClusterNetwork(RateSystem):
 21    """Specialized mean-field system for the clustered E/I network.
 22
 23    Examples
 24    --------
 25    >>> parameter = {
 26    ...     "Q": 2,
 27    ...     "N_E": 4000,
 28    ...     "N_I": 1000,
 29    ...     "V_th": 1.0,
 30    ...     "g": 1.0,
 31    ...     "p0_ee": 0.2,
 32    ...     "p0_ei": 0.2,
 33    ...     "p0_ie": 0.2,
 34    ...     "p0_ii": 0.2,
 35    ...     "R_Eplus": 1.0,
 36    ...     "R_j": 0.8,
 37    ...     "m_X": 0.1,
 38    ...     "tau_e": 1.0,
 39    ...     "tau_i": 2.0,
 40    ... }
 41    >>> system = EIClusterNetwork(parameter, v_focus=0.2, prefer_jax=False)
 42    >>> system.population_count
 43    4
 44    """
 45
 46    def __init__(
 47        self,
 48        parameter: Dict,
 49        v_focus: float,
 50        *,
 51        kappa: Optional[float] = None,
 52        connection_type: Optional[str] = None,
 53        use_temporal_variance: bool = True,
 54        use_quadratic_variance: bool = True,
 55        focus_population=None,
 56        prefer_jax: bool = True,
 57        max_steps: int = 256,
 58    ) -> None:
 59        self.Q = int(parameter["Q"])
 60        self._explicit_kappa = kappa
 61        self._explicit_connection = connection_type
 62        self.collapse_types = bool(parameter.get("collapse_types", True))
 63        super().__init__(
 64            parameter,
 65            v_focus,
 66            focus_population=focus_population,
 67            prefer_jax=prefer_jax,
 68            max_steps=max_steps,
 69            kappa=kappa,
 70            connection_type=connection_type,
 71            use_temporal_variance=use_temporal_variance,
 72            use_quadratic_variance=use_quadratic_variance,
 73        )
 74
 75    def _build_dynamics(
 76        self, parameter: Dict, **network_kwargs
 77    ) -> Union[
 78        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
 79        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
 80    ]:
 81        kappa_value = network_kwargs.get("kappa")
 82        if kappa_value is None:
 83            kappa_value = parameter.get("kappa", 0.0)
 84        conn_kind = network_kwargs.get("connection_type") or parameter.get("connection_type", "bernoulli")
 85        use_temporal_variance = bool(network_kwargs.get("use_temporal_variance", True))
 86        use_quadratic_variance = bool(network_kwargs.get("use_quadratic_variance", True))
 87
 88        connectivity = self._build_connectivity(parameter, float(kappa_value), str(conn_kind), use_temporal_variance, use_quadratic_variance)
 89        tau = np.ones(2 * self.Q, dtype=float)
 90        tau[: self.Q] *= parameter["tau_e"]
 91        tau[self.Q :] *= parameter["tau_i"]
 92        if use_temporal_variance:
 93            A, B, C, bias = connectivity
 94            return A, B, C, bias, tau
 95        A, B, bias = connectivity
 96        return A, B, bias, tau
 97
 98    def _build_population_groups(self, focus: np.ndarray) -> List[np.ndarray]:
 99        focus_set = set(int(idx) for idx in focus.tolist())
100        if not focus_set:
101            focus_set = {0}
102        if not self.collapse_types:
103            return self._build_full_focus_groups(focus_set)
104        return self._build_collapsed_groups(focus_set)
105
106    def _build_full_focus_groups(self, focus_set: set[int]) -> List[np.ndarray]:
107        groups: List[np.ndarray] = [np.array(sorted(focus_set), dtype=int)]
108        excit_focus = sorted(idx for idx in focus_set if 0 <= idx < self.Q)
109        paired_inhib = sorted(
110            {
111                idx + self.Q
112                for idx in excit_focus
113                if idx + self.Q < 2 * self.Q and (idx + self.Q) not in focus_set
114            }
115        )
116        if paired_inhib:
117            groups.append(np.array(paired_inhib, dtype=int))
118        remaining_excit = [idx for idx in range(self.Q) if idx not in focus_set]
119        for idx in remaining_excit:
120            groups.append(np.array([idx], dtype=int))
121        remaining_inhib = [
122            idx for idx in range(self.Q, 2 * self.Q) if idx not in focus_set and idx not in paired_inhib
123        ]
124        for idx in remaining_inhib:
125            groups.append(np.array([idx], dtype=int))
126        return groups
127
128    def _build_collapsed_groups(self, focus_set: set[int]) -> List[np.ndarray]:
129        groups: List[np.ndarray] = [np.array(sorted(focus_set), dtype=int)]
130        paired_inhib = sorted(
131            idx + self.Q
132            for idx in focus_set
133            if idx < self.Q and (idx + self.Q) not in focus_set
134        )
135        paired_excit = sorted(
136            idx - self.Q
137            for idx in focus_set
138            if idx >= self.Q and (idx - self.Q) not in focus_set
139        )
140        other_excit = sorted(idx for idx in range(self.Q) if idx not in focus_set and idx not in paired_excit)
141        other_inhib = sorted(
142            idx for idx in range(self.Q, 2 * self.Q) if idx not in focus_set and idx not in paired_inhib
143        )
144        if other_excit:
145            groups.append(np.array(other_excit, dtype=int))
146        if paired_inhib:
147            groups.append(np.array(paired_inhib, dtype=int))
148        if other_inhib:
149            groups.append(np.array(other_inhib, dtype=int))
150        if paired_excit:
151            groups.append(np.array(paired_excit, dtype=int))
152        return groups
153
154    def _build_connectivity(
155        self, parameter: Dict, kappa: float, connection_type: str, use_temporal_variance: bool, use_quadratic_variance: bool,
156    ) -> Union[
157        Tuple[np.ndarray, np.ndarray, np.ndarray],
158        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
159    ]:
160        N_E = parameter["N_E"]
161        N_I = parameter["N_I"]
162        N = _total_population(parameter)
163        V_th = parameter["V_th"]
164        g = parameter["g"]
165        p0_ee = parameter["p0_ee"]
166        p0_ie = parameter["p0_ie"]
167        p0_ei = parameter["p0_ei"]
168        p0_ii = parameter["p0_ii"]
169        m_X = parameter["m_X"]
170        R_Eplus = parameter["R_Eplus"]
171        R_j = parameter["R_j"]
172
173        n_er = N_E / N
174        n_ir = N_I / N
175        n_e = N_E / self.Q
176        n_i = N_I / self.Q
177
178        theta_E = V_th
179        theta_I = V_th
180        V_th_vec = np.array([theta_E] * self.Q + [theta_I] * self.Q, dtype=float)
181
182        R_Iplus = 1 + R_j * (R_Eplus - 1)
183
184        j_EE = theta_E / math.sqrt(p0_ee * n_er)
185        j_IE = theta_I / math.sqrt(p0_ie * n_er)
186        j_EI = -g * j_EE * p0_ee * n_er / (p0_ei * n_ir)
187        j_II = -j_IE * p0_ie * n_er / (p0_ii * n_ir)
188
189        scale = 1.0 / math.sqrt(N)
190        j_EE *= scale
191        j_IE *= scale
192        j_EI *= scale
193        j_II *= scale
194
195        def mix_scales(R_plus: float) -> tuple[float, float, float, float]:
196            """
197            Compute scaling factors for in-cluster and out-of-cluster connections.
198            
199            When Q > 1: Distributes connection strengths between in-cluster (prob_in, weight_in) 
200            and out-of-cluster (prob_out, weight_out) based on the clustering strength R_plus.
201            
202            When Q = 1: Only one cluster exists, so there are no out-of-cluster connections.
203            In this case, prob_out and weight_out are set equal to prob_in and weight_in,
204            which ensures that all connection parameters use the same scaling without 
205            division by zero. The structured matrix builder will handle the single-cluster 
206            case appropriately.
207            """
208            prob_in = R_plus ** (1.0 - kappa)
209            weight_in = R_plus ** kappa
210            if self.Q == 1:
211                # With a single cluster, all connections are in-cluster
212                prob_out = prob_in
213                weight_out = weight_in
214            else:
215                prob_out = (self.Q - prob_in) / (self.Q - 1)
216                weight_out = (self.Q - weight_in) / (self.Q - 1)
217            return prob_in, prob_out, weight_in, weight_out
218
219        P_scale_in_E, P_scale_out_E, J_scale_in_E, J_scale_out_E = mix_scales(R_Eplus)
220        P_scale_in_I, P_scale_out_I, J_scale_in_I, J_scale_out_I = mix_scales(R_Iplus)
221
222        P_EE = p0_ee * P_scale_in_E
223        p_ee = p0_ee * P_scale_out_E
224        P_IE = p0_ie * P_scale_in_I
225        p_ie = p0_ie * P_scale_out_I
226        P_EI = p0_ei * P_scale_in_I
227        p_ei = p0_ei * P_scale_out_I
228        P_II = p0_ii * P_scale_in_I
229        p_ii = p0_ii * P_scale_out_I
230
231        J_EE = j_EE * J_scale_in_E
232        j_ee = j_EE * J_scale_out_E
233        J_IE = j_IE * J_scale_in_I
234        j_ie = j_IE * J_scale_out_I
235        J_EI = j_EI * J_scale_in_I
236        j_ei = j_EI * J_scale_out_I
237        J_II = j_II * J_scale_in_I
238        j_ii = j_II * J_scale_out_I
239
240        EE_IN = J_EE * P_EE * n_e
241        EE_OUT = j_ee * p_ee * n_e
242        IE_IN = J_IE * P_IE * n_e
243        IE_OUT = j_ie * p_ie * n_e
244        EI_IN = J_EI * P_EI * n_i
245        EI_OUT = j_ei * p_ei * n_i
246        II_IN = J_II * P_II * n_i
247        II_OUT = j_ii * p_ii * n_i
248
249        mean_values = dict(EE_IN=EE_IN, EE_OUT=EE_OUT, IE_IN=IE_IN, IE_OUT=IE_OUT, EI_IN=EI_IN, EI_OUT=EI_OUT,
250                           II_IN=II_IN, II_OUT=II_OUT)
251        A = self._structured_matrix(mean_values)
252
253        var_coeffs = dict(
254            EE_IN=self._variance_coeffs(P_EE, J_EE, n_e, connection_type, use_temporal_variance, use_quadratic_variance),
255            EE_OUT=self._variance_coeffs(p_ee, j_ee, n_e, connection_type, use_temporal_variance, use_quadratic_variance),
256            IE_IN=self._variance_coeffs(P_IE, J_IE, n_e, connection_type, use_temporal_variance, use_quadratic_variance),
257            IE_OUT=self._variance_coeffs(p_ie, j_ie, n_e, connection_type, use_temporal_variance, use_quadratic_variance),
258            EI_IN=self._variance_coeffs(P_EI, J_EI, n_i, connection_type, use_temporal_variance, use_quadratic_variance),
259            EI_OUT=self._variance_coeffs(p_ei, j_ei, n_i, connection_type, use_temporal_variance, use_quadratic_variance),
260            II_IN=self._variance_coeffs(P_II, J_II, n_i, connection_type, use_temporal_variance, use_quadratic_variance),
261            II_OUT=self._variance_coeffs(p_ii, j_ii, n_i, connection_type, use_temporal_variance, use_quadratic_variance),
262        )
263        b_values = {key: coeffs[0] for key, coeffs in var_coeffs.items()}
264        B = self._structured_matrix(b_values)
265        if use_temporal_variance:
266            # C encodes the temporal m(1-m) gating term.
267            c_values = {key: coeffs[1] for key, coeffs in var_coeffs.items()}
268            C = self._structured_matrix(c_values)
269
270        J_EX = math.sqrt(p0_ee * N_E)
271        J_IX = 0.8 * J_EX
272        u_extE = J_EX * m_X
273        u_extI = J_IX * m_X
274        u_ext = np.array([u_extE] * self.Q + [u_extI] * self.Q, dtype=float)
275        bias = u_ext - V_th_vec
276        if use_temporal_variance:
277            return A, B, C, bias
278        return A, B, bias
279
280    def _structured_matrix(self, values: Dict[str, float]) -> np.ndarray:
281        size = 2 * self.Q
282        matrix = np.zeros((size, size), dtype=float)
283        for target in range(size):
284            tgt_type = "E" if target < self.Q else "I"
285            tgt_cluster = target % self.Q
286            for source in range(size):
287                src_type = "E" if source < self.Q else "I"
288                src_cluster = source % self.Q
289                suffix = "_IN" if tgt_cluster == src_cluster else "_OUT"
290                key = f"{tgt_type}{src_type}{suffix}"
291                matrix[target, source] = values[key]
292        return matrix
293
294    @staticmethod
295    def _variance(prob: float, weight: float, population: float, connection_type: str) -> float:
296        conn_kind = connection_type.lower()
297        if conn_kind == "poisson":
298            return prob * weight ** 2 * population
299        if conn_kind == "fixed-indegree":
300            return prob * (1 - (1 / population)) * weight ** 2 * population
301        # else: Bernoulli
302        return prob * (1 - prob) * weight ** 2 * population
303
304    @staticmethod
305    def _variance_coeffs(
306        prob: float,
307        weight: float,
308        population: float,
309        connection_type: str,
310        use_temporal_variance: bool,
311        use_quadratic_variance: bool,
312    ) -> tuple[float, float]:
313        #if not use_temporal_variance:
314        #    return EIClusterNetwork._variance(prob, weight, population, connection_type), 0.0
315        conn_kind = connection_type.lower()
316        if conn_kind == "poisson":
317            if use_quadratic_variance:
318                b= (prob + prob ** 2) * weight ** 2 * population
319            else:
320                b = (prob) * weight ** 2 * population
321            c = -(prob ** 2) * weight ** 2 * population
322            return b, c
323        if conn_kind == "fixed-indegree":
324            k_eff = (prob * population) * (1 - (1 / population))  # finite-size correction for fixed indegree
325            b = k_eff * weight ** 2
326            c = -k_eff * weight ** 2
327            return b, c
328        # else: Bernoulli
329        b = prob * weight ** 2 * population
330        c = -(prob ** 2) * weight ** 2 * population
331        return b, c

Specialized mean-field system for the clustered E/I network.

Examples
>>> parameter = {
...     "Q": 2,
...     "N_E": 4000,
...     "N_I": 1000,
...     "V_th": 1.0,
...     "g": 1.0,
...     "p0_ee": 0.2,
...     "p0_ei": 0.2,
...     "p0_ie": 0.2,
...     "p0_ii": 0.2,
...     "R_Eplus": 1.0,
...     "R_j": 0.8,
...     "m_X": 0.1,
...     "tau_e": 1.0,
...     "tau_i": 2.0,
... }
>>> system = EIClusterNetwork(parameter, v_focus=0.2, prefer_jax=False)
>>> system.population_count
4
EIClusterNetwork( parameter: Dict, v_focus: float, *, kappa: Optional[float] = None, connection_type: Optional[str] = None, use_temporal_variance: bool = True, use_quadratic_variance: bool = True, focus_population=None, prefer_jax: bool = True, max_steps: int = 256)
46    def __init__(
47        self,
48        parameter: Dict,
49        v_focus: float,
50        *,
51        kappa: Optional[float] = None,
52        connection_type: Optional[str] = None,
53        use_temporal_variance: bool = True,
54        use_quadratic_variance: bool = True,
55        focus_population=None,
56        prefer_jax: bool = True,
57        max_steps: int = 256,
58    ) -> None:
59        self.Q = int(parameter["Q"])
60        self._explicit_kappa = kappa
61        self._explicit_connection = connection_type
62        self.collapse_types = bool(parameter.get("collapse_types", True))
63        super().__init__(
64            parameter,
65            v_focus,
66            focus_population=focus_population,
67            prefer_jax=prefer_jax,
68            max_steps=max_steps,
69            kappa=kappa,
70            connection_type=connection_type,
71            use_temporal_variance=use_temporal_variance,
72            use_quadratic_variance=use_quadratic_variance,
73        )
Q
collapse_types
def ensure_output_folder(parameter: Dict, *, tag: Optional[str] = None) -> str:
892def ensure_output_folder(parameter: Dict, *, tag: Optional[str] = None) -> str:
893    """Create and return the cache folder for a mean-field parameter set.
894
895    Examples
896    --------
897    >>> folder = ensure_output_folder({"connection_type": "bernoulli", "R_j": 0.8, "Q": 2})
898    >>> folder.startswith("data/Bernoulli/Rj00_80/")
899    True
900    """
901    conn_name = str(parameter.get("connection_type", "bernoulli")).strip()
902    conn_label = conn_name.capitalize()
903    r_j = float(parameter.get("R_j", 0.0))
904    rj_label = f"Rj{r_j:05.2f}".replace(".", "_")
905    if tag is None:
906        filtered = {k: v for k, v in parameter.items() if k != "R_Eplus"}
907        tag = sim_tag_from_cfg(filtered)
908    folder = os.path.join("data", conn_label, rj_label, tag)
909    os.makedirs(folder, exist_ok=True)
910    return folder

Create and return the cache folder for a mean-field parameter set.

Examples
>>> folder = ensure_output_folder({"connection_type": "bernoulli", "R_j": 0.8, "Q": 2})
>>> folder.startswith("data/Bernoulli/Rj00_80/")
True
def serialize_erf( file_path: str, parameter: Dict, result: ERFResult, *, focus_count: Optional[int] = None) -> Optional[str]:
913def serialize_erf(
914    file_path: str,
915    parameter: Dict,
916    result: ERFResult,
917    *,
918    focus_count: Optional[int] = None,
919) -> Optional[str]:
920    """Serialize a completed ERF sweep to a pickle file.
921
922    Expected output
923    ---------------
924    Returns `file_path` on success and `None` when `result.completed` is
925    `False`.
926    """
927    if not result.completed:
928        return None
929    R_value = float(parameter["R_Eplus"])
930    focus_value = focus_count if focus_count is not None else parameter.get("focus_count", 1)
931    focus_value = 1 if focus_value is None else int(focus_value)
932    key = f"{R_value:.12g}_focus{focus_value}"
933    payload = {key: [result.x_data, result.y_data, result.solves, parameter]}
934    os.makedirs(os.path.dirname(file_path), exist_ok=True)
935    with open(file_path, "wb") as file:
936        import pickle
937
938        pickle.dump(payload, file)
939    return file_path

Serialize a completed ERF sweep to a pickle file.

Expected output

Returns file_path on success and None when result.completed is False.

def aggregate_data(folder: str) -> str:
942def aggregate_data(folder: str) -> str:
943    """Merge individual ERF pickle files into one combined pickle.
944
945    Expected output
946    ---------------
947    Returns the path to `all_data_P_Eplus.pkl` inside `folder`.
948    """
949    import glob
950    import pickle
951
952    list_dir = sorted(
953        f for f in glob.glob(f"{folder}/*.pkl") if os.path.basename(f) != "all_data_P_Eplus.pkl"
954    )
955    if not list_dir:
956        raise FileNotFoundError(f"No .pkl files found in {folder}")
957    with open(list_dir[0], "rb") as file:
958        all_files = pickle.load(file)
959    for name in list_dir[1:]:
960        with open(name, "rb") as file:
961            data = pickle.load(file)
962        all_files.update(data)
963    name = "all_data_P_Eplus.pkl"
964    path_sum = os.path.join(folder, name)
965    with open(path_sum, "wb") as file:
966        pickle.dump(all_files, file)
967    return path_sum

Merge individual ERF pickle files into one combined pickle.

Expected output

Returns the path to all_data_P_Eplus.pkl inside folder.