Source code for agentgit.core.rollback_protocol
"""Rollback protocol for tool operations in the LangGraph agent system."""
from typing import Dict, Any, Optional, List, Callable, Mapping
from dataclasses import dataclass
# Checkpoint tool names that should be excluded from rollback
CHECKPOINT_TOOL_NAMES = {
'create_checkpoint',
'list_checkpoints',
'rollback_to_checkpoint',
'delete_checkpoint',
'get_checkpoint_info',
'cleanup_auto_checkpoints'
}
[docs]
@dataclass
class ToolSpec:
"""Specification for a reversible tool."""
name: str
forward: Callable[[Mapping[str, Any]], Any]
reverse: Optional[Callable[[Mapping[str, Any], Any], Any]] = None
[docs]
@dataclass
class ToolInvocationRecord:
"""Record of a tool invocation."""
tool_name: str
args: Dict[str, Any]
result: Any
success: bool
error_message: Optional[str] = None
[docs]
@dataclass
class ReverseInvocationResult:
"""Result of a reverse tool operation."""
tool_name: str
reversed_successfully: bool
error_message: Optional[str] = None
[docs]
class ToolRollbackRegistry:
"""Registry for managing reversible tool operations."""
def __init__(self):
"""Initialize the registry."""
self._tools: Dict[str, ToolSpec] = {}
self._track: List[ToolInvocationRecord] = []
[docs]
def register_tool(self, spec: ToolSpec):
"""Register a tool with optional reverse handler.
Args:
spec: Tool specification with forward and optional reverse functions
"""
self._tools[spec.name] = spec
[docs]
def get_tool(self, name: str) -> Optional[ToolSpec]:
"""Get a tool specification by name.
Args:
name: Tool name
Returns:
Tool specification or None if not found
"""
return self._tools.get(name)
[docs]
def record_invocation(
self,
tool_name: str,
args: Dict[str, Any],
result: Any,
success: bool = True,
error_message: Optional[str] = None
):
"""Record a tool invocation.
Args:
tool_name: Name of the tool
args: Arguments passed to the tool
result: Result from the tool
success: Whether the invocation succeeded
error_message: Optional error message if failed
"""
record = ToolInvocationRecord(
tool_name=tool_name,
args=args,
result=result,
success=success,
error_message=error_message
)
self._track.append(record)
[docs]
def get_track(self) -> List[ToolInvocationRecord]:
"""Get the current tool invocation track.
Returns:
List of tool invocation records
"""
return self._track.copy()
[docs]
def truncate_track(self, position: int):
"""Truncate the track to a specific position.
Args:
position: Position to truncate to
"""
self._track = self._track[:position]
[docs]
def rollback(self) -> List[ReverseInvocationResult]:
"""Rollback all recorded tool invocations.
Returns:
List of reverse invocation results
"""
results = []
# Process in reverse order
for record in reversed(self._track):
if record.tool_name in CHECKPOINT_TOOL_NAMES:
continue
spec = self._tools.get(record.tool_name)
if not spec or not spec.reverse:
results.append(
ReverseInvocationResult(
tool_name=record.tool_name,
reversed_successfully=False,
error_message="No reverse handler registered"
)
)
continue
try:
spec.reverse(record.args, record.result)
results.append(
ReverseInvocationResult(
tool_name=record.tool_name,
reversed_successfully=True
)
)
except Exception as e:
results.append(
ReverseInvocationResult(
tool_name=record.tool_name,
reversed_successfully=False,
error_message=str(e)
)
)
# Clear track after rollback
self._track.clear()
return results
[docs]
def redo(self) -> List[ToolInvocationRecord]:
"""Re-execute forward handlers for recorded tools.
Returns:
List of new invocation records
"""
new_records = []
old_track = self._track.copy()
self._track.clear()
for record in old_track:
spec = self._tools.get(record.tool_name)
if spec and spec.forward:
try:
result = spec.forward(record.args)
self.record_invocation(
record.tool_name,
record.args,
result,
success=True
)
new_records.append(self._track[-1])
except Exception as e:
self.record_invocation(
record.tool_name,
record.args,
None,
success=False,
error_message=str(e)
)
new_records.append(self._track[-1])
return new_records