Module ConsSet

Require Import String.
Require Import List.
Require Export NumC.
Require Export ProgVar.
Require Export CstrC.
Require Export CertC.
Require Import Debugging.
Require Import OptionMonad.
Require Import CoqAddOn.

Module Type CsSig (C: CstrSig).
A parameterized interface for sets of linear relations between program variables.

  Parameter t: Type.

  Parameter sat: t -> Mem.t QNum.t -> Prop.

  Parameter mayDependOn: t -> PVar.t -> Prop.
  Parameter sat_mdo: mdoExt mayDependOn sat Logic.eq.

  Definition Incl (l1 l2: t): Prop
    := forall m, sat l1 m -> sat l2 m.

  Parameter top: t.
  Parameter top_correct: forall m, sat top m.

  Parameter combine: t -> Cert.frag_t -> C.t.
  Parameter combine_correct: forall (l: t) (f: Cert.frag_t) m, sat l m -> C.sat (combine l f) m.

  Parameter chkIncl: t -> t -> Cert.t -> bool.
  Parameter chkIncl_correct: forall l1 l2 f, If chkIncl l1 l2 f THEN Incl l1 l2.

  Parameter chkEmpty: t -> Cert.t -> bool.
  Parameter chkEmpty_correct: forall l cert, If chkEmpty l cert THEN forall m,~(sat l m).

  Parameter add: t -> Cert.id_t -> C.t -> t.
  Parameter Add_correct: forall l i c m, sat l m -> C.sat c m -> sat (add l i c) m.

  Parameter redAdd: t -> Cert.id_t -> C.t -> Cert.t -> t.
  Parameter redAdd_correct: forall l i c cert m, sat l m -> C.sat c m -> sat (redAdd l i c cert) m.

  Parameter buildJoin: t -> t -> Cert.joinT -> t.
  Parameter buildJoin_correct1: forall l1 l2 cert, Incl l1 (buildJoin l1 l2 cert).
  Parameter buildJoin_correct2: forall l1 l2 cert, Incl l2 (buildJoin l1 l2 cert).

  Parameter isFree: PVar.t -> t -> bool.
  Parameter isFree_correct: forall x cs, If (isFree x cs) THEN ~(mayDependOn cs x).

  Parameter rename: PVar.t -> PVar.t -> t -> t.
  Parameter rename_correct: forall (x y:PVar.t) (cs:t) m,
    (sat (rename x y cs) m)=(sat cs (Mem.assign x (m y) m)).

  Parameter pr: t -> String.string.

End CsSig.

Module CsImpl (C: CstrSig) <: CsSig C.
A straight-forward implementation of CsSig in as lists of constraints associated to a unique identifier.

labeled constraints
  Record cT: Type
    := mk {
      id: Cert.id_t;
      cstr: C.t
    }.

  Definition t: Type := list cT.

  Definition cTpr: cT -> string
    := fun c => CoqAddOn.sprintf "{id = %s; cstr = %s}"
      ((Cert.idPr (id c))::(C.pr (cstr c))::nil).

  Definition pr: t -> string
    := fun cs => CoqAddOn.concat CoqAddOn.nl (List.map cTpr cs).
 

  Fixpoint sat (l: t) m: Prop :=
    match l with
      | nil => True
      | c::l' => C.sat (cstr c) m /\ sat l' m
    end.

  Fixpoint mayDependOn (l:t) (x: PVar.t) : Prop :=
    match l with
      | nil => False
      | c::l' => C.mayDependOn (cstr c) x \/ mayDependOn l' x
    end.
  
  Lemma sat_mdo: mdoExt mayDependOn sat Logic.eq.
Proof.
    unfold mdoExt, bExt; induction e; simpl; try tauto.
    intros m1 m2 H; erewrite IHe; eauto.
    erewrite C.sat_mdo; intuition eauto.
  Qed.

  Definition InclCstr (l: t) (c: C.t): Prop
    := forall m, sat l m -> C.sat c m.

  Lemma satCons l c m: sat l m -> C.sat c m -> forall i, sat (mk i c::l) m.
Proof.
    simpl; intuition.
  Qed.
  Hint Resolve satCons: pedraQ.

  Lemma InclCstrIn: forall (l: t) (c: cT) m, In c l -> sat l m -> C.sat (cstr c) m.
    induction l; simpl; intuition.
    subst; auto.
  Qed.
  Hint Local Resolve InclCstrIn: pedraQ.

  Lemma satApp: forall (l1 l2: t) m,
    sat l1 m -> sat l2 m -> sat (l1 ++ l2) m.
Proof.
    induction l1; simpl; intuition.
  Qed.

  Definition Incl (l1 l2: t): Prop
    := forall m, sat l1 m -> sat l2 m.
  Hint Local Unfold Incl: pedraQ.

  Lemma InclTrans: forall l1 l2 l3, Incl l1 l2 -> Incl l2 l3 -> Incl l1 l3.
Proof.
    unfold Incl; eauto.
  Qed.

  Lemma InclFold l1 l2 m: Incl l1 l2 -> sat l1 m -> sat l2 m.
Proof.
    unfold Incl; auto.
  Qed.
  Hint Resolve InclFold: pedraQ.


  Definition top: t
    := nil.

  Lemma top_correct: forall m, sat top m.
Proof.
    simpl; auto.
  Qed.
  Hint Resolve top_correct: pedraQ.

  Definition add (l: t) (i: Cert.id_t) (c: C.t): t := (mk i c)::l.

  Lemma Add_correct: forall l i c m, sat l m -> C.sat c m -> sat (add l i c) m.
Proof.
    simpl; auto.
  Qed.
  Hint Resolve Add_correct: pedraQ.
  Global Opaque add.

  Fixpoint get (l: t) (n: Cert.id_t) {struct l}: option cT :=
    match l with
      | nil => None
      | c1::l1 =>
        if Cert.eqId n (id c1) then
          Some c1
        else
          get l1 n
    end.

  Lemma get_correct l n: WHEN c <- get l n THEN In c l.
Proof.
    induction l; simpl; PedraQsimplify; auto.
  Qed.
  Hint Local Resolve get_correct: pedraQ.

  Definition acc (c1: C.t) (c2: cT) (n: QNum.t): C.t
    := C.add c1 (C.mul n (cstr c2)).

  Lemma acc_correct l c1 c2 n m: (sat l m) -> (C.sat c1 m) -> (In c2 l) -> C.sat (acc c1 c2 n) m.
Proof.
    unfold acc; intros; assert (C.sat (cstr c2) m); eauto with pedraQ.
  Qed.
  Hint Local Resolve acc_correct: pedraQ.
  Extraction Inline acc.
  
  Fixpoint combineRec l (f: Cert.frag_t) (c: C.t) {struct f}: C.t :=
    match f with
      | nil => c
      | (n, coef)::f1 =>
        combineRec l f1
          (TRY cn <- get l n IN acc c cn coef
           CATCH failwith ("ConsSet.combineRec:" ++ (Cert.idPr n) ++ "?" ++ pr l) c)
    end.

  Lemma combineRec_correct l f m: forall c, sat l m -> C.sat c m -> C.sat (combineRec l f c) m.
Proof.
    induction f; simpl; auto with pedraQ.
    PedraQsimplify.
  Qed.
  Hint Local Resolve combineRec_correct: pedraQ.

  Fixpoint combine (l: t) (f: Cert.frag_t) {struct f}: C.t :=
    match f with
      | nil => C.top
      | (n, coef)::f1 =>
        TRY cn <- get l n IN combineRec l f1 (C.mul coef (cstr cn))
        CATCH failwith ("ConsSet.combine:" ++ (Cert.idPr n) ++ "?" ++ pr l) C.top
    end.

  Lemma combine_correct l f m: sat l m -> C.sat (combine l f) m.
Proof.
    induction f; simpl; auto with pedraQ.
    PedraQsimplify.
  Qed.
  Hint Resolve combine_correct: pedraQ.

  Definition build1 (l: t) (cert: Cert.cstrT): C.t :=
    match cert with
      | Cert.Direct f => combine l f
      | Cert.SplitEq f1 f2 => C.merge (combine l f1) (combine l f2)
    end.
  Extraction Inline build1.

  Lemma build1_correct l cert m: sat l m -> C.sat (build1 l cert) m.
Proof.
    unfold build1. destruct cert; eauto with pedraQ.
  Qed.
  Hint Local Resolve build1_correct: pedraQ.

  Fixpoint build (l: t) (cert: list (Cert.id_t * Cert.cstrT)): t :=
    match cert with
      | nil => nil
      | (id, c)::cert0 => (mk id (build1 l c))::(build l cert0)
    end.

  Lemma build_correct l cert: Incl l (build l cert).
Proof.
    unfold Incl; induction cert; simpl; intuition (simpl; auto with pedraQ).
  Qed.
  Hint Local Resolve build_correct: pedraQ.

  Hint Resolve InclTrans: pedraQ.

  Fixpoint chkSyn (l1 l2: t): bool :=
    match l2 with
      | nil => true
      | c2::l2' =>
        TRY c1 <- get l1 (id c2) IN
        C.isIncl (cstr c1) (cstr c2)
        &&& chkSyn l1 l2'
        CATCH false
    end.

  Lemma chkSyn_correct l1 l2: If chkSyn l1 l2 THEN Incl l1 l2.
Proof.
    unfold Incl; induction l2; simpl; auto.
    PedraQsimplify.
  Qed.
  Hint Local Resolve chkSyn_correct: pedraQ.

  Definition chk (l1 l2:t) (cert: list (Cert.id_t * Cert.cstrT)) : bool :=
    chkSyn (build l1 cert) l2.
  
  Lemma chk_correct l1 l2 cert: If chk l1 l2 cert THEN Incl l1 l2.
Proof.
    unfold chk. PedraQsimplify.
  Qed.

  Definition checkEmpty (l: t) (f: Cert.frag_t): bool :=
    C.isContrad (combine l f).

  Lemma checkEmpty_correct l f: If checkEmpty l f THEN forall m,~(sat l m).
Proof.
    unfold checkEmpty. PedraQsimplify.
  Qed.
  Local Hint Resolve checkEmpty_correct: pedraQ.
  Global Opaque checkEmpty.

  Definition bind1 (l: t) (i: Cert.id_t) (cert: Cert.cstrT): t :=
    add l i (build1 l cert).
  Extraction Inline bind1.

  Lemma bind1_correct l i cert: Incl l (bind1 l i cert).
Proof.
    unfold bind1; eauto with pedraQ.
  Qed.
  Hint Local Resolve bind1_correct: pedraQ.

  Fixpoint bind (l: t) (cert: Cert.t): t * Cert.finalT :=
    match cert with
      | Cert.Final cert1 => (l, cert1)
      | Cert.Bind i cert1 cert2 => bind (bind1 l i cert1) cert2
    end.

  Lemma bind_correct cert: forall l, Incl l (fst (bind l cert)).
Proof.
    induction cert; simpl; intros; eauto with pedraQ.
  Qed.
  
  Definition chkIncl (l1 l2: t) (cert: Cert.t): bool :=
    let (nl1,cert1) := bind l1 cert in
      match cert1 with
        | Cert.Implies cert2 => chk nl1 l2 cert2
        | Cert.Empty cert2 => checkEmpty nl1 cert2
      end.

  Lemma chkIncl_correct l1 l2 cert: If chkIncl l1 l2 cert THEN Incl l1 l2.
Proof.
    unfold chkIncl. generalize (bind_correct cert l1).
    destruct (bind l1 cert); simpl. destruct f; simpl; PedraQsimplify.
    unfold not, Incl in * |- *. intros X m X0; case (X m); auto.
  Qed.
  Hint Resolve chkIncl_correct: pedraQ.
  Global Opaque chkIncl.

  Definition chkEmpty (l: t) (cert: Cert.t): bool :=
    let (nl1,cert1) := (bind l cert) in
    match cert1 with
      | Cert.Empty cert2 => checkEmpty nl1 cert2
      | _ => false
    end.

  Lemma chkEmpty_correct l cert: If chkEmpty l cert THEN forall m,~(sat l m).
Proof.
    unfold chkEmpty. generalize (bind_correct cert l).
    destruct (bind l cert); simpl;
    destruct f; simpl; PedraQsimplify; auto with pedraQ.
  Qed.
  Hint Resolve chkEmpty_correct: pedraQ.
  Global Opaque chkEmpty.

  Definition approx (l: t) (cert: Cert.t): t :=
    let (nl1,cert1) := bind l cert in
      match cert1 with
        | Cert.Implies cert2 => build nl1 cert2
        | _ => failwith "ConsSet.approx" top
      end.
  Extraction Inline approx.

  Lemma approx_correct l cert m: sat l m -> sat (approx l cert) m.
Proof.
    generalize (bind_correct cert l).
    unfold approx.
    destruct (bind l cert); simpl.
    destruct f; simpl; eauto with pedraQ.
  Qed.
  Hint Resolve approx_correct: pedraQ.

  Definition redAdd (l: t) (i: Cert.id_t) (c: C.t) (cert: Cert.t): t :=
    approx (add l i c) cert.

  Lemma redAdd_correct l i c cert m: sat l m -> C.sat c m -> sat (redAdd l i c cert) m.
Proof.
    unfold redAdd; simpl; auto with pedraQ.
  Qed.
  Hint Resolve redAdd_correct: pedraQ.

Precise addition which checks that the result is as small as it can be. XXX: returning an option t defeats the purpose of having failwith. The problem here is that, when addp returns with failwith, we cannot prove anything about precision: the error case must be distinguished.
  Definition addp: t -> Cert.id_t -> C.t -> Cert.t -> Cert.t -> option t
    := fun l i c fwd bwd =>
         let l' := redAdd l i c fwd in
         if chkIncl l' (mk i c :: l) bwd
         then Some l'
         else failwith "ConsSet.CsImpl.addp" None.

  Lemma addp_correct: forall l l' i c fwd bwd m, addp l i c fwd bwd = Some l' ->
                                                 sat l m -> C.sat c m -> sat l' m.
Proof.
    intros l l' i c fwd bwd m hadd hl hc.
    unfold addp in hadd.
    destruct (chkIncl (redAdd l i c fwd) (mk i c :: l) bwd).
    - inversion hadd as [hadd'].
      apply redAdd_correct; assumption.
    - discriminate hadd.
  Qed.

  Lemma addp_precise: forall l l' i c fwd bwd m, addp l i c fwd bwd = Some l' ->
                                                  sat l' m -> sat (mk i c :: l) m.
Proof.
    intros l l' i c fwd bwd m hadd hl'.
    unfold addp in hadd.
    assert (hincl := chkIncl_correct (redAdd l i c fwd) (mk i c :: l) bwd).
    destruct (chkIncl (redAdd l i c fwd) (mk i c :: l) bwd).
    - inversion hadd as [hadd'].
      apply hincl.
      rewrite hadd'.
      assumption.
    - discriminate hadd.
  Qed.

  Lemma addp_precise_c: forall l l' i c fwd bwd m, addp l i c fwd bwd = Some l' ->
                                                  sat l' m -> C.sat c m.
Proof.
    intros l l' i c fwd bwd m hadd hl'.
    assert (pf := addp_precise l l' i c fwd bwd m hadd hl').
    replace c with (cstr (mk i c)) by reflexivity.
    apply InclCstrIn with (l := mk i c :: l).
    - compute.
      left.
      reflexivity.
    - assumption.
  Qed.

  Lemma addp_precise_l: forall l l' i c fwd bwd m, addp l i c fwd bwd = Some l' ->
                                                  sat l' m -> sat l m.
Proof.
    intros l l' i c fwd bwd m hadd hl'.
    assert (pf := addp_precise l l' i c fwd bwd m hadd hl').
    inversion pf.
    assumption.
  Qed.
  Global Opaque addp redAdd.

  Definition shift1: Cert.id_t -> cT -> cT
    := fun sh c => mk (Cert.shift (id c) sh) (cstr c).

  Definition shift: Cert.id_t -> t -> t
    := fun sh l => List.map (shift1 sh) l.

  Lemma Shiftsat l m sh: sat l m -> sat (shift sh l) m.
Proof.
    unfold sat, shift, shift1.
    induction l; simpl; intuition.
  Qed.

  Definition meet (l1 l2:t) (cert:Cert.meetT): t :=
    let l2' := shift (Cert.sh cert) l2 in
      approx (l1 ++ l2') (Cert.meetCert cert).

  Hint Local Resolve satApp Shiftsat: pedraQ.

  Lemma meet_correct l1 l2 cert m: sat l1 m -> sat l2 m -> sat (meet l1 l2 cert) m.
Proof.
    unfold meet.
    intros; eapply approx_correct; auto with pedraQ.
  Qed.

  Definition meetEmpty (l1 l2:t) (cert:Cert.meetT): bool
    := let l2' := shift (Cert.sh cert) l2 in
      chkEmpty (l1 ++ l2') (Cert.meetCert cert).

  Lemma meetEmpty_correct l1 l2 cert: If meetEmpty l1 l2 cert THEN forall m,(sat l1 m -> sat l2 m -> False).
Proof.
    unfold meetEmpty; PedraQsimplify.
    intuition eauto with pedraQ.
  Qed.
  Hint Resolve meet_correct meetEmpty_correct: pedraQ.
  Global Opaque meetEmpty.

  Definition choose (c1 c2: C.t): option C.t :=
    if C.isIncl c1 c2 then
      Some c2
    else if C.isIncl c2 c1 then
      Some c1
    else None.

  Lemma choose_correct c1 c2: WHEN c <- choose c1 c2 THEN C.Incl c1 c /\ C.Incl c2 c.
Proof.
    unfold choose.
    OptionMonad.xsimplify ltac:(intuition eauto with pedraQ).
  Qed.
  Hint Local Resolve Forall_monotone choose_correct: pedraQ.

  Definition jSplit (c: C.t) : list C.t :=
    TRY p <- C.split c IN let (c1,c2):=p in c::c1::c2::nil
    CATCH (c::nil).

  Lemma jSplit_correct c: Forall (fun d => C.Incl c d) (jSplit c).
Proof.
    unfold jSplit.
    OptionMonad.xsimplify ltac:(intuition eauto with pedraQ).
  Qed.
  Hint Local Resolve jSplit_correct: pedraQ.

  Fixpoint jTry1 (c: C.t) (l:list C.t): option C.t
    := match l with
         | nil => None
         | c'::l' =>
           TRY c0 <- choose c c' IN (Some c0)
           CATCH jTry1 c l'
       end.

  Definition CListIncl (l: list C.t) (c: C.t): Prop
    := forall m, Forall (fun c => C.sat c m) l -> C.sat c m.

  Lemma jTry1_correct c l: WHEN c' <- jTry1 c l THEN (C.Incl c c' /\ CListIncl l c').
Proof.
    unfold CListIncl; induction l; simpl; auto.
    PedraQsimplify.
    - simpl; intuition.
    - simpl; intuition.
 XXX: strange behavior of "simpl" here w.r.t. ";" 
      we can not replace the above code by 
        "PedraQsimplify; simpl; intuition"
   *)  Qed.
  Hint Local Resolve jTry1_correct: pedraQ.

  Fixpoint jTryIneqs (l1 l2:list C.t) {struct l1}: option C.t :=
    match l1 with
      | nil => None
      | c'::l' =>
        match jTry1 c' l2 with
          | None => jTryIneqs l' l2
          | c0 => c0
        end
    end.

  Lemma jTryIneqs_correct l1 l2:
    WHEN c <- jTryIneqs l1 l2 THEN (CListIncl l1 c) /\ (CListIncl l2 c).
Proof.
    induction l1; simpl; auto.
    PedraQsimplify; unfold CListIncl; simpl; intuition.
  Qed.
  Hint Local Resolve jTryIneqs_correct: pedraQ.

  Definition tryIneq (c1 c2: C.t): C.t :=
    TRY c <- jTryIneqs (jSplit c1) (jSplit c2) IN c
    CATCH failwith "ConsSet.tryIneq" C.top.

  Lemma tryIneq_correct1 c1 c2: C.Incl c1 (tryIneq c1 c2).
Proof.
    unfold tryIneq; PedraQsimplify.
    unfold CListIncl; simpl; intuition eauto with pedraQ.
  Qed.
  Lemma tryIneq_correct2 c1 c2: C.Incl c2 (tryIneq c1 c2).
Proof.
    unfold tryIneq; PedraQsimplify.
    unfold CListIncl; simpl; intuition eauto with pedraQ.
  Qed.

  Lemma CInclFold c1 c2 m: C.Incl c1 c2 -> C.sat c1 m -> C.sat c2 m.
Proof.
    unfold C.Incl; auto.
  Qed.
  Hint Local Resolve CInclFold tryIneq_correct1 tryIneq_correct2: pedraQ.


  Definition buildJoin1 (l1 l2: t) (cert: Cert.joinEltT): C.t :=
    let c1 := build1 l1 (Cert.arg1 cert) in
      let c2 := build1 l2 (Cert.arg2 cert) in
        TRY c <- choose c1 c2 IN c
        CATCH tryIneq c1 c2.

  Lemma buildJoin1_correct1 l1 l2 cert m: sat l1 m -> C.sat (buildJoin1 l1 l2 cert) m.
Proof.
    unfold buildJoin1. PedraQsimplify.
  Qed.

  Lemma buildJoin1_correct2 l1 l2 cert m: sat l2 m -> C.sat (buildJoin1 l1 l2 cert) m.
Proof.
    unfold buildJoin1. PedraQsimplify.
  Qed.
  Hint Local Resolve buildJoin1_correct1 buildJoin1_correct2: pedraQ.

  Fixpoint buildJoin (l1 l2: t) (cert: Cert.joinT): t :=
    match cert with
      | nil => nil
      | h::cert1 =>
        (mk (Cert.id h) (buildJoin1 l1 l2 h))::(buildJoin l1 l2 cert1)
    end.

  Lemma buildJoin_correct1 l1 l2 cert m: sat l1 m -> sat (buildJoin l1 l2 cert) m.
Proof.
    induction cert; simpl; intuition eauto with pedraQ.
  Qed.

  Lemma buildJoin_correct2 l1 l2 cert m: sat l2 m -> sat (buildJoin l1 l2 cert) m.
Proof.
    induction cert; simpl; intuition eauto with pedraQ.
  Qed.
  Hint Resolve buildJoin_correct1 buildJoin_correct2: pedraQ.

  Fixpoint isFree (x: PVar.t) (cs: t): bool :=
    match cs with
      | nil => true
      | c::cs' => C.isFree x (cstr c) &&& isFree x cs'
    end.

  Lemma isFree_correct x cs: If (isFree x cs) THEN ~(mayDependOn cs x).
Proof.
    induction cs; simpl; PedraQsimplify. auto.
  Qed.
  Hint Resolve isFree_correct: pedraQ.

  Fixpoint rename (x y: PVar.t) (cs: t): t :=
    match cs with
      | nil => nil
      | c::cs' => {| id:=id c ; cstr := C.rename x y (cstr c)|}::(rename x y cs')
    end.

  Lemma rename_correct (x y:PVar.t) (cs:t) m:
    (sat (rename x y cs) m)=(sat cs (Mem.assign x (m y) m)).
Proof.
    induction cs; simpl; auto.
    rewrite C.rename_correct. congruence.
  Qed.

End CsImpl.

Module Cs <: CsSig Cstr.
  Include CsImpl Cstr.
End Cs.