MeanField
Mean-field solvers and EI-cluster specializations.
The package exposes a generic fixed-point solver in RateSystem and a concrete
clustered E/I specialization in EIClusterNetwork.
Minimal example:
from MeanField import EIClusterNetwork
system = EIClusterNetwork(parameter, v_focus=0.2)
x, residual, success = system.solve()
rates = system.full_rates_numpy(x)
The helpers in rate_system are used by the figure pipelines to trace
event-rate functions, store fixpoint bundles, and reuse cached results.
Regenerating docs:
python scripts/generate_api_docs.py
1"""Mean-field solvers and EI-cluster specializations. 2 3The package exposes a generic fixed-point solver in `RateSystem` and a concrete 4clustered E/I specialization in `EIClusterNetwork`. 5 6Minimal example: 7 8```python 9from MeanField import EIClusterNetwork 10 11system = EIClusterNetwork(parameter, v_focus=0.2) 12x, residual, success = system.solve() 13rates = system.full_rates_numpy(x) 14``` 15 16The helpers in `rate_system` are used by the figure pipelines to trace 17event-rate functions, store fixpoint bundles, and reuse cached results. 18 19Regenerating docs: 20 21```bash 22python scripts/generate_api_docs.py 23``` 24""" 25 26from .rate_system import ( 27 ERFResult, 28 RateSystem, 29 aggregate_data, 30 ensure_output_folder, 31 serialize_erf, 32) 33from .ei_cluster_network import EIClusterNetwork 34 35__all__ = [ 36 "RateSystem", 37 "ERFResult", 38 "EIClusterNetwork", 39 "ensure_output_folder", 40 "serialize_erf", 41 "aggregate_data", 42]
71class RateSystem: 72 """General mean-field solver with helper utilities for ERF and fixpoints. 73 74 Subclasses implement `_build_dynamics(...)` and then inherit the fixed-point 75 solver, ERF sweep, and fixpoint analysis helpers. 76 """ 77 # Minimum allowed variance to prevent numerical instabilities. 78 # Chosen small enough to avoid affecting dynamics, only avoids divide-by-zero. 79 VAR_EPS = 1e-12 80 # Accept near-roots even when the underlying SciPy solver does not set success=True. 81 RESIDUAL_ACCEPT_TOL = 5e-4 82 # Smallest continuation step introduced when adaptively subdividing the ERF sweep. 83 ADAPTIVE_CONTINUATION_MIN_STEP = 1e-3 84 85 def __init__( 86 self, 87 parameter: Dict, 88 v_focus: float, 89 *, 90 focus_population: Optional[Union[int, Sequence[int]]] = None, 91 prefer_jax: bool = True, 92 root_tol: float = 1e-9, 93 max_function_evals: int = 4000, 94 max_newton_steps: Optional[int] = 1000, 95 **network_kwargs, 96 ) -> None: 97 self.parameter = dict(parameter) 98 self.v_focus = float(v_focus) 99 self.prefer_jax = bool(prefer_jax) 100 self.root_tol = root_tol 101 self.max_function_evals = max_function_evals 102 default_steps = 256 103 if max_newton_steps is None: 104 self.max_steps = default_steps 105 else: 106 steps = int(max_newton_steps) 107 if steps <= 0: 108 raise ValueError("max_newton_steps must be positive.") 109 self.max_steps = steps 110 self.network_kwargs = dict(network_kwargs) 111 dynamics = self._build_dynamics(self.parameter, **self.network_kwargs) 112 if len(dynamics) == 4: 113 self.A, self.B, self.bias, self.tau = dynamics 114 self.C = np.zeros_like(self.B) 115 elif len(dynamics) == 5: 116 self.A, self.B, self.C, self.bias, self.tau = dynamics 117 if self.C is None: 118 self.C = np.zeros_like(self.B) 119 else: 120 raise ValueError("Expected _build_dynamics to return (A, B, bias, tau) or (A, B, C, bias, tau).") 121 self.population_count = int(self.A.shape[0]) 122 if self.A.shape != self.B.shape or self.A.shape != self.C.shape: 123 raise ValueError("Connectivity mean and variance matrices must match in shape.") 124 if self.bias.shape[0] != self.population_count or self.tau.shape[0] != self.population_count: 125 raise ValueError("Bias and tau vectors must match the matrix dimensions.") 126 if focus_population is not None: 127 focus_config = focus_population 128 elif isinstance(self.parameter.get("focus_population"), (list, tuple, range)): 129 focus_config = self.parameter.get("focus_population") 130 else: 131 count = int(self.parameter.get("focus_count", 1) or 1) 132 focus_config = list(range(max(count, 1))) 133 self.focus_indices = self._resolve_focus_indices(focus_config) 134 self._initialize_group_constraints() 135 self.use_jax = self.prefer_jax and HAS_JAX and optx is not None 136 self._jax_args = None 137 138 # --- abstract hooks ------------------------------------------------- 139 def _build_dynamics( 140 self, parameter: Dict, **network_kwargs 141 ) -> Union[ 142 Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], 143 Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], 144 ]: 145 raise NotImplementedError("Derived classes must implement '_build_dynamics'.") 146 147 def _build_population_groups(self, focus: np.ndarray) -> List[np.ndarray]: 148 focus_set = set(int(idx) for idx in focus.tolist()) 149 if not focus_set: 150 focus_set = {0} 151 groups: List[np.ndarray] = [np.array(sorted(focus_set), dtype=int)] 152 remaining = [idx for idx in range(self.population_count) if idx not in focus_set] 153 for idx in remaining: 154 groups.append(np.array([idx], dtype=int)) 155 return groups 156 157 # --- solver core ---------------------------------------------------- 158 def _initialize_group_constraints(self) -> None: 159 groups = self._build_population_groups(self.focus_indices) 160 covered = np.concatenate(groups) 161 if covered.size != self.population_count or not np.all(np.sort(covered) == np.arange(self.population_count)): 162 raise ValueError("Population groups must cover each population exactly once.") 163 self.groups: List[np.ndarray] = groups 164 self.group_count = len(groups) 165 self.focus_group_index = 0 166 self.solve_groups = [idx for idx in range(self.group_count) if idx != self.focus_group_index] 167 self.dim = len(self.solve_groups) 168 self.focus_population_mask = np.zeros(self.population_count, dtype=bool) 169 self.focus_population_mask[self.focus_indices] = True 170 self.group_membership = np.zeros((self.population_count, self.group_count), dtype=float) 171 for idx, members in enumerate(self.groups): 172 self.group_membership[members, idx] = 1.0 173 self.group_sizes = np.array([len(members) for members in self.groups], dtype=float) 174 self.group_inverse_sizes = np.reciprocal(self.group_sizes) 175 self.focus_vector = np.zeros(self.group_count, dtype=float) 176 self.focus_vector[self.focus_group_index] = 1.0 177 self.selector_matrix = np.zeros((self.group_count, self.dim), dtype=float) 178 for col, group_idx in enumerate(self.solve_groups): 179 self.selector_matrix[group_idx, col] = 1.0 180 self.residual_matrix = np.zeros((self.dim, self.population_count), dtype=float) 181 for row, group_idx in enumerate(self.solve_groups): 182 members = self.groups[group_idx] 183 self.residual_matrix[row, members] = 1.0 / len(members) 184 185 def _resolve_focus_indices(self, focus_config) -> np.ndarray: 186 if focus_config is None: 187 entries = [0] 188 elif isinstance(focus_config, (int, np.integer)): 189 entries = [int(focus_config)] 190 else: 191 entries = [] 192 for value in focus_config: 193 if isinstance(value, slice): 194 start = 0 if value.start is None else value.start 195 stop = self.population_count if value.stop is None else value.stop 196 step = 1 if value.step is None else value.step 197 entries.extend(range(start, stop, step)) 198 else: 199 entries.append(int(value)) 200 focus = sorted(set(entries)) 201 if not focus: 202 focus = [0] 203 for idx in focus: 204 if idx < 0 or idx >= self.population_count: 205 raise ValueError(f"Focus index {idx} out of bounds for population size {self.population_count}.") 206 return np.asarray(focus, dtype=int) 207 208 def _full_rates_numpy(self, x: np.ndarray) -> np.ndarray: 209 arr = np.asarray(x, dtype=float).ravel() 210 if arr.size == self.population_count: 211 return arr 212 if arr.size == self.group_count: 213 group_values = arr.copy() 214 group_values[self.focus_group_index] = self.v_focus 215 else: 216 if arr.size != self.dim: 217 raise ValueError( 218 f"Expected vector of length {self.dim}, {self.group_count}, or {self.population_count}, got {arr.size}." 219 ) 220 group_values = self.selector_matrix @ arr 221 group_values += self.focus_vector * self.v_focus 222 return self.group_membership @ group_values 223 224 def _phi_numpy(self, full_rates: np.ndarray) -> np.ndarray: 225 mean = self.A.dot(full_rates) + self.bias 226 rates_sq = full_rates * full_rates 227 var = np.maximum(self.B.dot(full_rates) + self.C.dot(rates_sq), self.VAR_EPS) 228 return 0.5 * (1 - special.erf(-mean / np.sqrt(2.0 * var))) 229 230 def _phi_jacobian_numpy(self, full_rates: np.ndarray) -> np.ndarray: 231 mean = self.A.dot(full_rates) + self.bias 232 rates_sq = full_rates * full_rates 233 var = np.maximum(self.B.dot(full_rates) + self.C.dot(rates_sq), self.VAR_EPS) 234 inv_sqrt = 1.0 / np.sqrt(2.0 * var) 235 exp_term = np.exp(-(mean ** 2) / (2.0 * var)) 236 coeff = (1.0 / np.sqrt(np.pi)) * exp_term * inv_sqrt 237 correction = mean / (2.0 * var) 238 dvar_dm = self.B + self.C * (2.0 * full_rates[None, :]) 239 return coeff[:, None] * (self.A - correction[:, None] * dvar_dm) 240 241 def residual_numpy(self, x: np.ndarray) -> np.ndarray: 242 rates = self._full_rates_numpy(x) 243 phi = self._phi_numpy(rates) 244 residual = (phi - rates) / self.tau 245 if self.dim == 0: 246 return np.zeros((0,), dtype=float) 247 return self.residual_matrix @ residual 248 249 @staticmethod 250 def _jax_residual(x, v_focus, args): 251 A, B, C, bias, tau, focus_vector, selector, membership, reduction, var_eps = args 252 group_values = selector @ x + focus_vector * v_focus 253 full_rates = membership @ group_values 254 mean = A @ full_rates + bias 255 var = jnp.maximum(B @ full_rates + C @ (full_rates * full_rates), var_eps) 256 phi = 0.5 * (1 - jspecial.erf(-mean / jnp.sqrt(2.0 * var))) 257 residual = (phi - full_rates) / tau 258 return reduction @ residual 259 260 def solve(self, initial_guess: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, bool]: 261 """Solve for a fixed point of the reduced mean-field system. 262 263 Returns 264 ------- 265 tuple[np.ndarray, np.ndarray, bool] 266 `(x, residual, success)` where `x` is the reduced solution vector. 267 268 Expected output 269 --------------- 270 `success` is `True` when the nonlinear solver converged and `residual` 271 is close to zero. 272 """ 273 initial = self._coerce_initial(initial_guess) 274 if self.dim == 0: 275 residual = self.residual_numpy(initial) 276 return initial, residual, True 277 if self.use_jax: 278 try: 279 return self._solve_with_optimistix(initial) 280 except SolverConvergenceError: 281 raise 282 except (ValueError, RuntimeError, TypeError, AttributeError) as e: 283 logger.warning( 284 "Optimistix solver failed with %s: %s. Falling back to scipy solver.", 285 type(e).__name__, 286 str(e), 287 ) 288 self.use_jax = False 289 return self._solve_with_scipy(initial) 290 291 @staticmethod 292 def _residual_norm(residual: np.ndarray) -> float: 293 arr = np.asarray(residual, dtype=float).ravel() 294 if arr.size == 0: 295 return 0.0 296 return float(np.linalg.norm(arr)) 297 298 def _normalize_solver_result( 299 self, 300 x: np.ndarray, 301 residual: np.ndarray, 302 success: bool, 303 *, 304 method: str, 305 ) -> Tuple[np.ndarray, np.ndarray, bool]: 306 x_arr = np.asarray(x, dtype=float).ravel() 307 residual_arr = np.asarray(residual, dtype=float).ravel() 308 finite = np.isfinite(x_arr).all() and np.isfinite(residual_arr).all() 309 accepted = bool(success) and finite 310 if (not accepted) and finite: 311 residual_norm = self._residual_norm(residual_arr) 312 if residual_norm <= self.RESIDUAL_ACCEPT_TOL: 313 logger.debug( 314 "Accepting near-root from %s at v_focus %.6f with residual norm %.3e.", 315 method, 316 float(self.v_focus), 317 residual_norm, 318 ) 319 accepted = True 320 return x_arr, residual_arr, accepted 321 322 def _solve_with_scipy(self, initial: np.ndarray) -> Tuple[np.ndarray, np.ndarray, bool]: 323 best_x = np.asarray(initial, dtype=float).ravel() 324 best_residual = np.asarray(self.residual_numpy(best_x), dtype=float).ravel() 325 best_norm = self._residual_norm(best_residual) 326 327 def remember(x: np.ndarray, residual: np.ndarray) -> None: 328 nonlocal best_x, best_residual, best_norm 329 if not (np.isfinite(x).all() and np.isfinite(residual).all()): 330 return 331 norm = self._residual_norm(residual) 332 if norm < best_norm: 333 best_x = np.asarray(x, dtype=float).ravel() 334 best_residual = np.asarray(residual, dtype=float).ravel() 335 best_norm = norm 336 337 hybr = optimize.root( 338 self.residual_numpy, 339 initial, 340 method="hybr", 341 tol=self.root_tol, 342 options={"maxfev": self.max_function_evals}, 343 ) 344 x_hybr, residual_hybr, success_hybr = self._normalize_solver_result( 345 hybr.x, 346 hybr.fun, 347 bool(hybr.success), 348 method="scipy.root(hybr)", 349 ) 350 if success_hybr: 351 return x_hybr, residual_hybr, True 352 remember(x_hybr, residual_hybr) 353 354 lm_initial = x_hybr if np.isfinite(x_hybr).all() else np.asarray(initial, dtype=float).ravel() 355 lm = optimize.root( 356 self.residual_numpy, 357 lm_initial, 358 method="lm", 359 tol=self.root_tol, 360 options={"maxiter": self.max_function_evals}, 361 ) 362 x_lm, residual_lm, success_lm = self._normalize_solver_result( 363 lm.x, 364 lm.fun, 365 bool(lm.success), 366 method="scipy.root(lm)", 367 ) 368 if success_lm: 369 return x_lm, residual_lm, True 370 remember(x_lm, residual_lm) 371 372 ls_initial = x_lm if np.isfinite(x_lm).all() else lm_initial 373 ls = optimize.least_squares( 374 self.residual_numpy, 375 ls_initial, 376 method="trf", 377 xtol=self.root_tol, 378 ftol=self.root_tol, 379 gtol=self.root_tol, 380 max_nfev=self.max_function_evals, 381 ) 382 x_ls, residual_ls, success_ls = self._normalize_solver_result( 383 ls.x, 384 ls.fun, 385 bool(ls.success), 386 method="scipy.least_squares(trf)", 387 ) 388 if success_ls: 389 return x_ls, residual_ls, True 390 remember(x_ls, residual_ls) 391 392 return best_x, best_residual, False 393 394 def _solve_with_optimistix(self, initial: np.ndarray) -> Tuple[np.ndarray, np.ndarray, bool]: 395 args = self._prepare_jax_args() 396 solver_entry = self._get_jax_solver() 397 if solver_entry is None: 398 raise RuntimeError("Optimistix solver is unavailable.") 399 x0 = jnp.asarray(initial, dtype=jnp.float64) 400 v_focus = jnp.asarray(float(self.v_focus), dtype=jnp.float64) 401 value, status, _ = solver_entry["single"](x0, v_focus, args) 402 success = bool(np.asarray(status, dtype=bool)) 403 if not success: 404 raise SolverConvergenceError(self.v_focus, self.max_steps) 405 value_np = np.asarray(value, dtype=float) 406 return value_np, self.residual_numpy(value_np), success 407 408 def _prepare_jax_args(self): 409 if self._jax_args is None: 410 self._jax_args = ( 411 jnp.asarray(self.A, dtype=jnp.float64), 412 jnp.asarray(self.B, dtype=jnp.float64), 413 jnp.asarray(self.C, dtype=jnp.float64), 414 jnp.asarray(self.bias, dtype=jnp.float64), 415 jnp.asarray(self.tau, dtype=jnp.float64), 416 jnp.asarray(self.focus_vector, dtype=jnp.float64), 417 jnp.asarray(self.selector_matrix, dtype=jnp.float64), 418 jnp.asarray(self.group_membership, dtype=jnp.float64), 419 jnp.asarray(self.residual_matrix, dtype=jnp.float64), 420 float(self.VAR_EPS), 421 ) 422 return self._jax_args 423 424 def _get_jax_solver(self): 425 if not (self.use_jax and HAS_JAX and optx is not None): 426 return None 427 if self.dim == 0: 428 return None 429 key = (self.dim, float(self.root_tol), int(self.max_steps)) 430 entry = JAX_SOLVER_CACHE.get(key) 431 if entry is None: 432 entry = _build_jax_solver_entry(self.dim, float(self.root_tol), int(self.max_steps)) 433 JAX_SOLVER_CACHE[key] = entry 434 return entry 435 436 def phi_numpy(self, x: np.ndarray) -> np.ndarray: 437 """Evaluate the transfer function on the full-rate state implied by `x`. 438 439 Expected output 440 --------------- 441 Returns one activity value per population in the interval `[0, 1]`. 442 """ 443 rates = self._full_rates_numpy(np.asarray(x, dtype=float)) 444 return self._phi_numpy(rates) 445 446 def full_rates_numpy(self, x: np.ndarray) -> np.ndarray: 447 """Expand a reduced solver vector into one rate per population. 448 449 Expected output 450 --------------- 451 The returned array has length `population_count`. 452 """ 453 return self._full_rates_numpy(np.asarray(x, dtype=float)) 454 455 def jacobian_numpy(self, x: np.ndarray) -> np.ndarray: 456 """Return the Jacobian of the full mean-field residual at `x`. 457 458 Expected output 459 --------------- 460 The returned matrix has shape `(population_count, population_count)`. 461 """ 462 rates = self._full_rates_numpy(np.asarray(x, dtype=float)) 463 return (self._phi_jacobian_numpy(rates) - np.eye(self.population_count)) / self.tau[:, np.newaxis] 464 465 def focus_output(self, rates: np.ndarray) -> float: 466 """Average the rates over the configured focus populations.""" 467 values = np.asarray(rates, dtype=float)[self.focus_indices] 468 return float(values.mean()) if values.size else float(rates[0]) 469 470 def _reduce_full_rates(self, full_rates: np.ndarray) -> np.ndarray: 471 if self.dim == 0: 472 return np.zeros((0,), dtype=float) 473 full = np.asarray(full_rates, dtype=float).reshape((self.population_count,)) 474 group_sums = self.group_membership.T @ full 475 group_means = group_sums * self.group_inverse_sizes 476 return group_means[self.solve_groups] 477 478 def _coerce_initial(self, initial_guess: Optional[np.ndarray]) -> np.ndarray: 479 if self.dim == 0: 480 return np.zeros((0,), dtype=float) 481 if initial_guess is None: 482 return np.full((self.dim,), 0.1, dtype=float) 483 arr = np.asarray(initial_guess, dtype=float).ravel() 484 if arr.size == self.dim: 485 return arr 486 if arr.size == self.group_count: 487 copy = arr.copy() 488 copy[self.focus_group_index] = self.v_focus 489 return copy[self.solve_groups] 490 if arr.size == self.population_count: 491 return self._reduce_full_rates(arr) 492 non_focus_count = self.population_count - len(self.focus_indices) 493 if arr.size == non_focus_count: 494 full_rates = np.empty(self.population_count, dtype=float) 495 full_rates[self.focus_population_mask] = self.v_focus 496 full_rates[~self.focus_population_mask] = arr 497 return self._reduce_full_rates(full_rates) 498 raise ValueError( 499 f"Initial guess must have length {self.dim}, {self.group_count}, {self.population_count}, or {non_focus_count}, got {arr.size}." 500 ) 501 502 def solve_sequence( 503 self, 504 v_focus_values: Sequence[float], 505 initial_guess: Optional[np.ndarray] = None, 506 ) -> Optional[Tuple[np.ndarray, np.ndarray]]: 507 """Solve a sequence of focus inputs with the accelerated JAX path. 508 509 Expected output 510 --------------- 511 Returns `(solutions, success_flags)` when the JAX solver is available, 512 otherwise `None`. 513 """ 514 values = np.asarray(list(v_focus_values), dtype=float) 515 if values.size == 0: 516 return np.zeros((0, self.dim)), np.zeros((0,), dtype=bool) 517 if self.dim == 0: 518 zeros = np.zeros((values.size, 0), dtype=float) 519 return zeros, np.ones((values.size,), dtype=bool) 520 if not (self.use_jax and HAS_JAX and optx is not None): 521 return None 522 solver_entry = self._get_jax_solver() 523 if solver_entry is None: 524 return None 525 args = self._prepare_jax_args() 526 initial = self._coerce_initial(initial_guess) 527 x0 = jnp.asarray(initial, dtype=jnp.float64) 528 v_seq = jnp.asarray(values, dtype=jnp.float64) 529 try: 530 solutions, statuses, _ = solver_entry["scan"](x0, v_seq, args) 531 except (ValueError, RuntimeError, TypeError, AttributeError) as exc: 532 logger.warning( 533 "Optimistix sweep failed with %s: %s. Falling back to sequential solver.", 534 type(exc).__name__, 535 str(exc), 536 ) 537 self.use_jax = False 538 return None 539 solution_np = np.asarray(solutions, dtype=float) 540 success = np.asarray(statuses, dtype=bool) 541 return solution_np, success 542 543 # --- class helpers -------------------------------------------------- 544 @classmethod 545 def generate_erf_curve( 546 cls, 547 parameter: Dict, 548 *, 549 start: float = 0.02, 550 end: float = 1.0, 551 step_number: int = 20, 552 retry_step: Optional[float] = None, 553 initial_guess: Optional[np.ndarray] = None, 554 fallback_initials: Optional[Sequence[float]] = None, 555 **network_kwargs, 556 ) -> ERFResult: 557 """Compute the ERF for ``parameter``. 558 559 When ``retry_step`` is provided, the solver retries inputs separated by 560 that value whenever convergence fails. If the ERF cannot be completed 561 the ``completed`` flag is ``False`` and no serialization should happen. 562 Parameters 563 ---------- 564 start : float, optional 565 Lower bound of the input range for the ERF (default: 0.02). 566 end : float, optional 567 Upper bound of the input range for the ERF (default: 1.0). 568 step_number : int, optional 569 Number of steps between ``start`` and ``end`` (default: 20). 570 ... 571 572 Expected output 573 --------------- 574 Returns an :class:`ERFResult` whose `x_data` stores the driven focus 575 inputs and whose `y_data` stores the corresponding mean-field outputs. 576 """ 577 ERF_EPS = 1e-12 # Tolerance for floating-point comparison of v_in against end 578 x_data: List[float] = [] 579 y_data: List[float] = [] 580 solves: List[np.ndarray] = [] 581 step = (end - start) / max(step_number, 1) 582 adaptive_min_step = min(abs(step), cls.ADAPTIVE_CONTINUATION_MIN_STEP) if step != 0 else cls.ADAPTIVE_CONTINUATION_MIN_STEP 583 aborted = False 584 fallback_values = list(fallback_initials) if fallback_initials is not None else [0.02, 0.2, 0.5, 0.8, 0.98] 585 586 def solve_with_fallback_initials( 587 system: "RateSystem", 588 initial: Optional[np.ndarray], 589 ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], bool]: 590 best_solution: Optional[np.ndarray] = None 591 best_residual: Optional[np.ndarray] = None 592 best_norm = float("inf") 593 candidates: List[Optional[np.ndarray]] = [initial] 594 for seed in fallback_values: 595 if system.dim == 0: 596 candidate = np.zeros((0,), dtype=float) 597 else: 598 candidate = np.full((system.dim,), float(seed), dtype=float) 599 candidates.append(candidate) 600 for candidate in candidates: 601 try: 602 solution, residual, success = system.solve(candidate) 603 except SolverConvergenceError as exc: 604 logger.debug("Skipping v_in %.6f: %s", float(system.v_focus), str(exc)) 605 continue 606 residual_arr = np.asarray(residual, dtype=float).ravel() 607 if success: 608 return np.asarray(solution, dtype=float).ravel(), residual_arr, True 609 if np.isfinite(residual_arr).all(): 610 residual_norm = cls._residual_norm(residual_arr) 611 if residual_norm < best_norm: 612 best_solution = np.asarray(solution, dtype=float).ravel() 613 best_residual = residual_arr 614 best_norm = residual_norm 615 return best_solution, best_residual, False 616 617 vector_values: Optional[List[float]] = None 618 if retry_step is None: 619 vector_values = [] 620 if step == 0: 621 vector_values.append(float(start)) 622 else: 623 cursor = float(start) 624 while cursor <= end + ERF_EPS: 625 vector_values.append(cursor) 626 cursor += step 627 initial_focus = float(start) 628 if vector_values: 629 initial_focus = float(vector_values[0]) 630 system = cls(parameter, initial_focus, **network_kwargs) 631 current_initial = initial_guess 632 prefetched_solutions: Optional[np.ndarray] = None 633 prefix_limit = 0 634 if vector_values and system.use_jax: 635 seq_result = system.solve_sequence(vector_values, initial_guess=current_initial) 636 if seq_result is not None: 637 prefetched_solutions, success_flags = seq_result 638 failure = np.flatnonzero(~success_flags) 639 prefix_limit = int(failure[0]) if failure.size else len(vector_values) 640 if vector_values is not None: 641 idx = 0 642 while idx < len(vector_values): 643 v_in = vector_values[idx] 644 system.v_focus = float(v_in) 645 use_prefetched = prefetched_solutions is not None and idx < prefix_limit 646 if use_prefetched: 647 solution = np.asarray(prefetched_solutions[idx], dtype=float) 648 success = True 649 else: 650 solution, residual, success = solve_with_fallback_initials(system, current_initial) 651 if not success: 652 prev_v = x_data[-1] if x_data else None 653 if prev_v is not None: 654 gap = float(v_in) - float(prev_v) 655 if gap > adaptive_min_step + ERF_EPS: 656 midpoint = float(prev_v) + 0.5 * gap 657 vector_values.insert(idx, midpoint) 658 prefetched_solutions = None 659 prefix_limit = 0 660 continue 661 aborted = True 662 break 663 phi_values = system.phi_numpy(solution) 664 x_data.append(system.v_focus) 665 y_data.append(system.focus_output(phi_values)) 666 solves.append(solution) 667 current_initial = solution 668 idx += 1 669 completed = (not aborted) and (len(x_data) == len(vector_values)) 670 else: 671 v_in = float(start) 672 next_value = v_in 673 while v_in <= end + ERF_EPS: 674 system.v_focus = float(v_in) 675 solution, residual, success = solve_with_fallback_initials(system, current_initial) 676 if not success: 677 if retry_step is not None: 678 v_in += retry_step 679 next_value = v_in 680 if solution is not None: 681 current_initial = solution 682 continue 683 aborted = True 684 break 685 phi_values = system.phi_numpy(solution) 686 x_data.append(system.v_focus) 687 y_data.append(system.focus_output(phi_values)) 688 solves.append(solution) 689 current_initial = solution 690 if step == 0: 691 next_value = float("inf") 692 break 693 v_in = system.v_focus + step 694 next_value = v_in 695 completed = (not aborted) and (step == 0 or next_value > end + ERF_EPS) 696 return ERFResult(x_data=x_data, y_data=y_data, solves=solves, completed=completed) 697 698 @classmethod 699 def compute_fixpoints( 700 cls, 701 sweep_entry: Sequence, 702 *, 703 tol: float = 1e-3, 704 interpolation_steps: int = 10_000, 705 **network_kwargs, 706 ) -> Dict[float, Dict[str, Any]]: 707 """ 708 Compute fixed points from an ERF sweep. 709 Fixed points with slope larger than 1 at the intersection with the identity line are considered unstable in the 1D map approximation. 710 711 Parameters 712 ---------- 713 sweep_entry : sequence 714 Tuple (x_data, y_data, solves, parameter) as returned by generate_erf_curve. 715 tol : float, optional 716 Tolerance for detecting crossings of the identity line (x = y) and for 717 merging nearby crossings. Crossings where |x - y| <= tol are treated as 718 fixed points, and crossings within tol of each other are merged. Default 719 is 1e-3. 720 interpolation_steps : int, optional 721 Number of interpolation points used to refine the ERF before searching 722 for crossings. Larger values increase accuracy but also cost. Default 723 is 10_000. 724 725 Expected output 726 --------------- 727 Returns a dictionary keyed by fixed-point location. Each value contains 728 at least `stability`, `rates`, `residual_norm`, and `solver_success`. 729 """ 730 SLOPE_STABILITY_THRESHOLD = 1.0 # |d(ERF)/dv| < 1 ⇒ stable in 1D; > 1 ⇒ unstable 731 x_data, y_data, solves, parameter = sweep_entry 732 x_interp, y_interp = interpolate_curve(x_data, y_data, steps=interpolation_steps) 733 if x_interp.size == 0 or y_interp.size == 0: 734 print("Skipping fixpoint analysis: empty ERF data.") 735 return {} 736 diff = x_interp - y_interp 737 crossings = [] 738 prev_diff = diff[0] 739 for idx in range(1, len(diff)): 740 curr_diff = diff[idx] 741 cross_val = None 742 if np.abs(curr_diff) <= tol: 743 cross_val = y_interp[idx] 744 elif np.abs(prev_diff) <= tol: 745 cross_val = y_interp[idx - 1] 746 elif prev_diff * curr_diff < 0: 747 weight = prev_diff / (prev_diff - curr_diff) 748 cross_val = y_interp[idx - 1] + weight * (y_interp[idx] - y_interp[idx - 1]) 749 if cross_val is not None: 750 if crossings and np.abs(cross_val - crossings[-1][0]) <= tol: 751 crossings[-1] = (float(cross_val), idx) 752 else: 753 crossings.append((float(cross_val), idx)) 754 prev_diff = curr_diff 755 fixpoints: Dict[float, Dict[str, Any]] = {} 756 if not crossings: 757 return fixpoints 758 solves_array = [np.asarray(s, dtype=float) for s in solves] 759 v_out_old = np.asarray(y_data, dtype=float) 760 for cross_point, idx in crossings: 761 print(f"Cross-Point: {cross_point}") 762 if idx <= 0: 763 slope = np.inf 764 else: 765 slope = (y_interp[idx] - y_interp[idx - 1]) / (x_interp[idx] - x_interp[idx - 1]) 766 entry: Dict[str, Any] = { 767 "stability": "unstable", 768 "rates": None, 769 "residual_norm": float("inf"), 770 "solver_success": False, 771 "slope": float(slope) if np.isfinite(slope) else float("inf"), 772 "included": False, 773 } 774 slope_unstable = not np.isfinite(slope) or slope > SLOPE_STABILITY_THRESHOLD 775 if len(solves_array) == 0: 776 entry["reason"] = "missing_erf_solution" 777 fixpoints[cross_point] = entry 778 continue 779 closest_idx = int(np.argmin(np.abs(v_out_old - cross_point))) 780 closest_idx = min(max(closest_idx, 0), len(solves_array) - 1) 781 initial = solves_array[closest_idx] 782 system = cls(parameter, cross_point, **network_kwargs) 783 try: 784 solve, residual, success = system.solve(initial) 785 except SolverConvergenceError as exc: 786 entry["solver_success"] = False 787 entry["residual_norm"] = float("inf") 788 entry["rates"] = None 789 entry["stability"] = "unstable" 790 entry["reason"] = "solver_failed" 791 entry["error"] = str(exc) 792 fixpoints[cross_point] = entry 793 continue 794 residual = np.asarray(residual, dtype=float) 795 residual_norm = float(np.linalg.norm(residual)) if residual.size else 0.0 796 if not np.isfinite(residual).all(): 797 residual_norm = float("inf") 798 entry["solver_success"] = bool(success) 799 entry["residual_norm"] = residual_norm 800 if success and np.isfinite(residual).all(): 801 entry["rates"] = system.full_rates_numpy(solve) 802 else: 803 entry["rates"] = None 804 if slope_unstable or not success or not np.isfinite(residual).all(): 805 if not success or not np.isfinite(residual).all(): 806 print("Warning: convergence problems near cross-point") 807 stability = "unstable" 808 else: 809 jacobian = system.jacobian_numpy(solve) 810 if not np.isfinite(jacobian).all(): 811 stability = "unstable" 812 else: 813 try: 814 eigval = np.linalg.eigvals(jacobian) 815 except np.linalg.LinAlgError: 816 stability = "unstable" 817 else: 818 stability = "stable" if (eigval < 0).all() else "unstable" 819 entry["stability"] = stability 820 if "reason" not in entry: 821 entry["reason"] = "slope" if slope_unstable else None 822 fixpoints[cross_point] = entry 823 return fixpoints
General mean-field solver with helper utilities for ERF and fixpoints.
Subclasses implement _build_dynamics(...) and then inherit the fixed-point
solver, ERF sweep, and fixpoint analysis helpers.
85 def __init__( 86 self, 87 parameter: Dict, 88 v_focus: float, 89 *, 90 focus_population: Optional[Union[int, Sequence[int]]] = None, 91 prefer_jax: bool = True, 92 root_tol: float = 1e-9, 93 max_function_evals: int = 4000, 94 max_newton_steps: Optional[int] = 1000, 95 **network_kwargs, 96 ) -> None: 97 self.parameter = dict(parameter) 98 self.v_focus = float(v_focus) 99 self.prefer_jax = bool(prefer_jax) 100 self.root_tol = root_tol 101 self.max_function_evals = max_function_evals 102 default_steps = 256 103 if max_newton_steps is None: 104 self.max_steps = default_steps 105 else: 106 steps = int(max_newton_steps) 107 if steps <= 0: 108 raise ValueError("max_newton_steps must be positive.") 109 self.max_steps = steps 110 self.network_kwargs = dict(network_kwargs) 111 dynamics = self._build_dynamics(self.parameter, **self.network_kwargs) 112 if len(dynamics) == 4: 113 self.A, self.B, self.bias, self.tau = dynamics 114 self.C = np.zeros_like(self.B) 115 elif len(dynamics) == 5: 116 self.A, self.B, self.C, self.bias, self.tau = dynamics 117 if self.C is None: 118 self.C = np.zeros_like(self.B) 119 else: 120 raise ValueError("Expected _build_dynamics to return (A, B, bias, tau) or (A, B, C, bias, tau).") 121 self.population_count = int(self.A.shape[0]) 122 if self.A.shape != self.B.shape or self.A.shape != self.C.shape: 123 raise ValueError("Connectivity mean and variance matrices must match in shape.") 124 if self.bias.shape[0] != self.population_count or self.tau.shape[0] != self.population_count: 125 raise ValueError("Bias and tau vectors must match the matrix dimensions.") 126 if focus_population is not None: 127 focus_config = focus_population 128 elif isinstance(self.parameter.get("focus_population"), (list, tuple, range)): 129 focus_config = self.parameter.get("focus_population") 130 else: 131 count = int(self.parameter.get("focus_count", 1) or 1) 132 focus_config = list(range(max(count, 1))) 133 self.focus_indices = self._resolve_focus_indices(focus_config) 134 self._initialize_group_constraints() 135 self.use_jax = self.prefer_jax and HAS_JAX and optx is not None 136 self._jax_args = None
260 def solve(self, initial_guess: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, bool]: 261 """Solve for a fixed point of the reduced mean-field system. 262 263 Returns 264 ------- 265 tuple[np.ndarray, np.ndarray, bool] 266 `(x, residual, success)` where `x` is the reduced solution vector. 267 268 Expected output 269 --------------- 270 `success` is `True` when the nonlinear solver converged and `residual` 271 is close to zero. 272 """ 273 initial = self._coerce_initial(initial_guess) 274 if self.dim == 0: 275 residual = self.residual_numpy(initial) 276 return initial, residual, True 277 if self.use_jax: 278 try: 279 return self._solve_with_optimistix(initial) 280 except SolverConvergenceError: 281 raise 282 except (ValueError, RuntimeError, TypeError, AttributeError) as e: 283 logger.warning( 284 "Optimistix solver failed with %s: %s. Falling back to scipy solver.", 285 type(e).__name__, 286 str(e), 287 ) 288 self.use_jax = False 289 return self._solve_with_scipy(initial)
Solve for a fixed point of the reduced mean-field system.
Returns
- tuple[np.ndarray, np.ndarray, bool]:
(x, residual, success)wherexis the reduced solution vector.
Expected output
success is True when the nonlinear solver converged and residual
is close to zero.
436 def phi_numpy(self, x: np.ndarray) -> np.ndarray: 437 """Evaluate the transfer function on the full-rate state implied by `x`. 438 439 Expected output 440 --------------- 441 Returns one activity value per population in the interval `[0, 1]`. 442 """ 443 rates = self._full_rates_numpy(np.asarray(x, dtype=float)) 444 return self._phi_numpy(rates)
Evaluate the transfer function on the full-rate state implied by x.
Expected output
Returns one activity value per population in the interval [0, 1].
446 def full_rates_numpy(self, x: np.ndarray) -> np.ndarray: 447 """Expand a reduced solver vector into one rate per population. 448 449 Expected output 450 --------------- 451 The returned array has length `population_count`. 452 """ 453 return self._full_rates_numpy(np.asarray(x, dtype=float))
Expand a reduced solver vector into one rate per population.
Expected output
The returned array has length population_count.
455 def jacobian_numpy(self, x: np.ndarray) -> np.ndarray: 456 """Return the Jacobian of the full mean-field residual at `x`. 457 458 Expected output 459 --------------- 460 The returned matrix has shape `(population_count, population_count)`. 461 """ 462 rates = self._full_rates_numpy(np.asarray(x, dtype=float)) 463 return (self._phi_jacobian_numpy(rates) - np.eye(self.population_count)) / self.tau[:, np.newaxis]
Return the Jacobian of the full mean-field residual at x.
Expected output
The returned matrix has shape (population_count, population_count).
465 def focus_output(self, rates: np.ndarray) -> float: 466 """Average the rates over the configured focus populations.""" 467 values = np.asarray(rates, dtype=float)[self.focus_indices] 468 return float(values.mean()) if values.size else float(rates[0])
Average the rates over the configured focus populations.
502 def solve_sequence( 503 self, 504 v_focus_values: Sequence[float], 505 initial_guess: Optional[np.ndarray] = None, 506 ) -> Optional[Tuple[np.ndarray, np.ndarray]]: 507 """Solve a sequence of focus inputs with the accelerated JAX path. 508 509 Expected output 510 --------------- 511 Returns `(solutions, success_flags)` when the JAX solver is available, 512 otherwise `None`. 513 """ 514 values = np.asarray(list(v_focus_values), dtype=float) 515 if values.size == 0: 516 return np.zeros((0, self.dim)), np.zeros((0,), dtype=bool) 517 if self.dim == 0: 518 zeros = np.zeros((values.size, 0), dtype=float) 519 return zeros, np.ones((values.size,), dtype=bool) 520 if not (self.use_jax and HAS_JAX and optx is not None): 521 return None 522 solver_entry = self._get_jax_solver() 523 if solver_entry is None: 524 return None 525 args = self._prepare_jax_args() 526 initial = self._coerce_initial(initial_guess) 527 x0 = jnp.asarray(initial, dtype=jnp.float64) 528 v_seq = jnp.asarray(values, dtype=jnp.float64) 529 try: 530 solutions, statuses, _ = solver_entry["scan"](x0, v_seq, args) 531 except (ValueError, RuntimeError, TypeError, AttributeError) as exc: 532 logger.warning( 533 "Optimistix sweep failed with %s: %s. Falling back to sequential solver.", 534 type(exc).__name__, 535 str(exc), 536 ) 537 self.use_jax = False 538 return None 539 solution_np = np.asarray(solutions, dtype=float) 540 success = np.asarray(statuses, dtype=bool) 541 return solution_np, success
Solve a sequence of focus inputs with the accelerated JAX path.
Expected output
Returns (solutions, success_flags) when the JAX solver is available,
otherwise None.
544 @classmethod 545 def generate_erf_curve( 546 cls, 547 parameter: Dict, 548 *, 549 start: float = 0.02, 550 end: float = 1.0, 551 step_number: int = 20, 552 retry_step: Optional[float] = None, 553 initial_guess: Optional[np.ndarray] = None, 554 fallback_initials: Optional[Sequence[float]] = None, 555 **network_kwargs, 556 ) -> ERFResult: 557 """Compute the ERF for ``parameter``. 558 559 When ``retry_step`` is provided, the solver retries inputs separated by 560 that value whenever convergence fails. If the ERF cannot be completed 561 the ``completed`` flag is ``False`` and no serialization should happen. 562 Parameters 563 ---------- 564 start : float, optional 565 Lower bound of the input range for the ERF (default: 0.02). 566 end : float, optional 567 Upper bound of the input range for the ERF (default: 1.0). 568 step_number : int, optional 569 Number of steps between ``start`` and ``end`` (default: 20). 570 ... 571 572 Expected output 573 --------------- 574 Returns an :class:`ERFResult` whose `x_data` stores the driven focus 575 inputs and whose `y_data` stores the corresponding mean-field outputs. 576 """ 577 ERF_EPS = 1e-12 # Tolerance for floating-point comparison of v_in against end 578 x_data: List[float] = [] 579 y_data: List[float] = [] 580 solves: List[np.ndarray] = [] 581 step = (end - start) / max(step_number, 1) 582 adaptive_min_step = min(abs(step), cls.ADAPTIVE_CONTINUATION_MIN_STEP) if step != 0 else cls.ADAPTIVE_CONTINUATION_MIN_STEP 583 aborted = False 584 fallback_values = list(fallback_initials) if fallback_initials is not None else [0.02, 0.2, 0.5, 0.8, 0.98] 585 586 def solve_with_fallback_initials( 587 system: "RateSystem", 588 initial: Optional[np.ndarray], 589 ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], bool]: 590 best_solution: Optional[np.ndarray] = None 591 best_residual: Optional[np.ndarray] = None 592 best_norm = float("inf") 593 candidates: List[Optional[np.ndarray]] = [initial] 594 for seed in fallback_values: 595 if system.dim == 0: 596 candidate = np.zeros((0,), dtype=float) 597 else: 598 candidate = np.full((system.dim,), float(seed), dtype=float) 599 candidates.append(candidate) 600 for candidate in candidates: 601 try: 602 solution, residual, success = system.solve(candidate) 603 except SolverConvergenceError as exc: 604 logger.debug("Skipping v_in %.6f: %s", float(system.v_focus), str(exc)) 605 continue 606 residual_arr = np.asarray(residual, dtype=float).ravel() 607 if success: 608 return np.asarray(solution, dtype=float).ravel(), residual_arr, True 609 if np.isfinite(residual_arr).all(): 610 residual_norm = cls._residual_norm(residual_arr) 611 if residual_norm < best_norm: 612 best_solution = np.asarray(solution, dtype=float).ravel() 613 best_residual = residual_arr 614 best_norm = residual_norm 615 return best_solution, best_residual, False 616 617 vector_values: Optional[List[float]] = None 618 if retry_step is None: 619 vector_values = [] 620 if step == 0: 621 vector_values.append(float(start)) 622 else: 623 cursor = float(start) 624 while cursor <= end + ERF_EPS: 625 vector_values.append(cursor) 626 cursor += step 627 initial_focus = float(start) 628 if vector_values: 629 initial_focus = float(vector_values[0]) 630 system = cls(parameter, initial_focus, **network_kwargs) 631 current_initial = initial_guess 632 prefetched_solutions: Optional[np.ndarray] = None 633 prefix_limit = 0 634 if vector_values and system.use_jax: 635 seq_result = system.solve_sequence(vector_values, initial_guess=current_initial) 636 if seq_result is not None: 637 prefetched_solutions, success_flags = seq_result 638 failure = np.flatnonzero(~success_flags) 639 prefix_limit = int(failure[0]) if failure.size else len(vector_values) 640 if vector_values is not None: 641 idx = 0 642 while idx < len(vector_values): 643 v_in = vector_values[idx] 644 system.v_focus = float(v_in) 645 use_prefetched = prefetched_solutions is not None and idx < prefix_limit 646 if use_prefetched: 647 solution = np.asarray(prefetched_solutions[idx], dtype=float) 648 success = True 649 else: 650 solution, residual, success = solve_with_fallback_initials(system, current_initial) 651 if not success: 652 prev_v = x_data[-1] if x_data else None 653 if prev_v is not None: 654 gap = float(v_in) - float(prev_v) 655 if gap > adaptive_min_step + ERF_EPS: 656 midpoint = float(prev_v) + 0.5 * gap 657 vector_values.insert(idx, midpoint) 658 prefetched_solutions = None 659 prefix_limit = 0 660 continue 661 aborted = True 662 break 663 phi_values = system.phi_numpy(solution) 664 x_data.append(system.v_focus) 665 y_data.append(system.focus_output(phi_values)) 666 solves.append(solution) 667 current_initial = solution 668 idx += 1 669 completed = (not aborted) and (len(x_data) == len(vector_values)) 670 else: 671 v_in = float(start) 672 next_value = v_in 673 while v_in <= end + ERF_EPS: 674 system.v_focus = float(v_in) 675 solution, residual, success = solve_with_fallback_initials(system, current_initial) 676 if not success: 677 if retry_step is not None: 678 v_in += retry_step 679 next_value = v_in 680 if solution is not None: 681 current_initial = solution 682 continue 683 aborted = True 684 break 685 phi_values = system.phi_numpy(solution) 686 x_data.append(system.v_focus) 687 y_data.append(system.focus_output(phi_values)) 688 solves.append(solution) 689 current_initial = solution 690 if step == 0: 691 next_value = float("inf") 692 break 693 v_in = system.v_focus + step 694 next_value = v_in 695 completed = (not aborted) and (step == 0 or next_value > end + ERF_EPS) 696 return ERFResult(x_data=x_data, y_data=y_data, solves=solves, completed=completed)
Compute the ERF for parameter.
When retry_step is provided, the solver retries inputs separated by
that value whenever convergence fails. If the ERF cannot be completed
the completed flag is False and no serialization should happen.
Parameters
- start (float, optional): Lower bound of the input range for the ERF (default: 0.02).
- end (float, optional): Upper bound of the input range for the ERF (default: 1.0).
- step_number (int, optional):
Number of steps between
startandend(default: 20). - ...
Expected output
Returns an ERFResult whose x_data stores the driven focus
inputs and whose y_data stores the corresponding mean-field outputs.
698 @classmethod 699 def compute_fixpoints( 700 cls, 701 sweep_entry: Sequence, 702 *, 703 tol: float = 1e-3, 704 interpolation_steps: int = 10_000, 705 **network_kwargs, 706 ) -> Dict[float, Dict[str, Any]]: 707 """ 708 Compute fixed points from an ERF sweep. 709 Fixed points with slope larger than 1 at the intersection with the identity line are considered unstable in the 1D map approximation. 710 711 Parameters 712 ---------- 713 sweep_entry : sequence 714 Tuple (x_data, y_data, solves, parameter) as returned by generate_erf_curve. 715 tol : float, optional 716 Tolerance for detecting crossings of the identity line (x = y) and for 717 merging nearby crossings. Crossings where |x - y| <= tol are treated as 718 fixed points, and crossings within tol of each other are merged. Default 719 is 1e-3. 720 interpolation_steps : int, optional 721 Number of interpolation points used to refine the ERF before searching 722 for crossings. Larger values increase accuracy but also cost. Default 723 is 10_000. 724 725 Expected output 726 --------------- 727 Returns a dictionary keyed by fixed-point location. Each value contains 728 at least `stability`, `rates`, `residual_norm`, and `solver_success`. 729 """ 730 SLOPE_STABILITY_THRESHOLD = 1.0 # |d(ERF)/dv| < 1 ⇒ stable in 1D; > 1 ⇒ unstable 731 x_data, y_data, solves, parameter = sweep_entry 732 x_interp, y_interp = interpolate_curve(x_data, y_data, steps=interpolation_steps) 733 if x_interp.size == 0 or y_interp.size == 0: 734 print("Skipping fixpoint analysis: empty ERF data.") 735 return {} 736 diff = x_interp - y_interp 737 crossings = [] 738 prev_diff = diff[0] 739 for idx in range(1, len(diff)): 740 curr_diff = diff[idx] 741 cross_val = None 742 if np.abs(curr_diff) <= tol: 743 cross_val = y_interp[idx] 744 elif np.abs(prev_diff) <= tol: 745 cross_val = y_interp[idx - 1] 746 elif prev_diff * curr_diff < 0: 747 weight = prev_diff / (prev_diff - curr_diff) 748 cross_val = y_interp[idx - 1] + weight * (y_interp[idx] - y_interp[idx - 1]) 749 if cross_val is not None: 750 if crossings and np.abs(cross_val - crossings[-1][0]) <= tol: 751 crossings[-1] = (float(cross_val), idx) 752 else: 753 crossings.append((float(cross_val), idx)) 754 prev_diff = curr_diff 755 fixpoints: Dict[float, Dict[str, Any]] = {} 756 if not crossings: 757 return fixpoints 758 solves_array = [np.asarray(s, dtype=float) for s in solves] 759 v_out_old = np.asarray(y_data, dtype=float) 760 for cross_point, idx in crossings: 761 print(f"Cross-Point: {cross_point}") 762 if idx <= 0: 763 slope = np.inf 764 else: 765 slope = (y_interp[idx] - y_interp[idx - 1]) / (x_interp[idx] - x_interp[idx - 1]) 766 entry: Dict[str, Any] = { 767 "stability": "unstable", 768 "rates": None, 769 "residual_norm": float("inf"), 770 "solver_success": False, 771 "slope": float(slope) if np.isfinite(slope) else float("inf"), 772 "included": False, 773 } 774 slope_unstable = not np.isfinite(slope) or slope > SLOPE_STABILITY_THRESHOLD 775 if len(solves_array) == 0: 776 entry["reason"] = "missing_erf_solution" 777 fixpoints[cross_point] = entry 778 continue 779 closest_idx = int(np.argmin(np.abs(v_out_old - cross_point))) 780 closest_idx = min(max(closest_idx, 0), len(solves_array) - 1) 781 initial = solves_array[closest_idx] 782 system = cls(parameter, cross_point, **network_kwargs) 783 try: 784 solve, residual, success = system.solve(initial) 785 except SolverConvergenceError as exc: 786 entry["solver_success"] = False 787 entry["residual_norm"] = float("inf") 788 entry["rates"] = None 789 entry["stability"] = "unstable" 790 entry["reason"] = "solver_failed" 791 entry["error"] = str(exc) 792 fixpoints[cross_point] = entry 793 continue 794 residual = np.asarray(residual, dtype=float) 795 residual_norm = float(np.linalg.norm(residual)) if residual.size else 0.0 796 if not np.isfinite(residual).all(): 797 residual_norm = float("inf") 798 entry["solver_success"] = bool(success) 799 entry["residual_norm"] = residual_norm 800 if success and np.isfinite(residual).all(): 801 entry["rates"] = system.full_rates_numpy(solve) 802 else: 803 entry["rates"] = None 804 if slope_unstable or not success or not np.isfinite(residual).all(): 805 if not success or not np.isfinite(residual).all(): 806 print("Warning: convergence problems near cross-point") 807 stability = "unstable" 808 else: 809 jacobian = system.jacobian_numpy(solve) 810 if not np.isfinite(jacobian).all(): 811 stability = "unstable" 812 else: 813 try: 814 eigval = np.linalg.eigvals(jacobian) 815 except np.linalg.LinAlgError: 816 stability = "unstable" 817 else: 818 stability = "stable" if (eigval < 0).all() else "unstable" 819 entry["stability"] = stability 820 if "reason" not in entry: 821 entry["reason"] = "slope" if slope_unstable else None 822 fixpoints[cross_point] = entry 823 return fixpoints
Compute fixed points from an ERF sweep. Fixed points with slope larger than 1 at the intersection with the identity line are considered unstable in the 1D map approximation.
Parameters
- sweep_entry (sequence): Tuple (x_data, y_data, solves, parameter) as returned by generate_erf_curve.
- tol (float, optional): Tolerance for detecting crossings of the identity line (x = y) and for merging nearby crossings. Crossings where |x - y| <= tol are treated as fixed points, and crossings within tol of each other are merged. Default is 1e-3.
- interpolation_steps (int, optional): Number of interpolation points used to refine the ERF before searching for crossings. Larger values increase accuracy but also cost. Default is 10_000.
Expected output
Returns a dictionary keyed by fixed-point location. Each value contains
at least stability, rates, residual_norm, and solver_success.
45@dataclass 46class ERFResult: 47 """Container for an event-rate-function sweep. 48 49 Examples 50 -------- 51 >>> result = ERFResult(x_data=[0.1, 0.2], y_data=[0.12, 0.22], solves=[], completed=True) 52 >>> result.completed 53 True 54 """ 55 x_data: List[float] 56 y_data: List[float] 57 solves: List[np.ndarray] 58 completed: bool
Container for an event-rate-function sweep.
Examples
>>> result = ERFResult(x_data=[0.1, 0.2], y_data=[0.12, 0.22], solves=[], completed=True)
>>> result.completed
True
20class EIClusterNetwork(RateSystem): 21 """Specialized mean-field system for the clustered E/I network. 22 23 Examples 24 -------- 25 >>> parameter = { 26 ... "Q": 2, 27 ... "N_E": 4000, 28 ... "N_I": 1000, 29 ... "V_th": 1.0, 30 ... "g": 1.0, 31 ... "p0_ee": 0.2, 32 ... "p0_ei": 0.2, 33 ... "p0_ie": 0.2, 34 ... "p0_ii": 0.2, 35 ... "R_Eplus": 1.0, 36 ... "R_j": 0.8, 37 ... "m_X": 0.1, 38 ... "tau_e": 1.0, 39 ... "tau_i": 2.0, 40 ... } 41 >>> system = EIClusterNetwork(parameter, v_focus=0.2, prefer_jax=False) 42 >>> system.population_count 43 4 44 """ 45 46 def __init__( 47 self, 48 parameter: Dict, 49 v_focus: float, 50 *, 51 kappa: Optional[float] = None, 52 connection_type: Optional[str] = None, 53 use_temporal_variance: bool = True, 54 use_quadratic_variance: bool = True, 55 focus_population=None, 56 prefer_jax: bool = True, 57 max_steps: int = 256, 58 ) -> None: 59 self.Q = int(parameter["Q"]) 60 self._explicit_kappa = kappa 61 self._explicit_connection = connection_type 62 self.collapse_types = bool(parameter.get("collapse_types", True)) 63 super().__init__( 64 parameter, 65 v_focus, 66 focus_population=focus_population, 67 prefer_jax=prefer_jax, 68 max_steps=max_steps, 69 kappa=kappa, 70 connection_type=connection_type, 71 use_temporal_variance=use_temporal_variance, 72 use_quadratic_variance=use_quadratic_variance, 73 ) 74 75 def _build_dynamics( 76 self, parameter: Dict, **network_kwargs 77 ) -> Union[ 78 Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], 79 Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], 80 ]: 81 kappa_value = network_kwargs.get("kappa") 82 if kappa_value is None: 83 kappa_value = parameter.get("kappa", 0.0) 84 conn_kind = network_kwargs.get("connection_type") or parameter.get("connection_type", "bernoulli") 85 use_temporal_variance = bool(network_kwargs.get("use_temporal_variance", True)) 86 use_quadratic_variance = bool(network_kwargs.get("use_quadratic_variance", True)) 87 88 connectivity = self._build_connectivity(parameter, float(kappa_value), str(conn_kind), use_temporal_variance, use_quadratic_variance) 89 tau = np.ones(2 * self.Q, dtype=float) 90 tau[: self.Q] *= parameter["tau_e"] 91 tau[self.Q :] *= parameter["tau_i"] 92 if use_temporal_variance: 93 A, B, C, bias = connectivity 94 return A, B, C, bias, tau 95 A, B, bias = connectivity 96 return A, B, bias, tau 97 98 def _build_population_groups(self, focus: np.ndarray) -> List[np.ndarray]: 99 focus_set = set(int(idx) for idx in focus.tolist()) 100 if not focus_set: 101 focus_set = {0} 102 if not self.collapse_types: 103 return self._build_full_focus_groups(focus_set) 104 return self._build_collapsed_groups(focus_set) 105 106 def _build_full_focus_groups(self, focus_set: set[int]) -> List[np.ndarray]: 107 groups: List[np.ndarray] = [np.array(sorted(focus_set), dtype=int)] 108 excit_focus = sorted(idx for idx in focus_set if 0 <= idx < self.Q) 109 paired_inhib = sorted( 110 { 111 idx + self.Q 112 for idx in excit_focus 113 if idx + self.Q < 2 * self.Q and (idx + self.Q) not in focus_set 114 } 115 ) 116 if paired_inhib: 117 groups.append(np.array(paired_inhib, dtype=int)) 118 remaining_excit = [idx for idx in range(self.Q) if idx not in focus_set] 119 for idx in remaining_excit: 120 groups.append(np.array([idx], dtype=int)) 121 remaining_inhib = [ 122 idx for idx in range(self.Q, 2 * self.Q) if idx not in focus_set and idx not in paired_inhib 123 ] 124 for idx in remaining_inhib: 125 groups.append(np.array([idx], dtype=int)) 126 return groups 127 128 def _build_collapsed_groups(self, focus_set: set[int]) -> List[np.ndarray]: 129 groups: List[np.ndarray] = [np.array(sorted(focus_set), dtype=int)] 130 paired_inhib = sorted( 131 idx + self.Q 132 for idx in focus_set 133 if idx < self.Q and (idx + self.Q) not in focus_set 134 ) 135 paired_excit = sorted( 136 idx - self.Q 137 for idx in focus_set 138 if idx >= self.Q and (idx - self.Q) not in focus_set 139 ) 140 other_excit = sorted(idx for idx in range(self.Q) if idx not in focus_set and idx not in paired_excit) 141 other_inhib = sorted( 142 idx for idx in range(self.Q, 2 * self.Q) if idx not in focus_set and idx not in paired_inhib 143 ) 144 if other_excit: 145 groups.append(np.array(other_excit, dtype=int)) 146 if paired_inhib: 147 groups.append(np.array(paired_inhib, dtype=int)) 148 if other_inhib: 149 groups.append(np.array(other_inhib, dtype=int)) 150 if paired_excit: 151 groups.append(np.array(paired_excit, dtype=int)) 152 return groups 153 154 def _build_connectivity( 155 self, parameter: Dict, kappa: float, connection_type: str, use_temporal_variance: bool, use_quadratic_variance: bool, 156 ) -> Union[ 157 Tuple[np.ndarray, np.ndarray, np.ndarray], 158 Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], 159 ]: 160 N_E = parameter["N_E"] 161 N_I = parameter["N_I"] 162 N = _total_population(parameter) 163 V_th = parameter["V_th"] 164 g = parameter["g"] 165 p0_ee = parameter["p0_ee"] 166 p0_ie = parameter["p0_ie"] 167 p0_ei = parameter["p0_ei"] 168 p0_ii = parameter["p0_ii"] 169 m_X = parameter["m_X"] 170 R_Eplus = parameter["R_Eplus"] 171 R_j = parameter["R_j"] 172 173 n_er = N_E / N 174 n_ir = N_I / N 175 n_e = N_E / self.Q 176 n_i = N_I / self.Q 177 178 theta_E = V_th 179 theta_I = V_th 180 V_th_vec = np.array([theta_E] * self.Q + [theta_I] * self.Q, dtype=float) 181 182 R_Iplus = 1 + R_j * (R_Eplus - 1) 183 184 j_EE = theta_E / math.sqrt(p0_ee * n_er) 185 j_IE = theta_I / math.sqrt(p0_ie * n_er) 186 j_EI = -g * j_EE * p0_ee * n_er / (p0_ei * n_ir) 187 j_II = -j_IE * p0_ie * n_er / (p0_ii * n_ir) 188 189 scale = 1.0 / math.sqrt(N) 190 j_EE *= scale 191 j_IE *= scale 192 j_EI *= scale 193 j_II *= scale 194 195 def mix_scales(R_plus: float) -> tuple[float, float, float, float]: 196 """ 197 Compute scaling factors for in-cluster and out-of-cluster connections. 198 199 When Q > 1: Distributes connection strengths between in-cluster (prob_in, weight_in) 200 and out-of-cluster (prob_out, weight_out) based on the clustering strength R_plus. 201 202 When Q = 1: Only one cluster exists, so there are no out-of-cluster connections. 203 In this case, prob_out and weight_out are set equal to prob_in and weight_in, 204 which ensures that all connection parameters use the same scaling without 205 division by zero. The structured matrix builder will handle the single-cluster 206 case appropriately. 207 """ 208 prob_in = R_plus ** (1.0 - kappa) 209 weight_in = R_plus ** kappa 210 if self.Q == 1: 211 # With a single cluster, all connections are in-cluster 212 prob_out = prob_in 213 weight_out = weight_in 214 else: 215 prob_out = (self.Q - prob_in) / (self.Q - 1) 216 weight_out = (self.Q - weight_in) / (self.Q - 1) 217 return prob_in, prob_out, weight_in, weight_out 218 219 P_scale_in_E, P_scale_out_E, J_scale_in_E, J_scale_out_E = mix_scales(R_Eplus) 220 P_scale_in_I, P_scale_out_I, J_scale_in_I, J_scale_out_I = mix_scales(R_Iplus) 221 222 P_EE = p0_ee * P_scale_in_E 223 p_ee = p0_ee * P_scale_out_E 224 P_IE = p0_ie * P_scale_in_I 225 p_ie = p0_ie * P_scale_out_I 226 P_EI = p0_ei * P_scale_in_I 227 p_ei = p0_ei * P_scale_out_I 228 P_II = p0_ii * P_scale_in_I 229 p_ii = p0_ii * P_scale_out_I 230 231 J_EE = j_EE * J_scale_in_E 232 j_ee = j_EE * J_scale_out_E 233 J_IE = j_IE * J_scale_in_I 234 j_ie = j_IE * J_scale_out_I 235 J_EI = j_EI * J_scale_in_I 236 j_ei = j_EI * J_scale_out_I 237 J_II = j_II * J_scale_in_I 238 j_ii = j_II * J_scale_out_I 239 240 EE_IN = J_EE * P_EE * n_e 241 EE_OUT = j_ee * p_ee * n_e 242 IE_IN = J_IE * P_IE * n_e 243 IE_OUT = j_ie * p_ie * n_e 244 EI_IN = J_EI * P_EI * n_i 245 EI_OUT = j_ei * p_ei * n_i 246 II_IN = J_II * P_II * n_i 247 II_OUT = j_ii * p_ii * n_i 248 249 mean_values = dict(EE_IN=EE_IN, EE_OUT=EE_OUT, IE_IN=IE_IN, IE_OUT=IE_OUT, EI_IN=EI_IN, EI_OUT=EI_OUT, 250 II_IN=II_IN, II_OUT=II_OUT) 251 A = self._structured_matrix(mean_values) 252 253 var_coeffs = dict( 254 EE_IN=self._variance_coeffs(P_EE, J_EE, n_e, connection_type, use_temporal_variance, use_quadratic_variance), 255 EE_OUT=self._variance_coeffs(p_ee, j_ee, n_e, connection_type, use_temporal_variance, use_quadratic_variance), 256 IE_IN=self._variance_coeffs(P_IE, J_IE, n_e, connection_type, use_temporal_variance, use_quadratic_variance), 257 IE_OUT=self._variance_coeffs(p_ie, j_ie, n_e, connection_type, use_temporal_variance, use_quadratic_variance), 258 EI_IN=self._variance_coeffs(P_EI, J_EI, n_i, connection_type, use_temporal_variance, use_quadratic_variance), 259 EI_OUT=self._variance_coeffs(p_ei, j_ei, n_i, connection_type, use_temporal_variance, use_quadratic_variance), 260 II_IN=self._variance_coeffs(P_II, J_II, n_i, connection_type, use_temporal_variance, use_quadratic_variance), 261 II_OUT=self._variance_coeffs(p_ii, j_ii, n_i, connection_type, use_temporal_variance, use_quadratic_variance), 262 ) 263 b_values = {key: coeffs[0] for key, coeffs in var_coeffs.items()} 264 B = self._structured_matrix(b_values) 265 if use_temporal_variance: 266 # C encodes the temporal m(1-m) gating term. 267 c_values = {key: coeffs[1] for key, coeffs in var_coeffs.items()} 268 C = self._structured_matrix(c_values) 269 270 J_EX = math.sqrt(p0_ee * N_E) 271 J_IX = 0.8 * J_EX 272 u_extE = J_EX * m_X 273 u_extI = J_IX * m_X 274 u_ext = np.array([u_extE] * self.Q + [u_extI] * self.Q, dtype=float) 275 bias = u_ext - V_th_vec 276 if use_temporal_variance: 277 return A, B, C, bias 278 return A, B, bias 279 280 def _structured_matrix(self, values: Dict[str, float]) -> np.ndarray: 281 size = 2 * self.Q 282 matrix = np.zeros((size, size), dtype=float) 283 for target in range(size): 284 tgt_type = "E" if target < self.Q else "I" 285 tgt_cluster = target % self.Q 286 for source in range(size): 287 src_type = "E" if source < self.Q else "I" 288 src_cluster = source % self.Q 289 suffix = "_IN" if tgt_cluster == src_cluster else "_OUT" 290 key = f"{tgt_type}{src_type}{suffix}" 291 matrix[target, source] = values[key] 292 return matrix 293 294 @staticmethod 295 def _variance(prob: float, weight: float, population: float, connection_type: str) -> float: 296 conn_kind = connection_type.lower() 297 if conn_kind == "poisson": 298 return prob * weight ** 2 * population 299 if conn_kind == "fixed-indegree": 300 return prob * (1 - (1 / population)) * weight ** 2 * population 301 # else: Bernoulli 302 return prob * (1 - prob) * weight ** 2 * population 303 304 @staticmethod 305 def _variance_coeffs( 306 prob: float, 307 weight: float, 308 population: float, 309 connection_type: str, 310 use_temporal_variance: bool, 311 use_quadratic_variance: bool, 312 ) -> tuple[float, float]: 313 #if not use_temporal_variance: 314 # return EIClusterNetwork._variance(prob, weight, population, connection_type), 0.0 315 conn_kind = connection_type.lower() 316 if conn_kind == "poisson": 317 if use_quadratic_variance: 318 b= (prob + prob ** 2) * weight ** 2 * population 319 else: 320 b = (prob) * weight ** 2 * population 321 c = -(prob ** 2) * weight ** 2 * population 322 return b, c 323 if conn_kind == "fixed-indegree": 324 k_eff = (prob * population) * (1 - (1 / population)) # finite-size correction for fixed indegree 325 b = k_eff * weight ** 2 326 c = -k_eff * weight ** 2 327 return b, c 328 # else: Bernoulli 329 b = prob * weight ** 2 * population 330 c = -(prob ** 2) * weight ** 2 * population 331 return b, c
Specialized mean-field system for the clustered E/I network.
Examples
>>> parameter = {
... "Q": 2,
... "N_E": 4000,
... "N_I": 1000,
... "V_th": 1.0,
... "g": 1.0,
... "p0_ee": 0.2,
... "p0_ei": 0.2,
... "p0_ie": 0.2,
... "p0_ii": 0.2,
... "R_Eplus": 1.0,
... "R_j": 0.8,
... "m_X": 0.1,
... "tau_e": 1.0,
... "tau_i": 2.0,
... }
>>> system = EIClusterNetwork(parameter, v_focus=0.2, prefer_jax=False)
>>> system.population_count
4
46 def __init__( 47 self, 48 parameter: Dict, 49 v_focus: float, 50 *, 51 kappa: Optional[float] = None, 52 connection_type: Optional[str] = None, 53 use_temporal_variance: bool = True, 54 use_quadratic_variance: bool = True, 55 focus_population=None, 56 prefer_jax: bool = True, 57 max_steps: int = 256, 58 ) -> None: 59 self.Q = int(parameter["Q"]) 60 self._explicit_kappa = kappa 61 self._explicit_connection = connection_type 62 self.collapse_types = bool(parameter.get("collapse_types", True)) 63 super().__init__( 64 parameter, 65 v_focus, 66 focus_population=focus_population, 67 prefer_jax=prefer_jax, 68 max_steps=max_steps, 69 kappa=kappa, 70 connection_type=connection_type, 71 use_temporal_variance=use_temporal_variance, 72 use_quadratic_variance=use_quadratic_variance, 73 )
892def ensure_output_folder(parameter: Dict, *, tag: Optional[str] = None) -> str: 893 """Create and return the cache folder for a mean-field parameter set. 894 895 Examples 896 -------- 897 >>> folder = ensure_output_folder({"connection_type": "bernoulli", "R_j": 0.8, "Q": 2}) 898 >>> folder.startswith("data/Bernoulli/Rj00_80/") 899 True 900 """ 901 conn_name = str(parameter.get("connection_type", "bernoulli")).strip() 902 conn_label = conn_name.capitalize() 903 r_j = float(parameter.get("R_j", 0.0)) 904 rj_label = f"Rj{r_j:05.2f}".replace(".", "_") 905 if tag is None: 906 filtered = {k: v for k, v in parameter.items() if k != "R_Eplus"} 907 tag = sim_tag_from_cfg(filtered) 908 folder = os.path.join("data", conn_label, rj_label, tag) 909 os.makedirs(folder, exist_ok=True) 910 return folder
Create and return the cache folder for a mean-field parameter set.
Examples
>>> folder = ensure_output_folder({"connection_type": "bernoulli", "R_j": 0.8, "Q": 2})
>>> folder.startswith("data/Bernoulli/Rj00_80/")
True
913def serialize_erf( 914 file_path: str, 915 parameter: Dict, 916 result: ERFResult, 917 *, 918 focus_count: Optional[int] = None, 919) -> Optional[str]: 920 """Serialize a completed ERF sweep to a pickle file. 921 922 Expected output 923 --------------- 924 Returns `file_path` on success and `None` when `result.completed` is 925 `False`. 926 """ 927 if not result.completed: 928 return None 929 R_value = float(parameter["R_Eplus"]) 930 focus_value = focus_count if focus_count is not None else parameter.get("focus_count", 1) 931 focus_value = 1 if focus_value is None else int(focus_value) 932 key = f"{R_value:.12g}_focus{focus_value}" 933 payload = {key: [result.x_data, result.y_data, result.solves, parameter]} 934 os.makedirs(os.path.dirname(file_path), exist_ok=True) 935 with open(file_path, "wb") as file: 936 import pickle 937 938 pickle.dump(payload, file) 939 return file_path
Serialize a completed ERF sweep to a pickle file.
Expected output
Returns file_path on success and None when result.completed is
False.
942def aggregate_data(folder: str) -> str: 943 """Merge individual ERF pickle files into one combined pickle. 944 945 Expected output 946 --------------- 947 Returns the path to `all_data_P_Eplus.pkl` inside `folder`. 948 """ 949 import glob 950 import pickle 951 952 list_dir = sorted( 953 f for f in glob.glob(f"{folder}/*.pkl") if os.path.basename(f) != "all_data_P_Eplus.pkl" 954 ) 955 if not list_dir: 956 raise FileNotFoundError(f"No .pkl files found in {folder}") 957 with open(list_dir[0], "rb") as file: 958 all_files = pickle.load(file) 959 for name in list_dir[1:]: 960 with open(name, "rb") as file: 961 data = pickle.load(file) 962 all_files.update(data) 963 name = "all_data_P_Eplus.pkl" 964 path_sum = os.path.join(folder, name) 965 with open(path_sum, "wb") as file: 966 pickle.dump(all_files, file) 967 return path_sum
Merge individual ERF pickle files into one combined pickle.
Expected output
Returns the path to all_data_P_Eplus.pkl inside folder.