1:- module(egraph, [add_term//2, union//2, saturate//1, saturate//2,
2 extract//2, lookup/2]).
27:- use_module(library(dcg/high_order)). 28:- use_module(library(ordsets)). 29:- use_module(library(rbtrees)). 30:- use_module(library(heaps)). 31
32:- use_module(egraph/compile). 33
34cost:attr_unify_hook(_, _) :-
35 true.
36const:attr_unify_hook(XConst, Y) :-
37 ( get_attr(Y, const, YConst)
38 -> ( XConst =:= YConst
39 -> true
40 ; domain_error(XConst, YConst)
41 )
42 ; var(Y)
43 -> put_attr(Y, const, XConst)
44 ; true
45 ).
55lookup(Item-V, [X1-V1, X2-V2, X3-V3, X4-V4|Xs]) :-
56 !,
57 compare(R4, Item, X4),
58 ( R4=(>)
59 -> lookup(Item-V, Xs)
60 ; R4=(<)
61 -> compare(R2, Item, X2),
62 ( R2=(>)
63 -> Item==X3, V = V3
64 ; R2=(<)
65 -> Item==X1, V = V1
66 ; V = V2
67 )
68 ; V = V4
69 ).
70lookup(Item-V, [X1-V1, X2-V2|Xs]) :-
71 !,
72 compare(R2, Item, X2),
73 ( R2=(>)
74 -> lookup(Item-V, Xs)
75 ; R2=(<)
76 -> Item==X1, V = V1
77 ; V = V2
78 ).
79lookup(Item-V, [X1-V1]) :-
80 Item==X1, V = V1.
92add_term(Term, Id), var(Term) ==>
93 add_node('$VAR'(Term), Id).
94add_term(Term, Id), is_dict(Term) ==>
95 96 {
97 dict_pairs(Term, Tag, Pairs),
98 pairs_keys_values(Pairs, Keys, Values)
99 },
100 foldl(add_term, Values, Ids),
101 {
102 pairs_keys_values(Data, Keys, Ids),
103 dict_create(Node, Tag, Data)
104 },
105 add_node(Node, Id).
106add_term(Term, Id), compound(Term) ==>
107 { Term =.. [F | Args] },
108 foldl(add_term, Args, Ids),
109 { Node =.. [F | Ids] },
110 add_node(Node, Id).
111add_term(Term, Id) ==>
112 add_node(Term, Id).
113
114add_node(Node-Id, In, Out) :-
115 add_node(Node, Id, In, Out).
116add_node(Node, Id, In, Out) :-
117 ( lookup(Node-node(Id, _Cost), In)
118 -> Out = In
119 ; ord_add_element(In, Node-node(Id, 1), Out),
120 ( number(Node)
121 -> put_attr(Id, const, Node)
122 ; true
123 )
124 ).
125
130
131:- dynamic rules//5. 132:- non_terminal(user:egraph:rules/7). 133
134chain_rule(M, Pat, Id, Index, Rule, Mod:Call, UnifsIn, UnifsOut) :-
135 strip_module(M:Rule, Mod, Name),
136 Call =.. [Name, Pat, Id, Index, UnifsIn, UnifsOut].
137
138:- meta_predicate compile_rules(:, -). 139
140compile_rules(M:Rules, RulesId) :-
141 term_hash(M:Rules, RulesId),
142 ( clause(rules(RulesId, _, _, _, _, _, _), _)
143 -> true
144 ; ( Rules = [_ | _]
145 -> foldl(chain_rule(M, Pat, Id, Index), Rules, Calls, UnifsIn, UnifsOut),
146 comma_list(Body, Calls)
147 ; Body = true, UnifsOut = UnifsIn
148 ),
149 Head = egraph:rules(RulesId, Index, Pat-node(Id, _), UnifsIn, UnifsOut),
150 Clause = (Head --> Body),
151 expand_term(Clause, Term),
152 asserta(Term)
153 ).
154
155
156make_index(In, Index) :-
157 index_pairs(In, UnsortedPairs),
158 keysort(UnsortedPairs, IdPairs),
159 group_pairs_by_key(IdPairs, Groups),
160 ord_list_to_rbtree(Groups, Index).
161
162index_pairs([], []).
163index_pairs([Node-node(Id, _Cost)|T0], [Id-Node|T1]) :-
164 index_pairs(T0, T1).
165
166match([], _, _, Unifs, Unifs) --> [].
167match([Node | Rest], Rules, Index, UnifsIn, UnifsOut) -->
168 rules(Rules, Index, Node, UnifsIn, UnifsTmp),
169 match(Rest, Rules, Index, UnifsTmp, UnifsOut).
178union(A, A) -->
179 merge_nodes.
180
181merge_nodes(In, Out) :-
182 sort(In, Sort),
183 group_pairs_by_key(Sort, Groups),
184 merge_groups(Groups, Tmp, false, Merged),
185 ( Merged == true
186 -> merge_nodes(Tmp, Out)
187 ; Out = Sort
188 ).
189
190merge_groups([Sig-[H | T] | Nodes], [Sig-Node | Worklist], In, Out) :-
191 merge_group(T, H, Node),
192 ( T == []
193 -> Tmp = In
194 ; Tmp = true
195 ),
196 merge_groups(Nodes, Worklist, Tmp, Out).
197merge_groups([], [], In, In).
198
199merge_group([], Node, Node).
200merge_group([node(Id, Cost) | T], node(Id, PrevCost), Out) :-
201 MinCost is min(Cost, PrevCost),
202 merge_group(T, node(Id, MinCost), Out).
203
204apply_unifs([]).
205apply_unifs([A=A | L]) :-
206 apply_unifs(L).
207
208rebuild(Matches, Unifs, Out) :-
209 apply_unifs(Unifs),
210 merge_nodes(Matches, Out).
211
212:- meta_predicate saturate(:, ?, ?). 213:- meta_predicate saturate(:, +, ?, ?).
222saturate(M:Rules) -->
223 saturate(M:Rules, inf).
233saturate(M:Rules, N) -->
234 { compile_rules(M:Rules, RulesId) },
235 saturate_(RulesId, N).
236
237saturate_(Rules, N, In, Out) :-
238 ( N > 0
239 -> make_index(In, Index),
240 match(In, Rules, Index, Unifs, [], Matches, In),
241 rebuild(Matches, Unifs, Tmp),
242 length(In, Len1),
243 length(Tmp, Len2),
244 ( Len1 \== Len2
245 -> ( N == inf
246 -> N1 = N
247 ; N1 is N - 1
248 ),
249 saturate_(Rules, N1, Tmp, Out)
250 ; Out = Tmp
251 )
252 ; Out = In
253 ).
262extract(Target, Extracted, EGraph, EGraph) :-
263 current_prolog_flag(float_overflow, Flag),
264 setup_call_cleanup(
265 set_prolog_flag(float_overflow, infinity),
266 ( dijkstra(Target, EGraph, Costs),
267 extract_class(Costs, Target, Extracted)
268 ),
269 set_prolog_flag(float_overflow, Flag)
270 ).
(Costs, Target, Extracted) :-
272 rb_lookup(Target, _-Node, Costs),
273 extract_node(Costs, Node, Extracted).
274extract_node(_, '$VAR'(Var), R) =>
275 R = Var.
276extract_node(Costs, Dict, R), is_dict(Dict) =>
277 dict_pairs(Dict, Tag, Pairs),
278 pairs_keys_values(Pairs, Keys, Classes),
279 pairs_keys_values(NewPairs, Keys, Values),
280 dict_pairs(R, Tag, NewPairs),
281 maplist(extract_class(Costs), Classes, Values).
282extract_node(Costs, Compound, R), compound(Compound) =>
283 compound_name_arguments(Compound, Name, Classes),
284 same_length(Classes, Values),
285 compound_name_arguments(R, Name, Values),
286 maplist(extract_class(Costs), Classes, Values).
287extract_node(_, Atomic, R) =>
288 R = Atomic.
289
290dijkstra(Target, EGraph, CostsOut) :-
291 empty_heap(HeapIn),
292 rb_new(EmptyCosts),
293 setup(EGraph, ParentPairs, EmptyCosts, CostsIn, HeapIn, HeapOut),
294 keysort(ParentPairs, SortedParentPairs),
295 group_pairs_by_key(SortedParentPairs, GroupedParentPairs),
296 ord_list_to_rbtree(GroupedParentPairs, Parents),
297 dijkstra(Target, Parents, HeapOut, CostsIn, CostsOut).
298dijkstra(Target, Parents, HeapIn, CostsIn, CostsOut) :-
299 ( get_from_heap(HeapIn, CurrentCost, Class, HeapTmp)
300 -> ( Class == Target
301 -> CostsOut = CostsIn
302 ; rb_lookup(Class, ClassCost-_, CostsIn),
303 ( CurrentCost > ClassCost
304 -> dijkstra(Target, Parents, HeapTmp, CostsIn, CostsOut)
305 ; ( rb_lookup(Class, ClassParents, Parents)
306 -> true
307 ; ClassParents = []
308 ),
309 update_parents(ClassParents, CostsIn, CostsTmp, HeapTmp, HeapOut),
310 dijkstra(Target, Parents, HeapOut, CostsTmp, CostsOut)
311 )
312 )
313 ; CostsOut = CostsIn
314 ).
315update_parents([], Costs, Costs, Heap, Heap).
316update_parents([ParentNode-node(ParentClass, ParentCost) | Parents], CostsIn, CostsOut, HeapIn, HeapOut) :-
317 ( is_dict(ParentNode)
318 -> dict_pairs(ParentNode, _, KeysValues),
319 pairs_values(KeysValues, ChildClasses)
320 ; compound(ParentNode), ParentNode \= '$VAR'(_)
321 -> compound_name_arguments(ParentNode, _, ChildClasses)
322 ; ChildClasses = []
323 ),
324 compute_cost(ChildClasses, CostsIn, ParentCost, Cost),
325 ( rb_lookup(ParentClass, CurrentCost-_, CostsIn)
326 -> true
327 ; CurrentCost = inf
328 ),
329 ( Cost < CurrentCost
330 -> rb_insert(CostsIn, ParentClass, Cost-ParentNode, CostsTmp),
331 add_to_heap(HeapIn, Cost, ParentClass, HeapTmp)
332 ; CostsTmp = CostsIn, HeapTmp = HeapIn
333 ),
334 update_parents(Parents, CostsTmp, CostsOut, HeapTmp, HeapOut).
335
336
337compute_cost([], _, Cost, Cost).
338compute_cost([Child | Childs], Costs, CostIn, CostOut) :-
339 ( rb_lookup(Child, ChildCost-_, Costs)
340 -> true
341 ; ChildCost = inf
342 ),
343 CostTmp is CostIn + ChildCost,
344 compute_cost(Childs, Costs, CostTmp, CostOut).
345
346setup([], [], Cost, Cost, Heap, Heap).
347setup([Node-node(ClassId, NodeCost) | Nodes], ParentsIn, CostIn, CostOut, HeapIn, HeapOut) :-
348 ( is_dict(Node)
349 -> dict_pairs(Node, _, KeysValues),
350 pairs_values(KeysValues, ChildClasses)
351 ; compound(Node), Node \= '$VAR'(_)
352 -> compound_name_arguments(Node, _, ChildClasses)
353 ; ChildClasses = []
354 ),
355 ( ChildClasses == []
356 -> ParentsOut = ParentsIn,
357 ( (rb_lookup(ClassId, CurCost-_, CostIn) ; CurCost = inf), NodeCost < CurCost
358 -> rb_insert(CostIn, ClassId, NodeCost-Node, CostTmp),
359 add_to_heap(HeapIn, NodeCost, ClassId, HeapTmp)
360 ; CostTmp = CostIn, HeapTmp = HeapIn
361 )
362 ; insert_parent(ChildClasses, Node-node(ClassId, NodeCost), ParentsIn, ParentsOut),
363 CostTmp = CostIn, HeapTmp = HeapIn
364 ),
365 setup(Nodes, ParentsOut, CostTmp, CostOut, HeapTmp, HeapOut).
366
367insert_parent([], _, Parents, Parents).
368insert_parent([ChildClass | ChildClasses], Node, [ChildClass-Node | ParentsTmp], ParentsOut) :-
369 insert_parent(ChildClasses, Node, ParentsTmp, ParentsOut)
E-graph implementation for term rewriting and saturation
This module implements an E-graph (Equivalence Graph) data structure, commonly used for efficient term rewriting, congruence closure, and e-matching. The E-graph state is typically threaded through operations using DCG notation.
Rewrite rules are automatically compiled into efficient DCG predicates via term expansion. See the
egraph_compilemodule for full details. The supported rule declarations are:rewrite(Name, Lhs, Rhs)rewrite(Name, Lhs, Rhs, RhsOptions)rewrite(Name, Lhs, LhsOptions, Rhs, RhsOptions)rewrite(Name, Lhs, LhsOptions, Rhs, RhsOptions):- BodyMain predicates:
*/