Source file pure_literal_elim.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
open Logtk
open Libzipperposition
let section = Util.Section.make ~parent:Const.section "ple"
type id_sgn = ID.t * bool
module IDMap = CCMap.Make(struct
type t = id_sgn
let compare (id1, sgn1) (id2, sgn2) =
let open CCOrd in
ID.compare id1 id2
<?> (CCBool.compare, sgn1, sgn2)
end)
module IntSet = Util.Int_set
module TST = TypedSTerm
exception AppVarFound
let _enabled = ref false
let total_clauses = Util.mk_stat "total_clauses"
let removed_clauses = Util.mk_stat "removed_clauses"
let pp_key = CCPair.pp ID.pp CCBool.pp
let cl_syms lits =
let lit_syms lit =
SLiteral.fold (fun acc t ->
ID.Set.union (ID.Set.of_iter (TST.Seq.symbols t)) acc)
ID.Set.empty lit
in
List.fold_left (fun acc lit -> ID.Set.union acc (lit_syms lit)) ID.Set.empty lits
let compute_occurence_map (seq : (TST.t SLiteral.t list, TST.t, TST.t) Statement.t Iter.t) =
let process_literals ids2clauses (cl_idx, lits) =
let rec ocurrences map =
let inc_map hd sign map =
let prev = IDMap.get_or (hd, sign) map ~default:0 in
IDMap.add (hd,sign) (prev+1) map
in
function
| [] -> map
| x :: xs ->
try
match x with
| SLiteral.Atom(pred, sign) ->
ocurrences (inc_map (TST.head_exn pred) sign map) xs
| SLiteral.Neq(lhs,rhs)
| SLiteral.Eq(lhs,rhs) when TST.Ty.is_prop (TST.ty_exn lhs) ->
let (l_hd, r_hd) = CCPair.map_same TST.head_exn (lhs,rhs) in
let inc_l = inc_map l_hd false (inc_map l_hd true map) in
ocurrences (inc_map r_hd false (inc_map r_hd true inc_l)) xs
| _ -> ocurrences map xs
with Invalid_argument _ ->
raise AppVarFound
in
let all_symbol_occurences = ocurrences IDMap.empty lits in
let id_map' =
IDMap.keys all_symbol_occurences
|> Iter.map fst
|> ID.Set.of_iter
|> (fun syms -> ID.Set.fold (fun sym acc ->
let prev = ID.Map.get_or sym acc ~default:IntSet.empty in
ID.Map.add sym (IntSet.add cl_idx prev) acc
) syms ids2clauses)
in
id_map', all_symbol_occurences
in
Iter.fold (fun (forbidden, ids2clauses, clauses) (stmt : (TST.t SLiteral.t list, TST.t, TST.t) Statement.t) ->
match Statement.view stmt with
| Statement.TyDecl _ ->
forbidden, ids2clauses, clauses
| Statement.Data _
| Statement.Lemma _
| Statement.Def _
| Statement.Rewrite _ ->
let f_syms =
Statement.Seq.forms stmt
|> Iter.fold (fun acc lits -> ID.Set.union acc (cl_syms lits)) ID.Set.empty
in
(ID.Set.union f_syms forbidden), ids2clauses, clauses
| Statement.Assert lits ->
let ids2clauses', new_cl =
process_literals ids2clauses (List.length clauses, lits) in
forbidden, ids2clauses', new_cl :: clauses
| Statement.NegatedGoal (skolems,list_of_lits) ->
let ids2clauses', clauses' =
List.fold_left (fun (ids2cls, cls) lits ->
let ids2cls', new_cl =
process_literals ids2cls (List.length cls, lits) in
(ids2cls', new_cl :: cls)
) (ids2clauses, clauses) list_of_lits
in
forbidden, ids2clauses', clauses'
| Statement.Goal lits ->
failwith "Not implemented: Goal"
) (ID.Set.empty, ID.Map.empty, []) seq
let get_pure_symbols forbidden ids2clauses clauses =
let calculate_pure init_pure cl_status all_clauses =
let clauses = CCArray.of_list (List.rev clauses) in
let rec aux processed symbol_occurences = function
| [] -> processed
| (sym :: syms) as all_syms ->
if ID.Set.mem sym processed || ID.Set.mem sym forbidden then (
aux processed symbol_occurences syms
) else (
let clauses_to_remove =
CCList.filter_map (fun idx ->
if not (CCBV.get cl_status idx) then (
CCBV.set cl_status idx;
Some (clauses.(idx))
)
else None
) (IntSet.to_list @@ ID.Map.find sym ids2clauses)
in
let symbol_occurences', next_to_process =
List.fold_left (fun (sym_occs, next_to_process) cl ->
IDMap.fold (fun key occ (sym_occs, next_to_process) ->
let prev = IDMap.find key sym_occs in
let new_ = prev-occ in
let sym_occs' = IDMap.add key new_ sym_occs in
if new_ = 0 && not (ID.Set.mem sym processed || ID.Set.mem sym forbidden
|| (CCList.mem ~eq:ID.equal (fst key) all_syms)) then (
sym_occs', (fst key) :: next_to_process
) else (sym_occs', next_to_process)
) cl (sym_occs, next_to_process)
) (symbol_occurences, []) clauses_to_remove
in
Util.debugf ~section 1 "became pure: @[%a@]@." (fun k -> k (CCList.pp ID.pp) next_to_process);
aux (ID.Set.add sym processed) symbol_occurences' (next_to_process @ syms))
in
aux ID.Set.empty all_clauses (ID.Set.to_list init_pure)
in
let all_clauses =
List.fold_left (fun acc cl ->
IDMap.union (fun _ a b -> CCOpt.return @@ a + b) acc cl
) IDMap.empty clauses
in
let all_symbols =
IDMap.keys all_clauses
|> Iter.map fst
|> ID.Set.of_iter
in
let init_pure =
ID.Set.filter (fun sym ->
not @@ ID.Set.mem sym forbidden &&
(IDMap.get_or ~default:0 (sym, true) all_clauses == 0 ||
IDMap.get_or ~default:0 (sym, false) all_clauses == 0)
) all_symbols
in
Util.debugf ~section 1 "initially pure: @[%a@]@." (fun k -> k (ID.Set.pp ID.pp) init_pure);
let clause_status = CCBV.create ~size:(List.length clauses) false in
calculate_pure init_pure clause_status all_clauses
let remove_pure_clauses (seq : (TST.t SLiteral.t list, TST.t, TST.t) Statement.t Iter.t) =
let forbidden, ids2cls, cls = compute_occurence_map seq in
let pure_syms =
ID.Set.diff
(get_pure_symbols forbidden ids2cls cls)
forbidden
in
let filter_if_has_pure stmt lits =
Util.incr_stat total_clauses;
let ans =
CCOpt.return_if
(ID.Set.is_empty @@ ID.Set.inter pure_syms (cl_syms lits))
stmt
in
if CCOpt.is_none ans then (
Util.incr_stat removed_clauses;
Util.debugf ~section 2 "removed: @[%a@]@."
(fun k -> k (CCList.pp (SLiteral.pp TST.pp)) lits);
);
ans
in
Iter.filter_map (fun stmt ->
match Statement.view stmt with
| Statement.TyDecl _
| Statement.Data _
| Statement.Lemma _
| Statement.Def _
| Statement.Rewrite _ ->
Some stmt
| Statement.Assert lits ->
filter_if_has_pure stmt lits
| Statement.NegatedGoal (skolems,list_of_lits) ->
let new_cls =
CCList.filter_map (fun x -> filter_if_has_pure x x) list_of_lits
in
if CCList.is_empty new_cls then None
else if List.length new_cls == List.length list_of_lits then Some stmt
else (
let new_stm =
Statement.neg_goal
~attrs:(Statement.attrs stmt)
~proof:(Statement.proof_step stmt)
~skolems
new_cls
in
Some (new_stm)
)
| Statement.Goal lits ->
Some stmt
) seq
let extension =
let modifier (seq : (TST.t SLiteral.t list, TST.t, TST.t) Statement.t Iter.t) =
if !_enabled then
begin
try
remove_pure_clauses seq
with AppVarFound | Not_found ->
seq
end
else seq in
let print_stats _ =
if !_enabled then
CCFormat.printf "%%%a@.%%%a@." Util.pp_stat removed_clauses Util.pp_stat total_clauses;
in
Extensions.(
{ default with name="pure_literal_elimination";
post_cnf_modifiers=[modifier];
env_actions=[print_stats]; }
)
let () =
Options.add_opts
[ "--pure-literal-preprocessing", Arg.Bool (fun v -> _enabled := v),
" remove all pure literals in fixpoint"];
Extensions.register extension