1:- module(egraph, [add_term//2, union//2, saturate//1, saturate//2,
    2                   extract//2, lookup/2]).

E-graph implementation for term rewriting and saturation

This module implements an E-graph (Equivalence Graph) data structure, commonly used for efficient term rewriting, congruence closure, and e-matching. The E-graph state is typically threaded through operations using DCG notation.

Rewrite rules are automatically compiled into efficient DCG predicates via term expansion. See the egraph_compile module for full details. The supported rule declarations are:

Main predicates:

*/

   27:- use_module(library(dcg/high_order)).   28:- use_module(library(ordsets)).   29:- use_module(library(rbtrees)).   30:- use_module(library(heaps)).   31
   32:- use_module(egraph/compile).   33
   34cost:attr_unify_hook(_, _) :-
   35   true.
   36const:attr_unify_hook(XConst, Y) :-
   37   (  get_attr(Y, const, YConst)
   38   -> (  XConst =:= YConst
   39      -> true
   40      ;  domain_error(XConst, YConst)
   41      )
   42   ;  var(Y)
   43   -> put_attr(Y, const, XConst)
   44   ;  true
   45   ).
 lookup(+Pair, +SortedPairs) is semidet
Retrieves a value from a sorted list of pairs using standard term comparison. The search is unrolled for performance. Adapted from ord_memberchk/2.
Arguments:
Pair- A Key-Value pair where Key is the target key to find, and Value is unified with the associated value.
SortedPairs- A list of Key-Value pairs sorted by Key.
   55lookup(Item-V, [X1-V1, X2-V2, X3-V3, X4-V4|Xs]) :-
   56   !,
   57   compare(R4, Item, X4),
   58   (   R4=(>)
   59   ->  lookup(Item-V, Xs)
   60   ;   R4=(<)
   61   ->  compare(R2, Item, X2),
   62      (   R2=(>)
   63      ->  Item==X3, V = V3
   64      ;   R2=(<)
   65      ->  Item==X1, V = V1
   66      ;   V = V2
   67      )
   68   ;   V = V4
   69   ).
   70lookup(Item-V, [X1-V1, X2-V2|Xs]) :-
   71   !,
   72   compare(R2, Item, X2),
   73   (   R2=(>)
   74   ->  lookup(Item-V, Xs)
   75   ;   R2=(<)
   76   ->  Item==X1, V = V1
   77   ;   V = V2
   78   ).
   79lookup(Item-V, [X1-V1]) :-
   80   Item==X1, V = V1.
 add_term(+Term, -Id)// is det
Adds a term to the E-graph, returning its e-class ID. Compound terms are recursively traversed and their arguments are added to the E-graph first. Variables are represented using '$VAR'/1 wrappers.
Arguments:
Term- The term to be added.
Id- The e-class ID representing the added term.
   92add_term(Term, Id), var(Term) ==>
   93   add_node('$VAR'(Term), Id).
   94add_term(Term, Id), is_dict(Term) ==>
   95   % rework this with dict_same_keys
   96   {
   97      dict_pairs(Term, Tag, Pairs),
   98      pairs_keys_values(Pairs, Keys, Values)
   99   },
  100   foldl(add_term, Values, Ids),
  101   {
  102      pairs_keys_values(Data, Keys, Ids),
  103      dict_create(Node, Tag, Data)
  104   },
  105   add_node(Node, Id).
  106add_term(Term, Id), compound(Term) ==>
  107   { Term =.. [F | Args] },
  108   foldl(add_term, Args, Ids),
  109   { Node =.. [F | Ids] },
  110   add_node(Node, Id).
  111add_term(Term, Id) ==>
  112   add_node(Term, Id).
  113
  114add_node(Node-Id, In, Out) :-
  115   add_node(Node, Id, In, Out).
  116add_node(Node, Id, In, Out) :-
  117   (  lookup(Node-node(Id, _Cost), In)
  118   -> Out = In
  119   ;  ord_add_element(In, Node-node(Id, 1), Out),
  120      (  number(Node)
  121      -> put_attr(Id, const, Node)
  122      ;  true
  123      )
  124   ).
  125
  126% rules([Rule | Rules], Index, Pat-node(Id, Cost), UnifsIn, UnifsOut) -->
  127%    call(Rule, Pat, Id, Index, UnifsIn, UnifsTmp),
  128%    rules(Rules, Index, Pat-node(Id, Cost), UnifsTmp, UnifsOut).
  129% rules([], _, _, Unifs, Unifs) --> [].
  130
  131:- dynamic rules//5.  132:- non_terminal(user:egraph:rules/7).  133
  134chain_rule(M, Pat, Id, Index, Rule, Mod:Call, UnifsIn, UnifsOut) :-
  135   strip_module(M:Rule, Mod, Name),
  136   Call =.. [Name, Pat, Id, Index, UnifsIn, UnifsOut].
  137
  138:- meta_predicate compile_rules(:, -).  139
  140compile_rules(M:Rules, RulesId) :-
  141   term_hash(M:Rules, RulesId),
  142   (  clause(rules(RulesId, _, _, _, _, _, _), _)
  143   -> true
  144   ;  (  Rules = [_ | _]
  145      -> foldl(chain_rule(M, Pat, Id, Index), Rules, Calls, UnifsIn, UnifsOut),
  146         comma_list(Body, Calls)
  147      ;  Body = true, UnifsOut = UnifsIn
  148      ),
  149      Head = egraph:rules(RulesId, Index, Pat-node(Id, _), UnifsIn, UnifsOut),
  150      Clause = (Head --> Body),
  151      expand_term(Clause, Term),
  152      asserta(Term)
  153   ).
  154
  155
  156make_index(In, Index) :-
  157   index_pairs(In, UnsortedPairs),
  158   keysort(UnsortedPairs, IdPairs),
  159   group_pairs_by_key(IdPairs, Groups),
  160   ord_list_to_rbtree(Groups, Index).
  161
  162index_pairs([], []).
  163index_pairs([Node-node(Id, _Cost)|T0], [Id-Node|T1]) :-
  164   index_pairs(T0, T1).
  165
  166match([], _, _, Unifs, Unifs) --> [].
  167match([Node | Rest], Rules, Index, UnifsIn, UnifsOut) -->
  168   rules(Rules, Index, Node, UnifsIn, UnifsTmp),
  169   match(Rest, Rules, Index, UnifsTmp, UnifsOut).
 union(+Id1, +Id2)// is det
Merges two e-classes by unifying their IDs and merging their underlying nodes.
Arguments:
Id1- The first e-class ID.
Id2- The second e-class ID.
  178union(A, A) -->
  179   merge_nodes.
  180
  181merge_nodes(In, Out) :-
  182   sort(In, Sort),
  183   group_pairs_by_key(Sort, Groups),
  184   merge_groups(Groups, Tmp, false, Merged),
  185   (  Merged == true
  186   -> merge_nodes(Tmp, Out)
  187   ;  Out = Sort
  188   ).
  189
  190merge_groups([Sig-[H | T] | Nodes], [Sig-Node | Worklist], In, Out) :-
  191   merge_group(T, H, Node),
  192   (  T == []
  193   -> Tmp = In
  194   ;  Tmp = true
  195   ),
  196   merge_groups(Nodes, Worklist, Tmp, Out).
  197merge_groups([], [], In, In).
  198
  199merge_group([], Node, Node).
  200merge_group([node(Id, Cost) | T], node(Id, PrevCost), Out) :-
  201   MinCost is min(Cost, PrevCost),
  202   merge_group(T, node(Id, MinCost), Out).
  203
  204apply_unifs([]).
  205apply_unifs([A=A | L]) :-
  206   apply_unifs(L).
  207
  208rebuild(Matches, Unifs, Out) :-
  209   apply_unifs(Unifs),
  210   merge_nodes(Matches, Out).
  211
  212:- meta_predicate saturate(:, ?, ?).  213:- meta_predicate saturate(:, +, ?, ?).
 saturate(+Rules)// is det
Applies a list of compiled rewrite rules to the E-graph until saturation is reached.
Arguments:
Rules- A list of compiled rewrite rule names to apply.
  222saturate(M:Rules) -->
  223   saturate(M:Rules, inf).
 saturate(+Rules, +N)// is det
Applies a list of compiled rewrite rules to the E-graph up to N times or until saturation is reached.
Arguments:
Rules- A list of compiled rewrite rule names to apply.
N- The maximum number of iterations (or inf for no limit).
  233saturate(M:Rules, N) -->
  234   { compile_rules(M:Rules, RulesId) },
  235   saturate_(RulesId, N).
  236
  237saturate_(Rules, N, In, Out) :-
  238   (  N > 0
  239   -> make_index(In, Index),
  240      match(In, Rules, Index, Unifs, [], Matches, In),
  241      rebuild(Matches, Unifs, Tmp),
  242      length(In, Len1),
  243      length(Tmp, Len2),
  244      (  Len1 \== Len2
  245      -> (  N == inf
  246         -> N1 = N
  247         ;  N1 is N - 1
  248         ),
  249         saturate_(Rules, N1, Tmp, Out)
  250      ;  Out = Tmp
  251      )
  252   ;  Out = In
  253   ).
 extract(Id, Extracted)// is det
Extracts the optimal term from the E-graph based on term costs.
Arguments:
Id- The eclass Id to be extracted as returned by add_term
Extracted- the extracted term
  262extract(Target, Extracted, EGraph, EGraph) :-
  263   current_prolog_flag(float_overflow, Flag),
  264   setup_call_cleanup(
  265      set_prolog_flag(float_overflow, infinity),
  266      (  dijkstra(Target, EGraph, Costs),
  267         extract_class(Costs, Target, Extracted)
  268      ),
  269      set_prolog_flag(float_overflow, Flag)
  270   ).
  271extract_class(Costs, Target, Extracted) :-
  272   rb_lookup(Target, _-Node, Costs),
  273   extract_node(Costs, Node, Extracted).
  274extract_node(_, '$VAR'(Var), R) =>
  275   R = Var.
  276extract_node(Costs, Dict, R), is_dict(Dict) =>
  277   dict_pairs(Dict, Tag, Pairs),
  278   pairs_keys_values(Pairs, Keys, Classes),
  279   pairs_keys_values(NewPairs, Keys, Values),
  280   dict_pairs(R, Tag, NewPairs),
  281   maplist(extract_class(Costs), Classes, Values).
  282extract_node(Costs, Compound, R), compound(Compound) =>
  283   compound_name_arguments(Compound, Name, Classes),
  284   same_length(Classes, Values),
  285   compound_name_arguments(R, Name, Values),
  286   maplist(extract_class(Costs), Classes, Values).
  287extract_node(_, Atomic, R) =>
  288   R = Atomic.
  289
  290dijkstra(Target, EGraph, CostsOut) :-
  291   empty_heap(HeapIn),
  292   rb_new(EmptyCosts),
  293   setup(EGraph, ParentPairs, EmptyCosts, CostsIn, HeapIn, HeapOut),
  294   keysort(ParentPairs, SortedParentPairs),
  295   group_pairs_by_key(SortedParentPairs, GroupedParentPairs),
  296   ord_list_to_rbtree(GroupedParentPairs, Parents),
  297   dijkstra(Target, Parents, HeapOut, CostsIn, CostsOut).
  298dijkstra(Target, Parents, HeapIn, CostsIn, CostsOut) :-
  299   (  get_from_heap(HeapIn, CurrentCost, Class, HeapTmp)
  300   -> (  Class == Target
  301      -> CostsOut = CostsIn
  302      ;  rb_lookup(Class, ClassCost-_, CostsIn),
  303         (  CurrentCost > ClassCost
  304         -> dijkstra(Target, Parents, HeapTmp, CostsIn, CostsOut)
  305         ;  (  rb_lookup(Class, ClassParents, Parents)
  306            -> true
  307            ;  ClassParents = []
  308            ),
  309            update_parents(ClassParents, CostsIn, CostsTmp, HeapTmp, HeapOut),
  310            dijkstra(Target, Parents, HeapOut, CostsTmp, CostsOut)
  311         )
  312      )
  313   ;  CostsOut = CostsIn
  314   ).
  315update_parents([], Costs, Costs, Heap, Heap).
  316update_parents([ParentNode-node(ParentClass, ParentCost) | Parents], CostsIn, CostsOut, HeapIn, HeapOut) :-
  317   (  is_dict(ParentNode)
  318   -> dict_pairs(ParentNode, _, KeysValues),
  319      pairs_values(KeysValues, ChildClasses)
  320   ;  compound(ParentNode), ParentNode \= '$VAR'(_)
  321   -> compound_name_arguments(ParentNode, _, ChildClasses)
  322   ;  ChildClasses = []
  323   ),
  324   compute_cost(ChildClasses, CostsIn, ParentCost, Cost),
  325   (  rb_lookup(ParentClass, CurrentCost-_, CostsIn)
  326   -> true
  327   ;  CurrentCost = inf
  328   ),
  329   (  Cost < CurrentCost
  330   -> rb_insert(CostsIn, ParentClass, Cost-ParentNode, CostsTmp),
  331      add_to_heap(HeapIn, Cost, ParentClass, HeapTmp)
  332   ;  CostsTmp = CostsIn, HeapTmp = HeapIn
  333   ),
  334   update_parents(Parents, CostsTmp, CostsOut, HeapTmp, HeapOut).
  335
  336   
  337compute_cost([], _, Cost, Cost).
  338compute_cost([Child | Childs], Costs, CostIn, CostOut) :-
  339   (  rb_lookup(Child, ChildCost-_, Costs)
  340   -> true
  341   ;  ChildCost = inf
  342   ),
  343   CostTmp is CostIn + ChildCost,
  344   compute_cost(Childs, Costs, CostTmp, CostOut).
  345
  346setup([], [], Cost, Cost, Heap, Heap).
  347setup([Node-node(ClassId, NodeCost) | Nodes], ParentsIn, CostIn, CostOut, HeapIn, HeapOut) :-
  348   (  is_dict(Node)
  349   -> dict_pairs(Node, _, KeysValues),
  350      pairs_values(KeysValues, ChildClasses)
  351   ;  compound(Node), Node \= '$VAR'(_)
  352   -> compound_name_arguments(Node, _, ChildClasses)
  353   ;  ChildClasses = []
  354   ),
  355   (  ChildClasses == []
  356   -> ParentsOut = ParentsIn,
  357      (  (rb_lookup(ClassId, CurCost-_, CostIn) ; CurCost = inf), NodeCost < CurCost
  358      -> rb_insert(CostIn, ClassId, NodeCost-Node, CostTmp),
  359         add_to_heap(HeapIn, NodeCost, ClassId, HeapTmp)
  360      ;  CostTmp = CostIn, HeapTmp = HeapIn
  361      )
  362   ;  insert_parent(ChildClasses, Node-node(ClassId, NodeCost), ParentsIn, ParentsOut),
  363      CostTmp = CostIn, HeapTmp = HeapIn
  364   ),
  365   setup(Nodes, ParentsOut, CostTmp, CostOut, HeapTmp, HeapOut).
  366
  367insert_parent([], _, Parents, Parents).
  368insert_parent([ChildClass | ChildClasses], Node, [ChildClass-Node | ParentsTmp], ParentsOut) :-
  369   insert_parent(ChildClasses, Node, ParentsTmp, ParentsOut)