~ymherklotz/vericert

a1657466c7d8af0ed405723bf3aa83bafedf9816 — Yann Herklotz 2 months ago 4ba9568
Add new if-conversion pass with top-level fold
3 files changed, 83 insertions(+), 65 deletions(-)

M src/Compiler.v
M src/hls/IfConversion.v
M src/hls/IfConversionproof.v
M src/Compiler.v => src/Compiler.v +16 -1
@@ 139,6 139,21 @@ Definition match_if {A: Type} (flag: unit -> bool) (R: A -> A -> Prop)
  : A -> A -> Prop :=
  if flag tt then R else eq.

Definition match_rep {A: Type} (R: A -> A -> Prop): A -> A -> Prop :=
  Relation_Operators.clos_refl_trans A R.

Global Instance TransfIfLink {A: Type} {LA: Linker A}
                      (transf: A -> A -> Prop) (TL: TransfLink transf)
                      : TransfLink (match_rep transf).
Admitted.

Lemma total_rep_match:
  forall (A B: Type) (n: list B) (f: A -> B -> A)
         (rel: A -> A -> Prop) (prog: A),
    (forall b p, rel p (f p b)) ->
  match_rep rel prog (fold_left f n prog).
Proof. Admitted.

Lemma total_if_match:
  forall (A: Type) (flag: unit -> bool) (f: A -> A)
         (rel: A -> A -> Prop) (prog: A),


@@ 264,7 279,7 @@ Definition transf_hls_temp (p : Csyntax.program) : res Verilog.program :=
   @@ print (print_GibleSeq 0)
   @@ total_if HLSOpts.optim_if_conversion CondElim.transf_program
   @@ print (print_GibleSeq 1)
   @@ total_if HLSOpts.optim_if_conversion IfConversion.transf_program
   @@ total_if HLSOpts.optim_if_conversion (fold_left (fun a b => IfConversion.transf_program b a) (Maps.PTree.empty _ :: Maps.PTree.empty _ :: nil))
   @@ print (print_GibleSeq 2)
  @@@ DeadBlocks.transf_program
   @@ print (print_GibleSeq 3)

M src/hls/IfConversion.v => src/hls/IfConversion.v +12 -16
@@ 35,10 35,9 @@ Require Import vericert.hls.GibleSeq.
Require Import vericert.hls.Predicate.
Require Import vericert.bourdoncle.Bourdoncle.

Definition if_conv_t : Type := PTree.t (list (node * node)).
Definition if_conv_t : Type := list (node * node).

Parameter build_bourdoncle : function -> (bourdoncle * PMap.t N).
Parameter get_if_conv_t : program -> list if_conv_t.

#[local] Open Scope positive.



@@ 225,7 224,7 @@ Definition ifconv_list (headers: list node) (c: code) :=
Definition if_convert_code (c: code) iflist :=
  fold_left (fun s n => if_convert c s (fst n) (snd n)) iflist c.

Definition transf_function (l: if_conv_t) (i: ident) (f: function) : function :=
Definition transf_function (l: option if_conv_t) (f: function) : function :=
  let (b, _) := build_bourdoncle f in
  let b' := get_loops b in
  let iflist := ifconv_list b' f.(fn_code) in


@@ 235,28 234,25 @@ Definition transf_function (l: if_conv_t) (i: ident) (f: function) : function :=

Section TRANSF_PROGRAM.

  Context {A B V: Type}.
  Variable transf: ident -> A -> B.
  Context {A B V DATA: Type}.
  Variable transf: option DATA -> A -> B.

  Definition transform_program_globdef' (idg: ident * globdef A V) : ident * globdef B V :=
  Definition transform_program_globdef' (data: PTree.t DATA) (idg: ident * globdef A V) : ident * globdef B V :=
    match idg with
    | (id, Gfun f) => (id, Gfun (transf id f))
    | (id, Gfun f) => (id, Gfun (transf (data!id) f))
    | (id, Gvar v) => (id, Gvar v)
    end.

  Definition transform_program' (p: AST.program A V) : AST.program B V :=
  Definition transform_program_data (data: PTree.t DATA) (p: AST.program A V) : AST.program B V :=
    mkprogram
      (List.map transform_program_globdef' p.(prog_defs))
      (List.map (transform_program_globdef' data) p.(prog_defs))
      p.(prog_public)
          p.(prog_main).

End TRANSF_PROGRAM.

Definition transf_fundef (l: if_conv_t) (i: ident) (fd: fundef) : fundef :=
  transf_fundef (transf_function l i) fd.
Definition transf_fundef (l: option if_conv_t) (fd: fundef) : fundef :=
  transf_fundef (transf_function l) fd.

Definition transf_program_rec (p: program) (l: if_conv_t) : program :=
  transform_program' (transf_fundef l) p.

Definition transf_program (p: program) : program :=
  fold_left transf_program_rec (get_if_conv_t p) p.
Definition transf_program (l: PTree.t if_conv_t) (p: program) : program :=
  transform_program_data transf_fundef l p.

M src/hls/IfConversionproof.v => src/hls/IfConversionproof.v +55 -48
@@ 50,8 50,8 @@ Require Import vericert.hls.Predicate.

Variant match_stackframe : stackframe -> stackframe -> Prop :=
  | match_stackframe_init :
    forall res f tf sp pc rs p l i
           (TF: transf_function l i f = tf),
    forall res f tf sp pc rs p l
           (TF: transf_function l f = tf),
      match_stackframe (Stackframe res f sp pc rs p) (Stackframe res tf sp pc rs p).

Definition bool_order (b: bool): nat := if b then 1 else 0.


@@ 231,8 231,8 @@ Proof.
Qed.

Lemma transf_spec_correct :
  forall f pc l i,
    if_conv_spec f.(fn_code) (transf_function l i f).(fn_code) pc.
  forall f pc l,
    if_conv_spec f.(fn_code) (transf_function l f).(fn_code) pc.
Proof.
  intros; unfold transf_function; destruct_match; cbn.
  unfold if_convert_code.


@@ 306,8 306,8 @@ Section CORRECTNESS.

  Variant match_states : option SeqBB.t -> state -> state -> Prop :=
    | match_state_some :
      forall stk stk' f tf sp pc rs p m b pc0 rs0 p0 m0 l i
             (TF: transf_function l i f = tf)
      forall stk stk' f tf sp pc rs p m b pc0 rs0 p0 m0 l
             (TF: transf_function l f = tf)
             (STK: Forall2 match_stackframe stk stk')
             (* This can be improved with a recursive relation for a more general structure of the
                if-conversion proof. *)


@@ 319,13 319,13 @@ Section CORRECTNESS.
             (SIM: step ge (State stk f sp pc0 rs0 p0 m0) E0 (State stk f sp pc rs p m)),
        match_states (Some b) (State stk f sp pc rs p m) (State stk' tf sp pc0 rs0 p0 m0)
    | match_state_none :
      forall stk stk' f tf sp pc rs p m l i
             (TF: transf_function l i f = tf)
      forall stk stk' f tf sp pc rs p m l
             (TF: transf_function l f = tf)
             (STK: Forall2 match_stackframe stk stk'),
        match_states None (State stk f sp pc rs p m) (State stk' tf sp pc rs p m)
    | match_callstate :
      forall cs cs' f tf args m l i
             (TF: transf_fundef l i f = tf)
      forall cs cs' f tf args m l
             (TF: transf_fundef l f = tf)
             (STK: Forall2 match_stackframe cs cs'),
        match_states None (Callstate cs f args m) (Callstate cs' tf args m)
    | match_returnstate :


@@ 334,46 334,51 @@ Section CORRECTNESS.
        match_states None (Returnstate cs v m) (Returnstate cs' v m).

  Definition match_prog (p: program) (tp: program) :=
    Linking.match_program (fun cu f tf => forall l i, tf = transf_fundef l i f) eq p tp.
    match_program (fun _ f tf => exists l, transf_fundef l f = tf) eq p tp.

  Context (TRANSL : match_prog prog tprog).

  Lemma symbols_preserved:
    forall (s: AST.ident), Genv.find_symbol tge s = Genv.find_symbol ge s.
  Proof using TRANSL. intros. eapply (Genv.find_symbol_match TRANSL). Qed.
  Proof using TRANSL.
    intros. eapply (Genv.find_symbol_match TRANSL).
  Qed.

  Lemma senv_preserved:
    Senv.equiv (Genv.to_senv ge) (Genv.to_senv tge).
  Proof using TRANSL.
    Admitted.
    (*intros; eapply (Genv.senv_transf TRANSL). Qed.*)
    intros; eapply (Genv.senv_match TRANSL).
  Qed.

  Lemma function_ptr_translated:
    forall b f l i,
    forall b f l,
      Genv.find_funct_ptr ge b = Some f ->
      Genv.find_funct_ptr tge b = Some (transf_fundef l i f).
  Proof. Admitted.
      Genv.find_funct_ptr tge b = Some (transf_fundef l f).
  Proof.
    intros. exploit (Genv.find_funct_ptr_match TRANSL); eauto.
    crush.
  Qed.

  Lemma sig_transf_function:
    forall (f tf: fundef) l i,
      funsig (transf_fundef l i f) = funsig f.
    forall (f tf: fundef) l,
      funsig (transf_fundef l f) = funsig f.
  Proof using.
    unfold transf_fundef. unfold AST.transf_fundef; intros. destruct f.
    unfold transf_function. destruct_match. auto. auto.
  Qed.

  Lemma functions_translated:
    forall (v: Values.val) (f: GibleSeq.fundef) l i,
    forall (v: Values.val) (f: GibleSeq.fundef) l,
      Genv.find_funct ge v = Some f ->
      Genv.find_funct tge v = Some (transf_fundef l i f).
      Genv.find_funct tge v = Some (transf_fundef l f).
  Proof using TRANSL.
    intros. exploit (Genv.find_funct_match TRANSL); eauto. simplify. eauto.
    Admitted.
  Qed.

  Lemma find_function_translated:
    forall ros rs f l i,
    forall ros rs f l,
      find_function ge ros rs = Some f ->
      find_function tge ros rs = Some (transf_fundef l i f).
      find_function tge ros rs = Some (transf_fundef l f).
  Proof using TRANSL.
    Ltac ffts := match goal with
                 | [ H: forall _, Val.lessdef _ _, r: Registers.reg |- _ ] =>


@@ 398,11 403,12 @@ Section CORRECTNESS.
    induction 1.
    exploit function_ptr_translated; eauto; intros.
    do 2 econstructor; simplify. econstructor.
    (*apply (Genv.init_mem_transf TRANSL); eauto.
    apply (Genv.init_mem_match TRANSL); eauto.
    replace (prog_main tprog) with (prog_main prog). rewrite symbols_preserved; eauto.
    symmetry; eapply Linking.match_program_main; eauto. eauto.
    erewrite sig_transf_function; eauto. constructor. auto. auto.
  Qed.*) Admitted.
    erewrite sig_transf_function; eauto. econstructor. auto. auto.
    Unshelve. exact None.
  Qed.

  Lemma transf_final_states :
    forall s1 s2 r b,


@@ 498,8 504,8 @@ Section CORRECTNESS.
  Qed.

  Lemma fn_all_eq :
    forall f tf l i,
      transf_function l i f = tf ->
    forall f tf l,
      transf_function l f = tf ->
      fn_stacksize f = fn_stacksize tf
      /\ fn_sig f = fn_sig tf
      /\ fn_params f = fn_params tf


@@ 512,16 518,16 @@ Section CORRECTNESS.

  Ltac func_info :=
    match goal with
      H: transf_function _ _ _ = _ |- _ =>
      H: transf_function _ _ = _ |- _ =>
        let H2 := fresh "ALL_EQ" in
        pose proof (fn_all_eq _ _ _ _ H) as H2; simplify
        pose proof (fn_all_eq _ _ _ H) as H2; simplify
    end.

  Lemma step_cf_eq :
    forall stk stk' f tf sp pc rs pr m cf s t pc' l i,
    forall stk stk' f tf sp pc rs pr m cf s t pc' l,
      step_cf_instr ge (State stk f sp pc rs pr m) cf t s ->
      Forall2 match_stackframe stk stk' ->
      transf_function l i f = tf ->
      transf_function l f = tf ->
      exists s', step_cf_instr tge (State stk' tf sp pc' rs pr m) cf t s'
                 /\ match_states None s s'.
  Proof.


@@ 534,7 540,8 @@ Section CORRECTNESS.
      rewrite H2 in *. rewrite H12. auto. econstructor; auto.
    - func_info. do 2 econstructor. econstructor; eauto. rewrite H2 in *.
      eauto. econstructor; auto.
  Admitted.
    Unshelve. all: exact None.
  Qed.

  Definition cf_dec :
    forall a pc, {a = RBgoto pc} + {a <> RBgoto pc}.


@@ 768,19 775,19 @@ Section CORRECTNESS.
  Qed.

  Lemma match_none_correct :
    forall t s1' stk f sp pc rs m pr rs' m' bb pr' cf stk' l i,
    forall t s1' stk f sp pc rs m pr rs' m' bb pr' cf stk' l,
      (fn_code f) ! pc = Some bb ->
      SeqBB.step ge sp (Iexec (mki rs pr m)) bb (Iterm (mki rs' pr' m') cf) ->
      step_cf_instr ge (State stk f sp pc rs' pr' m') cf t s1' ->
      Forall2 match_stackframe stk stk' ->
      exists b' s2',
        (plus step tge (State stk' (transf_function l i f) sp pc rs pr m) t s2' \/
           star step tge (State stk' (transf_function l i f) sp pc rs pr m) t s2'
        (plus step tge (State stk' (transf_function l f) sp pc rs pr m) t s2' \/
           star step tge (State stk' (transf_function l f) sp pc rs pr m) t s2'
           /\ ltof (option SeqBB.t) measure b' None) /\
          match_states b' s1' s2'.
  Proof.
    intros * H H0 H1 STK.
    pose proof (transf_spec_correct f pc l i) as X; inv X.
    pose proof (transf_spec_correct f pc l) as X; inv X.
    - apply sim_plus. rewrite H in H2. symmetry in H2.
      exploit step_cf_eq; eauto; simplify.
      do 3 econstructor. apply plus_one. econstructor; eauto.


@@ 790,7 797,7 @@ Section CORRECTNESS.
      destruct (cf_wf_dec x b' cf pc'); subst; simplify.
      + inv H1.
        exploit exec_if_conv; eauto; simplify.
        apply sim_star. exists (Some b'). exists (State stk' (transf_function l i f) sp pc rs pr m).
        apply sim_star. exists (Some b'). exists (State stk' (transf_function l f) sp pc rs pr m).
        simplify; try (unfold ltof; auto). apply star_refl.
        econstructor; auto.
        simplify. econstructor; eauto.


@@ 810,21 817,21 @@ Section CORRECTNESS.
  Qed.

  Lemma match_some_correct:
    forall t s1' s f sp pc rs m pr rs' m' bb pr' cf stk' b0 pc1 rs1 p0 m1 l i,
    forall t s1' s f sp pc rs m pr rs' m' bb pr' cf stk' b0 pc1 rs1 p0 m1 l,
      step ge (State s f sp pc rs pr m) t s1' ->
      (fn_code f) ! pc = Some bb ->
      SeqBB.step ge sp (Iexec (mki rs pr m)) bb (Iterm (mki rs' pr' m') cf) ->
      step_cf_instr ge (State s f sp pc rs' pr' m') cf t s1' ->
      Forall2 match_stackframe s stk' ->
      (fn_code f) ! pc = Some b0 ->
      sem_extrap (transf_function l i f) pc1 sp (Iexec (mki rs pr m)) (Iexec (mki rs1 p0 m1)) b0 ->
      sem_extrap (transf_function l f) pc1 sp (Iexec (mki rs pr m)) (Iexec (mki rs1 p0 m1)) b0 ->
      (forall b',
          f.(fn_code)!pc1 = Some b' ->
          exists tb, (transf_function l i f).(fn_code)!pc1 = Some tb /\ if_conv_replace pc b0 b' tb) ->
          exists tb, (transf_function l f).(fn_code)!pc1 = Some tb /\ if_conv_replace pc b0 b' tb) ->
      step ge (State s f sp pc1 rs1 p0 m1) E0 (State s f sp pc rs pr m) ->
      exists b' s2',
        (plus step tge (State stk' (transf_function l i f) sp pc1 rs1 p0 m1) t s2' \/
           star step tge (State stk' (transf_function l i f) sp pc1 rs1 p0 m1) t s2' /\
        (plus step tge (State stk' (transf_function l f) sp pc1 rs1 p0 m1) t s2' \/
           star step tge (State stk' (transf_function l f) sp pc1 rs1 p0 m1) t s2' /\
             ltof (option SeqBB.t) measure b' (Some b0)) /\ match_states b' s1' s2'.
  Proof.
    intros * H H0 H1 H2 STK IS_B EXTRAP IS_TB SIM.


@@ 853,17 860,17 @@ Section CORRECTNESS.
      match goal with H: context[match_states] |- _ => inv H end.
    - eauto using match_some_correct.
    - eauto using match_none_correct.
    - apply sim_plus. remember (transf_function l i f) as tf. symmetry in Heqtf. func_info.
    - apply sim_plus. remember (transf_function l f) as tf. symmetry in Heqtf. func_info.
      exists None. eexists. split.
      apply plus_one. constructor; eauto. rewrite <- H1. eassumption.
      rewrite <- H4. rewrite <- H2. econstructor; auto.
    - apply sim_plus. exists None. eexists. split.
      apply plus_one. constructor; eauto.
      constructor; auto.
(*    - apply sim_plus. remember (transf_function l i f) as tf. symmetry in Heqtf. func_info.
   - apply sim_plus. remember (transf_function None f) as tf. symmetry in Heqtf. func_info.
      exists None. inv STK. inv H7. eexists. split. apply plus_one. constructor.
      constructor; auto.
  Qed.*) Admitted.
      econstructor; auto.
  Qed.

  Theorem transf_program_correct:
    forward_simulation (semantics prog) (semantics tprog).