Source code for tangl.vm.runtime.ledger

# tangl/vm/runtime/ledger.py
"""Persistent session state for a single traversal.

The Ledger owns long-lived state that persists across player actions:
the graph, cursor position and history, return stack, and accumulated output.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Optional, Self
from uuid import UUID

from pydantic import Field, model_validator

from tangl.core import BaseFragment, BehaviorRegistry, Entity, Graph, OrderedRegistry, Selector
from tangl.type_hints import UnstructuredData
from tangl.vm.traversable import TraversableEdge, TraversableNode

from .causality import CausalityMode
from .frame import Frame, PhaseCtx, StepTrace
from ..replay import (
    CausalityTransitionRecord,
    CheckpointRecord,
    RollbackRecord,
    StepRecord,
    get_replay_engine,
)
from ..replay.contracts import ReplayDelta

if TYPE_CHECKING:
    from tangl.vm import Dependency


__all__ = ["Ledger"]

logger = logging.getLogger(__name__)


[docs] class Ledger(Entity): """Persistent traversal state across player actions.""" graph: Graph output_stream: OrderedRegistry = Field(default_factory=OrderedRegistry) local_behaviors: BehaviorRegistry = Field( default_factory=lambda: BehaviorRegistry(label="ledger.local.dispatch"), exclude=True, ) cursor_id: UUID cursor_history: list[UUID] = Field(default_factory=list) replay_algorithm_id: str = "diff_v1" checkpoint_cadence: int = 1 causality_mode: CausalityMode = CausalityMode.CLEAN causality_break_reason: str | None = None causality_break_step_id: str | None = None @model_validator(mode="before") @classmethod def _coerce_legacy_stream_aliases(cls, data: Any) -> Any: if not isinstance(data, dict): return data if "output_stream" not in data and "records" in data: payload = dict(data) payload["output_stream"] = payload["records"] return payload return data @property def records(self) -> OrderedRegistry: """Legacy alias for :attr:`output_stream`.""" return self.output_stream @records.setter def records(self, value: OrderedRegistry) -> None: self.output_stream = value @property def cursor(self) -> TraversableNode: """The current node, resolved from the graph.""" return self.graph.get(self.cursor_id) @cursor.setter def cursor(self, value: TraversableNode) -> None: if value is not None: self.cursor_id = value.uid @property def turn(self) -> int: """Distinct position changes, ignoring self-loops.""" from tangl.vm.traversal import count_turns return count_turns(self.cursor_history) @property def step(self) -> int: """Alias for ``cursor_steps``.""" return self.cursor_steps @step.setter def step(self, value: int) -> None: self.cursor_steps = value call_stack_ids: list[UUID] = Field(default_factory=list) last_redirect: dict | None = None redirect_trace: list[dict] = Field(default_factory=list) def _call_stack(self) -> list[TraversableEdge]: """Resolve call stack UIDs to edge objects (introspection only).""" call_stack: list[TraversableEdge] = [] for edge_id in self.call_stack_ids: edge = self.graph.get(edge_id) if edge is None: raise ValueError( f"Call stack contains unresolved edge id: {edge_id}" ) call_stack.append(edge) return call_stack def push_call(self, edge: TraversableEdge) -> None: """Push a call edge onto the return stack.""" if edge.return_phase is None: raise ValueError("Putting a call onto the stack requires a return phase/type") self.call_stack_ids.append(edge.uid) def pop_call(self) -> TraversableEdge: """Pop and return the most recent call edge.""" call_edge_id = self.call_stack_ids.pop() return self.graph.get(call_edge_id) reentrant_steps: int = -1 cursor_steps: int = -1 choice_steps: int = -1 user: Optional[Entity] = Field(None, exclude=True) user_id: Optional[UUID] = None worker_dispatcher: Any = Field(default=None, exclude=True) @classmethod def from_graph( cls, graph: Graph, entry_id: UUID | None = None, *, uid: UUID | None = None, ) -> Self: """Construct and initialize a ledger at a graph entry node.""" resolved_entry_id = cls._resolve_entry_id(graph, entry_id) payload: dict[str, Any] = {"graph": graph, "cursor_id": resolved_entry_id} if uid is not None: payload["uid"] = uid ledger = cls(**payload) ledger._seed_counters(entry_id=resolved_entry_id) ledger.initialize_entry() return ledger @staticmethod def _resolve_entry_id(graph: Graph, entry_id: UUID | None = None) -> UUID: """Return explicit entry id or the graph's default traversal entry.""" if entry_id is not None: return entry_id graph_entry_id = getattr(graph, "initial_cursor_id", None) if isinstance(graph_entry_id, UUID): return graph_entry_id if graph_entry_id is None: raise ValueError( "Entry id not provided and graph does not define initial_cursor_id", ) raise ValueError("Graph initial_cursor_id must be a UUID when present") def _seed_counters(self, entry_id: UUID | None = None) -> None: """Initialize cursor and counters without dispatch/pipeline execution.""" entry_id = entry_id or self.cursor_id entry_node = self.graph.get(entry_id) if entry_node is None: raise ValueError(f"Entry node not found: {entry_id}") self.cursor_id = entry_node.uid self.cursor_history = [entry_node.uid] self.reentrant_steps = 0 self.cursor_steps = 0 self.choice_steps = 0 self.call_stack_ids = [] self.last_redirect = None self.redirect_trace = [] self.causality_mode = CausalityMode.CLEAN self.causality_break_reason = None self.causality_break_step_id = None def initialize_entry(self) -> None: """Finalize entry initialization and persist initial checkpoint.""" frame = self.get_frame() self.call_stack_ids = [e.uid for e in frame.return_stack] self.save_snapshot(force=True) def initialize_ledger(self, entry_id: UUID | None = None) -> None: """Backward-compatible initializer for entry setup.""" resolved_entry_id = self._resolve_entry_id(self.graph, entry_id) self._seed_counters(entry_id=resolved_entry_id) self.initialize_entry() def model_post_init(self, __context) -> None: """Initialize fresh ledgers; preserve structured ledgers as provided.""" if self.cursor_steps < 0: self._seed_counters() elif not self.cursor_history and self.cursor_id is not None: self.cursor_history.append(self.cursor_id) def get_frame(self) -> Frame: """Create an ephemeral frame for the next pipeline execution.""" return Frame( self.graph, self.cursor, self.output_stream, self._call_stack(), ledger_local_behaviors=self.local_behaviors, step_base=self.cursor_steps, meta=self._frame_meta(), causality_mode=self.causality_mode, mark_soft_dirty_callback=self.mark_soft_dirty, escalate_to_hard_dirty_callback=self.escalate_to_hard_dirty, ) def _frame_meta(self) -> dict[str, Any]: meta: dict[str, Any] = {"causality_mode": self.causality_mode.value} meta["cursor_history"] = list(self.cursor_history) if self.user is not None: meta["user"] = self.user if self.user_id is not None: meta["user_id"] = self.user_id if self.worker_dispatcher is not None: meta["worker_dispatcher"] = self.worker_dispatcher return meta def _local_authorities(self) -> list[BehaviorRegistry]: if not isinstance(self.local_behaviors, BehaviorRegistry): return [] if not self.local_behaviors.members: return [] return [self.local_behaviors] def _record_causality_transition( self, *, from_mode: CausalityMode, to_mode: CausalityMode, reason: str, step_id: str | None = None, ) -> None: step = max(self.cursor_steps, 0) self.output_stream.append( CausalityTransitionRecord( step=step, from_mode=from_mode.value, to_mode=to_mode.value, reason=reason, step_id=step_id, cursor_id=self.cursor_id, ) ) logger.warning( "Causality transition %s -> %s at step=%s cursor_id=%s reason=%s step_id=%s", from_mode.value, to_mode.value, step, self.cursor_id, reason, step_id, ) def mark_soft_dirty(self, reason: str, step_id: str | None = None) -> bool: """Transition from CLEAN to SOFT_DIRTY and audit the change once.""" if self.causality_mode is not CausalityMode.CLEAN: return False previous = self.causality_mode self.causality_mode = CausalityMode.SOFT_DIRTY self._record_causality_transition( from_mode=previous, to_mode=self.causality_mode, reason=reason, step_id=step_id, ) return True def escalate_to_hard_dirty(self, reason: str, step_id: str | None = None) -> bool: """Escalate to HARD_DIRTY once; never downgrade within this session.""" if self.causality_mode is CausalityMode.HARD_DIRTY: return False previous = self.causality_mode self.causality_mode = CausalityMode.HARD_DIRTY self.causality_break_reason = reason self.causality_break_step_id = step_id self._record_causality_transition( from_mode=previous, to_mode=self.causality_mode, reason=reason, step_id=step_id, ) return True def _record_step(self, trace: StepTrace) -> None: """Build and append replay records for one traced frame hop.""" engine = get_replay_engine(self.replay_algorithm_id) delta = engine.build_delta(before_graph=trace.before_graph, after_graph=trace.after_graph) delta_id: UUID | None = None if delta is not None: self.output_stream.append(delta) delta_id = delta.uid self.output_stream.append( StepRecord( step=trace.step, edge_id=trace.edge_id, cursor_id=trace.cursor_id, entry_phase=trace.entry_phase.name if trace.entry_phase is not None else None, was_choice=trace.was_choice, delta_id=delta_id, state_hash=trace.state_hash, call_stack_ids=list(trace.call_stack_ids), algorithm_id=self.replay_algorithm_id, ) ) @staticmethod def _selection_destination_dependency(edge: TraversableEdge) -> Optional["Dependency"]: from tangl.vm import Dependency graph = getattr(edge, "graph", None) if graph is None: return None deps = graph.find_edges( Selector( has_kind=Dependency, predecessor=edge, label="destination", satisfied=False, ) ) dep = next(deps, None) if dep is not None: return dep deps = graph.find_edges( Selector(has_kind=Dependency, predecessor=edge, satisfied=False) ) return next(deps, None) def _provision_selected_destination(self, edge: TraversableEdge) -> None: if edge.successor is not None: return dep = self._selection_destination_dependency(edge) if dep is None: return from tangl.vm import Resolver ctx = self._make_phase_ctx() resolved = Resolver.from_ctx(ctx).resolve_dependency( dep, allow_stubs=self.causality_mode is CausalityMode.HARD_DIRTY, _ctx=ctx, ) if resolved and dep.successor is not None and edge.successor is None: edge.set_successor(dep.successor, _ctx=ctx) def _make_phase_ctx(self) -> PhaseCtx: return PhaseCtx( graph=self.graph, cursor_id=self.cursor_id, step=max(self.cursor_steps, 0), causality_mode=self.causality_mode, mark_soft_dirty_callback=self.mark_soft_dirty, escalate_to_hard_dirty_callback=self.escalate_to_hard_dirty, local_authorities=self._local_authorities(), ) def _require_choice_edge(self, edge_id: UUID) -> TraversableEdge: edge = self.graph.get(edge_id) if edge is None: raise ValueError(f"Choice edge not found: {edge_id}") return edge def _run_frame_choice( self, *, frame: Frame, edge: TraversableEdge, choice_payload: Any = None, ) -> None: if hasattr(frame, "step_observer"): frame.step_observer = self._record_step frame.resolve_choice(edge, choice_payload=choice_payload) @staticmethod def _validate_frame_return_stack(frame: Frame) -> None: for call_edge in frame.return_stack: if call_edge is None: raise ValueError("Frame return stack contains a null edge") def _sync_reentrant_steps(self, *, frame: Frame) -> None: prev_id = self.cursor_history[-1] if self.cursor_history else None for node_id in frame.cursor_trace: if prev_id is not None and node_id == prev_id: self.reentrant_steps += 1 prev_id = node_id def _sync_from_frame(self, *, frame: Frame) -> None: self.choice_steps += 1 self.cursor_steps += frame.cursor_steps self._sync_reentrant_steps(frame=frame) self.cursor_id = frame.cursor.uid self.cursor_history.extend(frame.cursor_trace) self.call_stack_ids = [edge.uid for edge in frame.return_stack] self.last_redirect = frame.last_redirect self.redirect_trace = list(frame.redirect_trace) def _commit_frame_choice(self, *, frame: Frame) -> None: self._validate_frame_return_stack(frame) self._sync_from_frame(frame=frame) self.save_snapshot(cadence=self.checkpoint_cadence) def resolve_choice(self, edge_id: UUID, *, choice_payload: Any = None) -> None: """Resolve a player choice and sync frame results into ledger state.""" edge = self._require_choice_edge(edge_id) self._provision_selected_destination(edge) frame = self.get_frame() self._run_frame_choice(frame=frame, edge=edge, choice_payload=choice_payload) self._commit_frame_choice(frame=frame) @staticmethod def _coerce_fragment_record(record: Any) -> BaseFragment | None: """Normalize mixed fragment record shapes into the canonical fragment base.""" if isinstance(record, BaseFragment): return record fragment_type = getattr(record, "fragment_type", None) if fragment_type is None: return None raw_step = getattr(record, "step", -1) step = -1 if raw_step is None else int(raw_step) payload: dict[str, Any] = { "fragment_type": str(fragment_type), "step": step, } for key in ("content", "text", "source_id", "edge_id", "available", "unavailable_reason"): if hasattr(record, key): payload[key] = getattr(record, key) return BaseFragment(**payload) def get_journal(self, *, since_step: int = 0, limit: int = 0) -> list[BaseFragment]: """Return output fragments in chronological order, optionally filtered.""" fragments: list[BaseFragment] = [] for record in self.output_stream.values(): fragment = self._coerce_fragment_record(record) if fragment is None: continue raw_step = getattr(fragment, "step", -1) step = -1 if raw_step is None else int(raw_step) if step >= since_step or step < 0: fragments.append(fragment) if limit > 0 and len(fragments) > limit: fragments = fragments[-limit:] return fragments def unstructure(self) -> UnstructuredData: """Serialize ledger state to plain data for persistence.""" return { "uid": self.uid, "label": self.label, "cursor_id": self.cursor_id, "cursor_history": list(self.cursor_history), "cursor_steps": self.cursor_steps, "choice_steps": self.choice_steps, "reentrant_steps": self.reentrant_steps, "call_stack_ids": list(self.call_stack_ids), "last_redirect": self.last_redirect, "redirect_trace": self.redirect_trace, "causality_mode": self.causality_mode.value, "causality_break_reason": self.causality_break_reason, "causality_break_step_id": self.causality_break_step_id, "user_id": str(self.user_id) if self.user_id is not None else None, "replay_algorithm_id": self.replay_algorithm_id, "checkpoint_cadence": self.checkpoint_cadence, "graph": self.graph.unstructure(), "output_stream": self.output_stream.unstructure(), } @classmethod def structure(cls, data: UnstructuredData) -> Self: """Reconstruct a ledger from serialized data.""" def _coerce_uuid(value: UUID | str) -> UUID: if isinstance(value, UUID): return value return UUID(str(value)) def _coerce_kind_refs(value: Any) -> Any: if isinstance(value, dict): normalized: dict[str, Any] = {} for key, item in value.items(): if key == "kind" and isinstance(item, str): normalized[key] = Entity.dereference_cls_name(item) or item else: normalized[key] = _coerce_kind_refs(item) return normalized if isinstance(value, list): return [_coerce_kind_refs(item) for item in value] return value graph = Graph.structure(_coerce_kind_refs(data["graph"])) output_stream = OrderedRegistry.structure(_coerce_kind_refs(data.get("output_stream", {}))) return cls( uid=_coerce_uuid(data["uid"]), label=data.get("label", ""), graph=graph, output_stream=output_stream, cursor_id=_coerce_uuid(data["cursor_id"]), cursor_history=[_coerce_uuid(uid) for uid in data.get("cursor_history", [])], cursor_steps=data.get("cursor_steps", -1), choice_steps=data.get("choice_steps", -1), reentrant_steps=data.get("reentrant_steps", -1), call_stack_ids=[_coerce_uuid(uid) for uid in data.get("call_stack_ids", [])], last_redirect=data.get("last_redirect"), redirect_trace=list(data.get("redirect_trace", [])), causality_mode=CausalityMode(data.get("causality_mode", CausalityMode.CLEAN.value)), causality_break_reason=data.get("causality_break_reason"), causality_break_step_id=data.get("causality_break_step_id"), user_id=_coerce_uuid(data["user_id"]) if data.get("user_id") else None, replay_algorithm_id=data.get("replay_algorithm_id", "diff_v1"), checkpoint_cadence=data.get("checkpoint_cadence", 1), ) def save_snapshot(self, *, force: bool = False, cadence: int = 0) -> Optional[CheckpointRecord]: """Save a checkpoint if forced or cadence says one is due.""" cadence = cadence if cadence > 0 else self.checkpoint_cadence should_save = force or ( cadence > 0 and self.choice_steps >= 0 and (self.choice_steps % cadence) == 0 ) if not should_save: return None engine = get_replay_engine(self.replay_algorithm_id) checkpoint = engine.make_checkpoint( graph=self.graph, step=self.cursor_steps, cursor_id=self.cursor_id, call_stack_ids=self.call_stack_ids, ) self.output_stream.append(checkpoint) return checkpoint def push_snapshot(self) -> Optional[CheckpointRecord]: """Legacy alias for forcing a checkpoint save.""" return self.save_snapshot(force=True) def _ordered_records(self) -> list[Entity]: # Preserve actual stream append order. HasOrder.seq is class-local and # cannot be used to globally sort mixed record kinds. return list(self.output_stream.values()) def _step_records(self, *, upto_step: int | None = None) -> list[StepRecord]: selector = Selector(has_kind=StepRecord) records = [ record for record in selector.filter(self.output_stream) if record.algorithm_id == self.replay_algorithm_id ] if upto_step is not None: records = [record for record in records if record.step <= upto_step] return sorted(records, key=lambda record: (record.step, record.seq)) def _checkpoint_records(self) -> list[CheckpointRecord]: selector = Selector(has_kind=CheckpointRecord) records = [ record for record in selector.filter(self.output_stream) if record.algorithm_id == self.replay_algorithm_id ] return sorted(records, key=lambda record: (record.step, record.seq)) def rollback_to_step(self, target_step: int, *, reason: str | None = None) -> None: """Restore ledger state to ``target_step`` with destructive truncation.""" if target_step < 0: raise ValueError("target_step must be >= 0") if target_step > self.cursor_steps: raise ValueError( f"target_step {target_step} must be <= current step {self.cursor_steps}" ) if target_step == self.cursor_steps: return prior_step = self.cursor_steps engine = get_replay_engine(self.replay_algorithm_id) checkpoints = self._checkpoint_records() checkpoint = next( (record for record in reversed(checkpoints) if record.step <= target_step), None, ) if checkpoint is None: raise RuntimeError("No checkpoint available for rollback") graph = engine.restore_checkpoint(checkpoint) all_active_steps = self._step_records(upto_step=target_step) replay_steps = [ record for record in all_active_steps if checkpoint.step < record.step <= target_step ] for record in replay_steps: if record.delta_id is None: continue delta = self.output_stream.get(record.delta_id) if delta is None: raise RuntimeError(f"Missing delta for StepRecord {record.uid}") if not isinstance(delta, ReplayDelta): raise RuntimeError(f"Invalid delta type for StepRecord {record.uid}") graph = engine.apply_delta(graph=graph, delta=delta) final_cursor_id = checkpoint.cursor_id final_call_stack_ids = list(checkpoint.call_stack_ids) if all_active_steps: final_cursor_id = all_active_steps[-1].cursor_id final_call_stack_ids = list(all_active_steps[-1].call_stack_ids) history_start = self.cursor_history[0] if self.cursor_history else checkpoint.cursor_id if checkpoints and checkpoints[0].step == 0: history_start = checkpoints[0].cursor_id history: list[UUID] = [history_start] history.extend(record.cursor_id for record in all_active_steps) reentrant_steps = 0 for index in range(1, len(history)): if history[index] == history[index - 1]: reentrant_steps += 1 choice_steps = sum(1 for record in all_active_steps if record.was_choice) ordered_records = self._ordered_records() cutoff_index = -1 for index, record in enumerate(ordered_records): record_step = getattr(record, "step", None) if isinstance(record_step, int) and record_step <= target_step: cutoff_index = index kept_records = ordered_records[: cutoff_index + 1] if cutoff_index >= 0 else [] truncated_record_count = len(ordered_records) - len(kept_records) truncated_step_count = sum(1 for record in self._step_records() if record.step > target_step) new_stream = OrderedRegistry() new_stream.extend(kept_records) new_stream.append( RollbackRecord( resumed_step=target_step, prior_step=prior_step, truncated_record_count=truncated_record_count, truncated_step_count=truncated_step_count, reason=reason, ) ) self.output_stream = new_stream self.graph = graph self.cursor_id = final_cursor_id self.call_stack_ids = final_call_stack_ids self.cursor_steps = target_step self.choice_steps = choice_steps self.cursor_history = history self.reentrant_steps = reentrant_steps self.last_redirect = None self.redirect_trace = []