Module MemChunkTree


Require Import
  Utf8 Coqlib
  AST Maps ShareTree Util
  Hash FastArith.

Set Implicit Arguments.

Definition typed_memory_chunk :=
  {κ:memory_chunk |
   match κ with
     | Many32 | Many64 => False
     | _ => True
   end }.

Program Definition chunk_beq (x y : typed_memory_chunk) : bool :=
  match x, y return _ with
  | Mint8signed, Mint8signed
  | Mint8unsigned, Mint8unsigned
  | Mint16signed, Mint16signed
  | Mint16unsigned,Mint16unsigned
  | Mint32, Mint32
  | Mint64, Mint64
  | Mfloat32, Mfloat32
  | Mfloat64, Mfloat64
    => true
  | _, _ => false
  end.

Lemma chunk_beq_correct :
  ∀ x y, chunk_beq x y = truex = y.
Proof.
  intros [[] []] [[] []]; simpl; split; try discriminate; reflexivity.
Qed.

Module MemChunkTree : SHARETREE with Definition elt := typed_memory_chunk.

  Definition elt := typed_memory_chunk.
  Definition elt_eq: forall (a b: elt), {a = b} + {a <> b} :=
    eq_dec_of_beq chunk_beq (chunk_beq_correct).

  Record _t (A:Type) : Type := T
  { i8s: option A
  ; i8u: option A
  ; i16s: option A
  ; i16u: option A
  ; i32: option A
  ; i64: option A
  ; f32: option A
  ; f64: option A
  }.

  Definition t: Type -> Type := _t.
  Lemma eq: forall (A: Type), (forall (x y: A), {x=y} + {x<>y}) ->
                forall (a b: t A), {a = b} + {a <> b}.
Proof.
repeat decide equality. Defined.
  Definition empty A : t A := T None None None None None None None None.
  Definition get A (k:typed_memory_chunk) (m: t A) : option A :=
    match k with
      | exist Mint8signed _ => i8s
      | exist Mint8unsigned _ => i8u
      | exist Mint16signed _ => i16s
      | exist Mint16unsigned _ => i16u
      | exist Mint32 _ => i32
      | exist Mint64 _ => i64
      | exist Mfloat32 _ => f32
      | exist Mfloat64 _ => f64
      | exist _ H => match H with end
    end A m.

  Definition upd A (k: elt) (v: option A) (m: t A) : t A.
Proof.
    destruct m as [a b c d e f g h].
    exact
    (match k with
      | exist Mint8signed _ => T v b c d e f g h
      | exist Mint8unsigned _ => T a v c d e f g h
      | exist Mint16signed _ => T a b v d e f g h
      | exist Mint16unsigned _ => T a b c v e f g h
      | exist Mint32 _ => T a b c d v f g h
      | exist Mint64 _ => T a b c d e v g h
      | exist Mfloat32 _ => T a b c d e f v h
      | exist Mfloat64 _ => T a b c d e f g v
      | exist _ H => match H with end
     end).
  Defined.
  Definition set A k v (m: t A) : t A := upd k (Some v) m.
  Definition remove A k m : t A := upd k None m.

  Lemma gempty: forall (A: Type) (i: elt), get i (empty A) = None.
Proof.
now destruct i as [[] []]. Qed.
  Lemma gss: forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
Proof.
now destruct m; destruct i as [[] []]. Qed.
  Lemma gso: forall (A: Type) (i j: elt) (x: A) (m: t A),
    i <> j -> get i (set j x m) = get i m.
Proof.
destruct m. destruct i as [[] []]; destruct j as [[] []]; auto; congruence. Qed.
  Lemma gsspec: forall (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
Proof.
destruct m. destruct i as [[] []]; destruct j as [[] []]; auto. Qed.
  Lemma gsident: forall (A: Type) (i: elt) (m: t A) (v: A),
    get i m = Some v -> set i v m = m.
Proof.
destruct m. destruct i as [[] []]; simpl; congruence. Qed.
  Lemma grs: forall (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
Proof.
destruct m. destruct i as [[] []]; auto. Qed.
  Lemma gro: forall (A: Type) (i j: elt) (m: t A),
    i <> j -> get i (remove j m) = get i m.
Proof.
destruct m. destruct i as [[] []]; destruct j as [[] []]; auto; congruence. Qed.
  Lemma grspec: forall (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.
Proof.
destruct m. destruct i as [[] []]; destruct j as [[] []]; auto. Qed.

Extensional equality between trees.
  Definition opt_beq A (cmp: A -> A -> bool) : option A -> option A -> bool :=
    fun x y =>
      match x, y with
        | Some a, Some b => cmp a b
        | None, None => true
        | _, _ => false
      end.
  Definition beq A (cmp: A -> A -> bool) (m1 m2: t A) : bool.
Proof.
    destruct m1 as [a1 b1 c1 d1 e1 f1 g1 h1].
    destruct m2 as [a2 b2 c2 d2 e2 f2 g2 h2].
    exact (opt_beq cmp a1 a2 &&
           opt_beq cmp b1 b2 &&
           opt_beq cmp c1 c2 &&
           opt_beq cmp d1 d2 &&
           opt_beq cmp e1 e2 &&
           opt_beq cmp f1 f2 &&
           opt_beq cmp g1 g2 &&
           opt_beq cmp h1 h2).
  Defined.
  Lemma beq_correct:
    forall (A: Type) (eqA: A -> A -> bool) (t1 t2: t A),
    beq eqA t1 t2 = true <->
    (forall (x: elt),
     match get x t1, get x t2 with
     | None, None => True
     | Some y1, Some y2 => eqA y1 y2 = true
     | _, _ => False
    end).
Proof.
    destruct t1. destruct t2.
    simpl.
    repeat rewrite andb_true_iff in *. split.
    destruct x as [[] []]; simpl;
    repeat match goal with
      | |- context[match ?a with None => _ | Some _ => _ end] =>
        destruct a
    end; simpl in *; intuition.
    repeat split.
    specialize (H (exist _ Mint8signed I)); destruct i8s0, i8s1; auto.
    specialize (H (exist _ Mint8unsigned I)); destruct i8u0, i8u1; auto.
    specialize (H (exist _ Mint16signed I)); destruct i16s0, i16s1; auto.
    specialize (H (exist _ Mint16unsigned I)); destruct i16u0, i16u1; auto.
    specialize (H (exist _ Mint32 I)); destruct i33, i34; auto.
    specialize (H (exist _ Mint64 I)); destruct i65, i66; auto.
    specialize (H (exist _ Mfloat32 I)); destruct f33, f34; auto.
    specialize (H (exist _ Mfloat64 I)); destruct f65, f66; auto.
  Qed.

  Definition get_set A i m : option A * (A -> t A) :=
    (get i m, fun v => set i v m).

  Lemma get_set_spec:
    forall (A: Type) (i:elt) (m:t A),
      fst (get_set i m) = get i m /\
      forall v, snd (get_set i m) v = set i v m.
Proof.
auto. Qed.

Applying a function to all data of a tree.
  Definition map A B (fn: elt -> A -> B) (m: t A) : t B.
Proof.
    destruct m as [a b c d e f g h].
    exact
      (T (option_map (fn (exist _ Mint8signed I)) a)
         (option_map (fn (exist _ Mint8unsigned I)) b)
         (option_map (fn (exist _ Mint16signed I)) c)
         (option_map (fn (exist _ Mint16unsigned I)) d)
         (option_map (fn (exist _ Mint32 I)) e)
         (option_map (fn (exist _ Mint64 I)) f)
         (option_map (fn (exist _ Mfloat32 I)) g)
         (option_map (fn (exist _ Mfloat64 I)) h)).
  Defined.
  Lemma gmap: forall (A B: Type) (f: elt -> A -> B) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).
Proof.
destruct m. now destruct i as [[] []]. Qed.

Same as map, but the function does not receive the elt argument.
  Definition map1 A B (fn: A -> B) (m: t A) : t B.
Proof.
    destruct m as [a b c d e f g h].
    exact
      (T (option_map fn a)
         (option_map fn b)
         (option_map fn c)
         (option_map fn d)
         (option_map fn e)
         (option_map fn f)
         (option_map fn g)
         (option_map fn h)).
  Defined.
  Lemma gmap1: forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).
Proof.
destruct m. now destruct i as [[] []]. Qed.

Applying a function pairwise to all data of two trees.
  Definition combine A B C (fn: option A -> option B -> option C) (m1: t A) (m2: t B) : t C.
Proof.
    destruct m1 as [a1 b1 c1 d1 e1 f1 g1 h1].
    destruct m2 as [a2 b2 c2 d2 e2 f2 g2 h2].
    exact
      (T (fn a1 a2) (fn b1 b2) (fn c1 c2) (fn d1 d2) (fn e1 e2) (fn f1 f2) (fn g1 g2) (fn h1 h2)).
  Defined.
  Lemma gcombine:
    forall (A B C: Type) (f: option A -> option B -> option C),
    f None None = None ->
    forall (m1: t A) (m2: t B) (i: elt),
    get i (combine f m1 m2) = f (get i m1) (get i m2).
Proof.
now destruct m1; destruct m2; destruct i as [[] []]. Qed.
  Lemma combine_commut:
    forall (A B: Type) (f g: option A -> option A -> option B),
    (forall (i j: option A), f i j = g j i) ->
    forall (m1 m2: t A),
    combine f m1 m2 = combine g m2 m1.
Proof.
destruct m1; destruct m2. simpl. congruence. Qed.

Enumerating the bindings of a tree.
  Definition opt_cons A (e: elt) (x: option A) (tl: list (elt * A)) : list (elt * A) :=
    match x with
      | None => tl
      | Some y => (e, y) :: tl
    end.

  Lemma in_opt_cons A (x: elt * A) e y tl :
    In x (opt_cons e y tl) ↔ fst x = eSome (snd x) = yIn x tl.
Proof.
    split.
    - destruct y. now intros [ <- | H ]; [ left | right ]. apply or_intror.
    - intros [ ( <- & <- ) | H ].
      + left. symmetry. apply surjective_pairing.
      + refine match y return In x (opt_cons e y tl) with Some _ => or_intror H | None => H end.
  Qed.

  Definition elements A (m: t A) : list (elt * A).
Proof.
    destruct m as [a b c d e f g h i j].
    exact (opt_cons (exist _ Mint8signed I) a
          (opt_cons (exist _ Mint8unsigned I) b
          (opt_cons (exist _ Mint16signed I) c
          (opt_cons (exist _ Mint16unsigned I) d
          (opt_cons (exist _ Mint32 I) e
          (opt_cons (exist _ Mint64 I) f
          (opt_cons (exist _ Mfloat32 I) g
          (opt_cons (exist _ Mfloat64 I) h nil)))))))).
  Defined.

  Lemma elements_correct:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some v -> In (i, v) (elements m).
Proof.
destruct m. destruct i as [[] []]; simpl; intros; subst; simpl; auto;
                      repeat (apply in_opt_cons; first [now left|right; intuition]).
  Qed.
  Lemma elements_complete:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    In (i, v) (elements m) -> get i m = Some v.
Proof.
    destruct m. simpl. intros.
    rewrite !in_opt_cons in H.
    simpl in H.
    intuition subst; reflexivity.
  Qed.
  Lemma elements_keys_norepet:
    forall (A: Type) (m: t A),
    list_norepet (List.map (@fst elt A) (elements m)).
Proof.
    assert (∀ {A} (e: elt) (x: option A) tl, ¬ In e (List.map fst tl) → list_norepet (List.map fst tl) → list_norepet (List.map fst (opt_cons e x tl))) as list_norepet_opt_cons.
    { intros A e [x|] tl. apply list_norepet_cons. intros _. exact id. }
    intros A m.
    destruct m.
    repeat
      (apply list_norepet_opt_cons; [ rewrite in_map_iff; intros ( ( ? & ? ) & ? & H); try rewrite !in_opt_cons in H; simpl in *; intuition subst; try easy | ]).
    constructor.
  Qed.

Folding a function over all bindings of a tree.
  Definition fold A B (f: B -> elt -> A -> B) (m: t A) (v: B) : B :=
    List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
  Lemma fold_spec:
    forall (A B: Type) (f: B -> elt -> A -> B) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
Proof.
    auto.
  Qed.

  Definition fold1 A B (f: B -> A -> B) (m: t A) (v: B) : B :=
    List.fold_left (fun a p => f a (snd p)) (elements m) v.
  Lemma fold1_spec:
    forall (A B: Type) (f: B -> A -> B) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a p => f a (snd p)) (elements m) v.
Proof.
    auto.
  Qed.

  Program Definition share_option {A:Type} (a a':option A) : option A :=
    match a, a' with
      | None, None | Some _, None | None, Some _ => a
      | Some v, Some v' => ifeq v == v' then a' else a
    end.

  Lemma share_option_eq:
    ∀ A (a a':option A), share_option a a' = a.
Proof.
unfold share_option. destruct a, a'; auto. Qed.

  Program Definition shcombine {A:Type}
             (it:elt -> option A -> option A -> option A) (Hit:∀ x v, it x v v = v)
             (x y:t A) : t A :=
    ifeq x == y then x
    else match x, y with
      | T a1 b1 c1 d1 e1 f1 g1 h1,
        T a2 b2 c2 d2 e2 f2 g2 h2 =>
        let a := share_option (share_option (it Mint8signed a1 a2) a1) a2 in
        let b := share_option (share_option (it Mint8unsigned b1 b2) b1) b2 in
        let c := share_option (share_option (it Mint16signed c1 c2) c1) c2 in
        let d := share_option (share_option (it Mint16unsigned d1 d2) d1) d2 in
        let e := share_option (share_option (it Mint32 e1 e2) e1) e2 in
        let f := share_option (share_option (it Mint64 f1 f2) f1) f2 in
        let g := share_option (share_option (it Mfloat32 g1 g2) g1) g2 in
        let h := share_option (share_option (it Mfloat64 h1 h2) h1) h2 in
        let ret _ := T a b c d e f g h in
        let ret _ :=
        ifeq a == a1 then
          ifeq b == b1 then
            ifeq c == c1 then
              ifeq d == d1 then
                ifeq e == e1 then
                  ifeq f == f1 then
                    ifeq g == g1 then
                      ifeq h == h1 then x else ret tt
                    else ret tt
                  else ret tt
                else ret tt
              else ret tt
            else ret tt
          else ret tt
        else ret tt
        in
        ifeq a == a2 then
          ifeq b == b2 then
            ifeq c == c2 then
              ifeq d == d2 then
                ifeq e == e2 then
                  ifeq f == f2 then
                    ifeq g == g2 then
                      ifeq h == h2 then y else ret tt
                    else ret tt
                  else ret tt
                else ret tt
              else ret tt
            else ret tt
          else ret tt
        else ret tt
      end.
Next Obligation.
rewrite !share_option_eq in *. congruence. Qed.
Next Obligation.
unfold physEq. rewrite !share_option_eq in *. congruence. Qed.
Next Obligation.
destruct y. unfold physEq. rewrite !share_option_eq. f_equal; auto. Qed.

  Lemma gshcombine:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m1: t A) (m2: t A) (i: elt),
    get i (shcombine f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    intros. unfold shcombine, physEq. destruct m1, m2. rewrite !share_option_eq.
    destruct i as [[] []]; simpl; auto.
  Qed.

  Lemma shcombine_eq:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m: t A),
      shcombine f Hf m m = m.
Proof.
    destruct m. unfold shcombine, physEq. rewrite !share_option_eq. f_equal; auto.
  Qed.

  Program Definition shforall2 {A:Type}
             (it:elt -> option A -> option A -> bool) (Hit:∀ x v, it x v v = true)
             (x y:t A) : bool :=
    ifeq x == y then true
    else
      match x, y return _ with
        | T a1 b1 c1 d1 e1 f1 g1 h1,
          T a2 b2 c2 d2 e2 f2 g2 h2 =>
          it Mint8signed a1 a2 &&
          it Mint8unsigned b1 b2 &&
          it Mint16signed c1 c2 &&
          it Mint16unsigned d1 d2 &&
          it Mint32 e1 e2 &&
          it Mint64 f1 f2 &&
          it Mfloat32 g1 g2 &&
          it Mfloat64 h1 h2
      end.
Next Obligation.
destruct y. setoid_rewrite Hit. auto. Qed.

  Lemma shforall2_correct:
    forall (A: Type)
           (f: elt -> option A -> option A -> bool) (Hf:∀ x v, f x v v = true),
    forall (m1: t A) (m2: t A),
    shforall2 f Hf m1 m2 = true <->
    (forall x, f x (get x m1) (get x m2) = true).
Proof.
    intros. unfold shforall2, physEq. destruct m1, m2.
    repeat setoid_rewrite Bool.andb_true_iff.
    split.
    - intros [[[[[[[? ?] ?] ?] ?] ?] ?] ?] [[] []]; auto.
    - repeat (split; try apply H).
  Qed.

  Program Definition shcombine_diff {A B:Type}
             (it:elt -> option A -> option A -> option B) (Hit:∀ x v, it x v v = None)
             (x y:t A) : t B :=
    ifeq x == y then empty _
    else match x, y with
      | T a1 b1 c1 d1 e1 f1 g1 h1,
        T a2 b2 c2 d2 e2 f2 g2 h2 =>
        let a := it Mint8signed a1 a2 in
        let b := it Mint8unsigned b1 b2 in
        let c := it Mint16signed c1 c2 in
        let d := it Mint16unsigned d1 d2 in
        let e := it Mint32 e1 e2 in
        let f := it Mint64 f1 f2 in
        let g := it Mfloat32 g1 g2 in
        let h := it Mfloat64 h1 h2 in
        T a b c d e f g h
    end.
Next Obligation.
unfold physEq, empty. destruct y. f_equal; auto. Qed.

  Lemma gshcombine_diff:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine_diff f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    intros. unfold shcombine_diff, physEq. destruct m1, m2.
    destruct i as [[] []]; simpl; auto.
  Qed.

  Lemma shcombine_diff_eq:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m: t A),
      shcombine_diff f Hf m m = empty _.
Proof.
    destruct m. unfold shcombine_diff, physEq, empty. f_equal; auto.
  Qed.

End MemChunkTree.

Instance hash_typed_memory_chunk : hashable typed_memory_chunk :=
  fun (h:Nfast) (c:typed_memory_chunk) =>
    let key :=
      match proj1_sig c with
      | Mint8signed => 0 | Mint8unsigned => 1 | Mint16signed => 2
      | Mint16unsigned => 3 | Mint32 => 4 | Mint64 => 5
      | Mfloat32 => 6 | Mfloat64 => 7
      | Many32 | Many64 => 0
      end%N in
    MIX h key.