Module TreeAl


Require Import
  Coqlib Utf8
  Maps
  ShareTree.

Set Implicit Arguments.

Module Type TYPE_EQ.
  Variable s: Type.
  Variable t: Type.
  Variable t_of_s : s -> t.
  Variable s_of_t : t -> s.
  Hypothesis s_of_t_of_s : forall x : s, s_of_t (t_of_s x) = x.
  Hypothesis t_of_s_of_t: forall x : t, t_of_s (s_of_t x) = x.
  Variable eq: forall (x y: s), {x = y} + {x <> y}.
End TYPE_EQ.

Module TYPE_EQ_PROP (X:TYPE_EQ).
  Lemma injective (a b: X.s) :
    X.t_of_s a = X.t_of_s b -> a = b.
Proof.
intros H. apply (f_equal X.s_of_t) in H. rewrite !X.s_of_t_of_s in H. auto. Qed.
  Lemma tinjective (a b: X.t) :
    X.s_of_t a = X.s_of_t b -> a = b.
Proof.
intros H. apply (f_equal X.t_of_s) in H. rewrite !X.t_of_s_of_t in H. auto. Qed.
End TYPE_EQ_PROP.

Module Z_EQ_POS <: TYPE_EQ
    with Definition s := Z
    with Definition t := positive.
  Definition s := Z.
  Definition t := positive.
  Definition t_of_s (z: s) : t :=
    match z with
      | Z0 => xH
      | Zpos p => xO p
      | Zneg p => xI p
    end.
  Definition s_of_t (p: t) : s :=
    match p with
      | xH => Z0
      | xO x => Zpos x
      | xI x => Zneg x
    end.
  Lemma s_of_t_of_s : forall x : s, s_of_t (t_of_s x) = x.
Proof.
now induction x. Qed.
  Lemma t_of_s_of_t: forall x : t, t_of_s (s_of_t x) = x.
Proof.
now induction x. Qed.
  Definition eq: forall (x y: s), {x = y} + {x <> y} :=
    Z_eq_dec.
End Z_EQ_POS.

Module BijTree (X:TYPE_EQ) (TTree: TREE with Definition elt := X.t) <: TREE
    with Definition elt := X.s
    with Definition t := TTree.t
    with Definition get := fun A i g => TTree.get (X.t_of_s i) g : option A.

  Module P := TYPE_EQ_PROP(X).
  Hint Resolve P.injective P.tinjective.
  Definition elt: Type := X.s.
  Definition elt_eq: forall (a b: elt), {a = b} + {a <> b} := X.eq.
  Definition t: Type -> Type := TTree.t.
  Definition empty: forall (A: Type), t A := TTree.empty.
  Definition get A i g : option A := TTree.get (X.t_of_s i) g.
  Definition set A i v g : t A := TTree.set (X.t_of_s i) v g.
  Definition remove A i g : t A := TTree.remove (X.t_of_s i) g.

  Lemma gempty: forall (A: Type) (i: elt), get i (empty A) = None.
Proof.
intros; apply TTree.gempty. Qed.
  Lemma gss: forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
Proof.
intros; apply TTree.gss. Qed.
  Lemma gso: forall (A: Type) (i j: elt) (x: A) (m: t A),
    i <> j -> get i (set j x m) = get i m.
Proof.
intros; apply TTree.gso; auto. Qed.
  Lemma gsspec: forall (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
Proof.
    intros. unfold get, set, elt_eq. rewrite TTree.gsspec.
    destruct X.eq; subst; destruct TTree.elt_eq; try congruence.
    exfalso. auto.
  Qed.
  Lemma grs: forall (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
Proof.
intros; apply TTree.grs. Qed.
  Lemma gro: forall (A: Type) (i j: elt) (m: t A),
    i <> j -> get i (remove j m) = get i m.
Proof.
intros; apply TTree.gro; auto. Qed.
  Lemma grspec: forall (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.
Proof.
    intros. unfold get, remove. rewrite TTree.grspec.
    destruct TTree.elt_eq; destruct elt_eq; auto; exfalso; auto.
    congruence.
  Qed.

  Definition get_set A i g : option A * (A -> t A) :=
    TTree.get_set (X.t_of_s i) g.
  Lemma get_set_spec:
    forall (A: Type) (i:elt) (m:t A),
      fst (get_set i m) = get i m /\
      forall v, snd (get_set i m) v = set i v m.
Proof.
    intros. apply TTree.get_set_spec.
  Qed.

  Definition beq A cmp (t1 t2: t A) : bool :=
    TTree.beq cmp t1 t2.
  Lemma beq_correct:
    forall (A: Type) (eqA: A -> A -> bool) (t1 t2: t A),
      beq eqA t1 t2 = true <->
      (forall (x: elt),
         match get x t1, get x t2 with
           | None, None => True
           | Some y1, Some y2 => eqA y1 y2 = true
           | _, _ => False
         end).
Proof.
    intros. rewrite TTree.beq_correct.
    unfold get. split.
    - intros. apply H.
    - intros. specialize (H (X.s_of_t x)).
      rewrite X.t_of_s_of_t in H. auto.
  Qed.

Applying a function to all data of a tree.
  Definition map A B (f: elt -> A -> B) (g: t A) : t B :=
    TTree.map (fun e => (f (X.s_of_t e))) g.

  Lemma gmap: forall (A B: Type) (f: elt -> A -> B) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).
Proof.
    intros. unfold get, map. rewrite TTree.gmap.
    now rewrite X.s_of_t_of_s.
  Qed.

Same as map, but the function does not receive the elt argument.
  Definition map1 A B (f: A -> B) g := TTree.map1 f g.
  Lemma gmap1:
    forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).
Proof.
    unfold get, map1. intros. apply TTree.gmap1.
  Qed.

Applying a function pairwise to all data of two trees.
  Definition combine A B C (f:option A -> option B -> option C) g1 g2 : t C :=
    TTree.combine f g1 g2.
  Lemma gcombine:
    forall (A B C: Type) (f: option A -> option B -> option C),
    f None None = None ->
    forall (m1: t A) (m2: t B) (i: elt),
    get i (combine f m1 m2) = f (get i m1) (get i m2).
Proof.
    intros. unfold get, combine. now apply TTree.gcombine.
  Qed.

Enumerating the bindings of a tree.
  Definition elements A g : list (elt * A) :=
    List.map (fun q => (X.s_of_t (fst q), snd q))
    (TTree.elements g).
  Lemma elements_correct:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some v -> In (i, v) (elements m).
Proof.
    unfold get, elements. intros A m i v H.
    pose proof (List.in_map (fun q => (X.s_of_t (fst q), snd q)) (TTree.elements m) _ (TTree.elements_correct H)).
    simpl in *. rewrite X.s_of_t_of_s in *.
    auto.
  Qed.
  Lemma elements_complete:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    In (i, v) (elements m) -> get i m = Some v.
Proof.
    unfold elements, get. intros A m i v H.
    apply TTree.elements_complete.
    destruct (list_in_map_inv _ _ _ H) as ((x & y) & X & Y).
    simpl in *. injection X. clear X. intros. subst.
    now rewrite X.t_of_s_of_t.
  Qed.
  Lemma elements_keys_norepet:
    forall (A: Type) (m: t A),
    list_norepet (List.map (@fst elt A) (elements m)).
Proof.
    intros A m.
    unfold elements.
    rewrite list_map_compose. simpl.
    rewrite <- list_map_compose.
    apply list_map_norepet.
    apply TTree.elements_keys_norepet.
    auto.
  Qed.

Folding a function over all bindings of a tree.
  Definition fold A B f (g: t A) b : B :=
    TTree.fold (fun x k v => f x (X.s_of_t k) v) g b.
  Lemma fold_spec:
    forall (A B: Type) (f: B -> elt -> A -> B) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
Proof.
    intros.
    unfold fold. rewrite TTree.fold_spec.
    unfold elements.
    generalize dependent v.
    generalize (TTree.elements m). clear m.
    induction l; auto.
    simpl. intuition.
  Qed.

Folding a function over all rhs of bindings of a tree.
  Definition fold1 A B f (g: t A) b : B :=
    TTree.fold1 (fun x v => f x v) g b.
  Lemma fold1_spec:
    forall (A B: Type) (f: B -> A -> B) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a p => f a (snd p)) (elements m) v.
Proof.
    intros.
    unfold fold1. rewrite TTree.fold1_spec.
    unfold elements.
    generalize dependent v.
    generalize (TTree.elements m). clear m.
    induction l; auto.
    simpl. intuition.
  Qed.

End BijTree.

Module BijShareTree (X:TYPE_EQ) (TTree: SHARETREE with Definition elt := X.t) <: SHARETREE with
  Definition elt := X.s.

  Include BijTree(X)(TTree).

  Program Definition shcombine A (f:elt -> option A -> option A -> option A)
                                 (Hf:∀ x v, f x v v = v) x y : t A :=
    TTree.shcombine (fun e => f (X.s_of_t e)) _ x y.

  Lemma gshcombine:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    intros. unfold shcombine. unfold get. rewrite TTree.gshcombine, X.s_of_t_of_s. auto.
  Qed.

  Lemma shcombine_eq:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m: t A),
      shcombine f Hf m m = m.
Proof.
intros. apply TTree.shcombine_eq. Qed.

  Program Definition shcombine_diff A B (f:elt -> option A -> option A -> option B)
                                        (Hf:∀ x v, f x v v = None) x y : t B :=
    TTree.shcombine_diff (fun e => f (X.s_of_t e)) _ x y.

  Lemma gshcombine_diff:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine_diff f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    intros. unfold shcombine_diff. unfold get. rewrite TTree.gshcombine_diff, X.s_of_t_of_s. auto.
  Qed.

  Lemma shcombine_diff_eq:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m: t A),
      shcombine_diff f Hf m m = empty _.
Proof.
intros. apply TTree.shcombine_diff_eq. Qed.

  Program Definition shforall2 A (f:elt -> option A -> option A -> bool)
                                 (Hf:forall x v, f x v v = true) x y : bool :=
    TTree.shforall2 (fun e => f (X.s_of_t e)) _ x y.

  Lemma shforall2_correct:
    forall (A: Type)
           (f: elt -> option A -> option A -> bool) (Hf:∀ x v, f x v v = true),
    forall (m1: t A) (m2: t A),
    shforall2 f Hf m1 m2 = true <->
    (forall x, f x (get x m1) (get x m2) = true).
Proof.
    unfold shforall2. intros. rewrite TTree.shforall2_correct. split; intros.
    specialize (H (X.t_of_s x)). rewrite X.s_of_t_of_s in H. auto.
    specialize (H (X.s_of_t x)). unfold get in H. rewrite X.t_of_s_of_t in H. auto.
  Qed.

End BijShareTree.

Module ZTree <: TREE with Definition elt := Z := BijTree(Z_EQ_POS)(PTree).
Module ZShareTree <: SHARETREE with Definition elt := Z := BijShareTree(Z_EQ_POS)(PShareTree).

Module SumTree (L:TREE) (R:TREE) <: TREE
    with Definition elt := (L.elt + R.elt)%type
    with Definition t := fun A => (L.t A * R.t A)%type
    with Definition get := fun A k m =>
      match k return option A with
        | inl x => L.get x (fst m)
        | inr x => R.get x (snd m)
      end.

  Definition elt := (L.elt + R.elt)%type.
  Definition elt_eq: forall (a b: elt), {a = b} + {a <> b}.
Proof.
    decide equality. apply L.elt_eq. apply R.elt_eq.
  Defined.
  Definition t (A: Type) : Type := (L.t A * R.t A)%type.
  Definition empty A : t A := (L.empty A, R.empty A).
  Definition get A k m : option A :=
    match k with
      | inl x => L.get x (fst m)
      | inr x => R.get x (snd m)
    end.
  Definition set A k v m : t A :=
    match k with
      | inl x => (L.set x v (fst m), snd m)
      | inr x => (fst m, R.set x v (snd m))
    end.
  Definition remove A k m : t A :=
    match k with
      | inl x => (L.remove x (fst m), snd m)
      | inr x => (fst m, R.remove x (snd m))
    end.

  Lemma gempty:
    forall (A: Type) (i: elt), get i (empty A) = None.
Proof.
    destruct i; simpl. apply L.gempty. apply R.gempty.
  Qed.
  Lemma gss:
    forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
Proof.
    destruct m. destruct i; simpl. apply L.gss. apply R.gss.
  Qed.
  Lemma gso:
    forall (A: Type) (i j: elt) (x: A) (m: t A),
    i <> j -> get i (set j x m) = get i m.
Proof.
    destruct m. destruct i; destruct j; simpl; intros; auto. apply L.gso. congruence. apply R.gso. congruence.
  Qed.
  Lemma gsspec:
    forall (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
Proof.
    destruct m. destruct i; destruct j; auto.
    destruct elt_eq; simpl; rewrite L.gsspec; destruct L.elt_eq; auto; congruence.
    destruct elt_eq; simpl; rewrite R.gsspec; destruct R.elt_eq; auto; congruence.
  Qed.

  Lemma grs:
    forall (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
Proof.
    destruct m; destruct i; simpl. apply L.grs. apply R.grs.
  Qed.
  Lemma gro:
    forall (A: Type) (i j: elt) (m: t A),
    i <> j -> get i (remove j m) = get i m.
Proof.
    destruct m; destruct i; destruct j; simpl; intros; auto.
    apply L.gro; congruence.
    apply R.gro; congruence.
  Qed.
  Lemma grspec:
    forall (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.
Proof.
    destruct m; destruct i; destruct j; auto.
    destruct elt_eq; simpl; rewrite L.grspec; destruct L.elt_eq; auto; congruence.
    destruct elt_eq; simpl; rewrite R.grspec; destruct R.elt_eq; auto; congruence.
  Qed.

  Definition get_set A i m : option A * (A -> t A) :=
    match i with
      | inl x => let '(r, s) := L.get_set x (fst m) in
                 (r, fun v => (s v, snd m))
      | inr x => let '(r, s) := R.get_set x (snd m) in
                 (r, fun v => (fst m, s v))
    end.
  Lemma get_set_spec:
    forall (A: Type) (i:elt) (m:t A),
      fst (get_set i m) = get i m /\
      forall v, snd (get_set i m) v = set i v m.
Proof.
    intros. destruct i; simpl.
    - rewrite (surjective_pairing (L.get_set e (fst m))). simpl.
      split. apply L.get_set_spec. intros. f_equal. apply L.get_set_spec.
    - rewrite (surjective_pairing (R.get_set e (snd m))). simpl.
      split. apply R.get_set_spec. intros. f_equal. apply R.get_set_spec.
  Qed.

Extensional equality between trees.
  Definition beq A (cmp: A -> A -> bool) (m1 m2: t A) : bool :=
    let '(l1, r1) := m1 in
    let '(l2, r2) := m2 in
    L.beq cmp l1 l2 && R.beq cmp r1 r2.
  Lemma beq_correct:
    forall (A: Type) (eqA: A -> A -> bool) (t1 t2: t A),
      beq eqA t1 t2 = true <->
      (forall (x: elt),
         match get x t1, get x t2 with
           | None, None => True
           | Some y1, Some y2 => eqA y1 y2 = true
           | _, _ => False
         end).
Proof.
    destruct t1; destruct t2; simpl.
    rewrite andb_true_iff, L.beq_correct, R.beq_correct.
    intuition.
    - destruct x. apply H0. apply H1.
    - apply (H (inl _)).
    - apply (H (inr _)).
  Qed.

Applying a function to all data of a tree.
  Definition map A B (f: elt -> A -> B) (m: t A) : t B :=
    let '(l, r) := m in
    (L.map (fun k => f (inl _ k)) l,
     R.map (fun k => f (inr _ k)) r).
  Lemma gmap:
    forall (A B: Type) (f: elt -> A -> B) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).
Proof.
    destruct m. destruct i; simpl. now rewrite L.gmap. now rewrite R.gmap.
  Qed.

  Definition map1 A B (f: A -> B) (m: t A) : t B :=
    (L.map1 f (fst m), R.map1 f (snd m)).
  Lemma gmap1:
    forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).
Proof.
    destruct m; destruct i; simpl. apply L.gmap1. apply R.gmap1.
  Qed.

  Definition combine A B C (f: option A -> option B -> option C) (m1: t A) (m2: t B) : t C :=
    let '(l1, r1) := m1 in
    let '(l2, r2) := m2 in
    (L.combine f l1 l2, R.combine f r1 r2).
  Lemma gcombine:
    forall (A B C: Type) (f: option A -> option B -> option C),
    f None None = None ->
    forall (m1: t A) (m2: t B) (i: elt),
    get i (combine f m1 m2) = f (get i m1) (get i m2).
Proof.
    destruct m1; destruct m2; destruct i; simpl. apply L.gcombine; auto. apply R.gcombine; auto.
  Qed.

Enumerating the bindings of a tree.
  Definition elements A (m: t A) : list (elt * A) :=
    (List.map (fun k => (inl _ (fst k), snd k)) (L.elements (fst m)))
      ++
    (List.map (fun k => (inr _ (fst k), snd k)) (R.elements (snd m)))
  .
  Lemma elements_correct:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some v -> In (i, v) (elements m).
Proof.
    intros A m i v H.
    apply in_app.
    destruct m as (m & n); destruct i as [i|i];[left|right].
    pose proof (L.elements_correct).
    pose proof (List.in_map (fun k => (inl R.elt (fst k), snd k)) (L.elements m) (i,v)). auto.
    pose proof (R.elements_correct H).
    pose proof (List.in_map (fun k => (inr L.elt (fst k), snd k)) (R.elements n) (i,v)). auto.
  Qed.
  Lemma elements_complete:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    In (i, v) (elements m) -> get i m = Some v.
Proof.
    intros A m i v H. apply in_app in H.
    destruct m as [m n].
    destruct H as [H|H]; destruct (list_in_map_inv _ _ _ H) as ((x & y) & X & Y);
    injection X; intros; subst; clear X; simpl.
    apply L.elements_complete; auto.
    apply R.elements_complete; auto.
  Qed.
  Lemma elements_keys_norepet:
    forall (A: Type) (m: t A),
    list_norepet (List.map (@fst elt A) (elements m)).
Proof.
    intros A (m & n). unfold elements. rewrite List.map_app. apply list_norepet_app. split.
    rewrite list_map_compose. simpl. rewrite <- list_map_compose.
    apply list_map_norepet. apply L.elements_keys_norepet. congruence.
    split.
    rewrite list_map_compose. simpl. rewrite <- list_map_compose.
    apply list_map_norepet. apply R.elements_keys_norepet. congruence.
    repeat rewrite list_map_compose. simpl.
    generalize (L.elements m).
    generalize (R.elements n). clear m n.
    induction l as [|(a & b) l]. intros l x y H Z. elim Z.
    intros m x y X Y.
    destruct (list_in_map_inv _ _ _ X) as ((u & v) & U & V). subst. clear X. simpl.
    destruct (list_in_map_inv _ _ _ Y). intuition subst; discriminate.
  Qed.

  Lemma fold_left_map:
    forall {A B C} (f: A -> B) (g: C -> B -> C) xs a,
      fold_left g (List.map f xs) a = fold_left (fun a x => g a (f x)) xs a.
Proof.
now induction xs; simpl. Qed.

  Definition fold A B (f: B -> elt -> A -> B) (m: t A) (v: B) : B :=
    (R.fold (fun b k => f b (inr _ k)) (snd m)
    (L.fold (fun b k => f b (inl _ k)) (fst m) v)).
  Lemma fold_spec:
    forall (A B: Type) (f: B -> elt -> A -> B) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
Proof.
    intros. destruct m as (m & n).
    unfold fold, elements.
    rewrite List.fold_left_app, L.fold_spec, R.fold_spec. simpl.
    now repeat rewrite fold_left_map.
  Qed.

  Definition fold1 A B (f: B -> A -> B) (m: t A) (v: B) : B :=
    (R.fold1 (fun b => f b) (snd m)
    (L.fold1 (fun b => f b) (fst m) v)).
  Lemma fold1_spec:
    forall (A B: Type) (f: B -> A -> B) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a p => f a (snd p)) (elements m) v.
Proof.
    intros. destruct m as (m & n).
    unfold fold1, elements.
    rewrite List.fold_left_app, L.fold1_spec, R.fold1_spec. simpl.
    now repeat rewrite fold_left_map.
  Qed.

End SumTree.

Module SumShareTree (L:SHARETREE) (R:SHARETREE) <: SHARETREE
    with Definition elt := (L.elt + R.elt)%type.

  Include SumTree(L)(R).

  Program Definition shcombine A (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v)
             (m1 m2:t A): t A :=
    ifeq m1 == m2 then m1 else
      let '(l1, r1) := m1 in
      let '(l2, r2) := m2 in
      let l := L.shcombine (fun e => f (inl e)) _ l1 l2 in
      let r := R.shcombine (fun e => f (inr e)) _ r1 r2 in
      let sh1 _ :=
        ifeq l == l1 then
          ifeq r == r1 then m1 else (l, r)
        else (l, r)
      in
      ifeq l == l2 then
        ifeq r == r2 then m2 else sh1 tt
      else sh1 tt.
Next Obligation.
congruence. Qed.
Next Obligation.
unfold physEq. congruence. Qed.
Next Obligation.
unfold physEq. destruct m2. rewrite L.shcombine_eq, R.shcombine_eq. auto. Qed.

  Lemma gshcombine:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m1: t A) (m2: t A) (i: elt),
    get i (shcombine f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    unfold shcombine, physEq. intros. destruct i, m1, m2; simpl.
    apply L.gshcombine. apply R.gshcombine.
  Qed.

  Lemma shcombine_eq:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m: t A),
      shcombine f Hf m m = m.
Proof.
    unfold shcombine, physEq. intros. destruct m.
    rewrite L.shcombine_eq, R.shcombine_eq. auto.
  Qed.

  Program Definition shcombine_diff A B (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None)
             (m1 m2:t A): t B :=
    ifeq m1 == m2 then empty _
    else
      let '(l1, r1) := m1 in
      let '(l2, r2) := m2 in
      let l := L.shcombine_diff (fun e => f (inl e)) _ l1 l2 in
      let r := R.shcombine_diff (fun e => f (inr e)) _ r1 r2 in
      (l, r).
Next Obligation.
destruct m2. rewrite L.shcombine_diff_eq, R.shcombine_diff_eq. auto. Qed.

  Lemma gshcombine_diff:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m1: t A) (m2: t A) (i: elt),
    get i (shcombine_diff f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    unfold shcombine_diff. intros. destruct i, m1, m2; simpl.
    apply L.gshcombine_diff. apply R.gshcombine_diff.
  Qed.

  Lemma shcombine_diff_eq:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m: t A),
      shcombine_diff f Hf m m = empty _.
Proof.
    unfold shcombine_diff, physEq. intros. destruct m.
    rewrite L.shcombine_diff_eq, R.shcombine_diff_eq. auto.
  Qed.

  Program Definition shforall2 A (f:elt -> option A -> option A -> bool)
                                 (Hf:forall x v, f x v v = true) m1 m2 : bool :=
    ifeq m1 == m2 then true
    else
      let '(l1, r1) := m1 in
      let '(l2, r2) := m2 in
      L.shforall2 (fun e => f (inl e)) _ l1 l2 &&
      R.shforall2 (fun e => f (inr e)) _ r1 r2.
Next Obligation.
    symmetry. rewrite Bool.andb_true_iff, L.shforall2_correct, R.shforall2_correct. auto.
  Qed.

  Lemma shforall2_correct:
    forall (A: Type)
           (f: elt -> option A -> option A -> bool) (Hf:∀ x v, f x v v = true),
    forall (m1: t A) (m2: t A),
    shforall2 f Hf m1 m2 = true <->
    (forall x, f x (get x m1) (get x m2) = true).
Proof.
    intros. destruct m1, m2. simpl.
    rewrite Bool.andb_true_iff, L.shforall2_correct, R.shforall2_correct.
    split. intros [] []; auto. split; intros; apply H.
  Qed.

End SumShareTree.

Module ProdTree (M1:TREE) (M2:TREE) <: TREE
    with Definition elt := (M1.elt * M2.elt)%type
    with Definition t := fun A => M1.t (M2.t A)
    with Definition get := fun A a (m:M1.t (M2.t A)) =>
                             let '(a1,a2) := a in
                             match M1.get a1 m with
                               | None => None
                               | Some m => M2.get a2 m
                             end.

  Definition elt: Type := (M1.elt * M2.elt)%type.
  Definition elt_eq: forall (a b: elt), {a = b} + {a <> b}.
    refine (fun a b =>
      match M1.elt_eq (fst a) (fst b) with
      | left E1 =>
        match M2.elt_eq (snd a) (snd b) with
        | left E2 => left _
        | right N2 => right (fun K => N2 (f_equal snd K))
        end
      | right N1 => right (fun K => N1 (f_equal fst K))
      end).
Proof.
    abstract (destruct a; destruct b; simpl in *; congruence).
  Defined.

  Definition t (A:Type) : Type := M1.t (M2.t A).

  Definition empty (A: Type) : t A := M1.empty _.

  Definition get (A: Type) (a:elt) (m: t A) : option A :=
    let '(a1,a2) := a in
      match M1.get a1 m with
        | None => None
        | Some m => M2.get a2 m
      end.

  Definition set (A: Type) (a:elt) (v: A) (m:t A) : t A :=
    let (a1,a2) := a in
    match M1.get_set a1 m with
      | (None, s) => s (M2.set a2 v (M2.empty _))
      | (Some m1, s) => s (M2.set a2 v m1)
    end.

  Definition remove (A: Type) (a:elt) (m:t A) : t A :=
    let (a1,a2) := a in
    match M1.get_set a1 m with
      | (None, s) => m
      | (Some m1, s) => s (M2.remove a2 m1)
    end.

The ``good variables'' properties for trees, expressing commutations between get, set and remove.
  Lemma gempty:
    forall (A: Type) (i: elt), get i (empty A) = None.
Proof.
    intros A [i1 i2]; unfold get, empty.
    rewrite M1.gempty; auto.
  Qed.

  Lemma gss:
    forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
Proof.
    intros A [i1 i2] x m; unfold get, set.
    destruct (M1.get_set_spec i1 m). rewrite (surjective_pairing (M1.get_set i1 m)), H.
    destruct (M1.get i1 m); rewrite H0, M1.gss, M2.gss; auto.
  Qed.

  Lemma gso:
    forall (A: Type) (i j: elt) (x: A) (m: t A),
    i <> j -> get i (set j x m) = get i m.
Proof.
    intros A [i1 i2] [j1 j2] x m H; unfold get, set.
    destruct (M1.get_set_spec j1 m). rewrite (surjective_pairing (M1.get_set j1 m)), H0.
    destruct (M1.elt_eq i1 j1).
    - subst. destruct (M1.get j1 m); rewrite H1, M1.gss, M2.gso by congruence.
      auto. apply M2.gempty.
    - destruct (M1.get j1 m); rewrite H1, M1.gso by auto; auto.
  Qed.

  Lemma gsspec:
    forall (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
Proof.
    intros A i j x m.
    destruct (elt_eq i j); subst.
    apply gss.
    apply gso; auto.
  Qed.

  Lemma grs:
    forall (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
Proof.
    intros A [i1 i2] m; unfold get, remove.
    destruct (M1.get_set_spec i1 m). rewrite (surjective_pairing (M1.get_set i1 m)), H.
    destruct (M1.get i1 m) eqn:EQ. rewrite H0, M1.gss. apply M2.grs. rewrite EQ. auto.
  Qed.

  Lemma gro:
    forall (A: Type) (i j: elt) (m: t A),
    i <> j -> get i (remove j m) = get i m.
Proof.
    intros A [i1 i2] [j1 j2] m H; unfold get, remove.
    destruct (M1.get_set_spec j1 m). rewrite (surjective_pairing (M1.get_set j1 m)), H0.
    destruct (M1.elt_eq i1 j1).
    - subst. destruct (M1.get j1 m) eqn:EQ.
      rewrite H1, M1.gss. apply M2.gro. congruence. rewrite EQ. auto.
    - destruct (M1.get j1 m). rewrite H1, M1.gso by auto; auto. auto.
  Qed.

  Lemma grspec:
    forall (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.
Proof.
    intros A i j m.
    destruct (elt_eq i j); subst.
    apply grs.
    apply gro; auto.
  Qed.

  Definition get_set A a m : option A * (A -> t A) :=
    let '(a1,a2) := a in
    match M1.get_set a1 m with
      | (None, s) => (None, fun v => s (M2.set a2 v (M2.empty _)))
      | (Some m, s) => (M2.get a2 m, fun v => s (M2.set a2 v m))
    end.

  Lemma get_set_spec:
    forall (A: Type) (i:elt) (m:t A),
      fst (get_set i m) = get i m /\
      forall v, snd (get_set i m) v = set i v m.
Proof.
    intros A [i1 i2] m. unfold get_set. simpl.
    destruct (M1.get_set_spec i1 m). rewrite (surjective_pairing (M1.get_set i1 m)), H.
    destruct M1.get; simpl; auto.
  Qed.

Extensional equality between trees.
  Definition beq (A: Type) (f:A -> A -> bool) (m1 m2: t A) : bool :=
    let mb := M1.combine
                (fun m'1 m'2 =>
                   match m'1, m'2 with
                     | Some m'1, Some m'2 => Some (M2.beq f m'1 m'2)
                     | Some m'1, None => Some (M2.beq f m'1 (M2.empty _))
                     | None, Some m'2 => Some (M2.beq f (M2.empty _) m'2)
                     | None, None => None
                   end)
                m1 m2
    in M1.fold (fun x _ y => y && x) mb true.

  Lemma beq_correct:
    forall (A: Type) (eqA: A -> A -> bool) (t1 t2: t A),
    beq eqA t1 t2 = true <->
    (forall (x: elt),
     match get x t1, get x t2 with
     | None, None => True
     | Some y1, Some y2 => eqA y1 y2 = true
     | _, _ => False
    end).
Proof.
    intros A eqA t1 t2. unfold get, beq.
    rewrite M1.fold_spec.
    rewrite <- fold_left_rev_right.
    unfold fold_right. fold (forallb (snd (A:=M1.elt))).
    rewrite forallb_forall.
    setoid_rewrite <- (in_rev).
    assert (forall (A B:Type) (P:A*B->Prop), (forall p, P p) <-> (forall a b, P (a, b))) by intuition.
    rewrite H. clear H.
    setoid_rewrite (fun A m i v => conj (@M1.elements_complete A m i v) (@M1.elements_correct A m i v) : _ <-> _).
    setoid_rewrite (@M1.gcombine _ _ _ _ eq_refl).
    split.
    - intros H [x1 x2]. specialize (H x1 false). simpl in H.
      destruct (M1.get); destruct (M1.get).
      + apply M2.beq_correct. destruct (M2.beq eqA t0 t3); auto.
      + assert (M2.beq eqA t0 (M2.empty A) = true) by (destruct M2.beq; auto).
        rewrite M2.beq_correct in H0. specialize (H0 x2). rewrite (M2.gempty) in H0. trivial.
      + assert (M2.beq eqA (M2.empty A) t0 = true) by (destruct M2.beq; auto).
        rewrite M2.beq_correct in H0. specialize (H0 x2). rewrite (M2.gempty) in H0. trivial.
      + trivial.
    - intros H x1 []. trivial.
      pose proof (fun x => H (x1, x)). clear H. simpl in H0.
      destruct (M1.get); destruct (M1.get).
      + rewrite <- M2.beq_correct in H0. congruence.
      + assert (M2.beq eqA t0 (M2.empty A) = true).
        rewrite M2.beq_correct. intro. rewrite M2.gempty. apply H0.
        congruence.
      + assert (M2.beq eqA (M2.empty A) t0 = true).
        rewrite M2.beq_correct. intro. rewrite M2.gempty. apply H0.
        congruence.
      + congruence.
  Qed.

Applying a function to all data of a tree.
  Definition map (A B: Type) (f: elt -> A -> B) (m:t A) : t B :=
    M1.map (fun a1 => M2.map (fun a2 => f (a1,a2))) m.

  Lemma gmap:
    forall (A B: Type) (f: elt -> A -> B) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).
Proof.
    intros A B f [i1 i2] m; unfold get, map.
    rewrite M1.gmap.
    case_eq (M1.get i1 m); intros; simpl; auto.
    rewrite M2.gmap; auto.
  Qed.

Same as map, but the function does not receive the elt argument.
  Definition map1 (A B: Type) (f:A -> B) (m:t A) : t B :=
    M1.map1 (M2.map1 f) m.

  Lemma gmap1:
    forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).
Proof.
    intros A B f [i1 i2] m; unfold get, map1.
    rewrite M1.gmap1.
    case_eq (M1.get i1 m); intros; simpl; auto.
    rewrite M2.gmap1; auto.
  Qed.

Applying a function pairwise to all data of two trees.
  Definition combine (A B C: Type) (f:option A -> option B -> option C) (m1:t A) (m2:t B) : t C :=
    M1.combine (fun om1 om2 =>
      match om1, om2 with
        | None, None => None
        | None, Some m2 => Some (M2.combine f (@M2.empty _) m2)
        | Some m2, None => Some (M2.combine f m2 (@M2.empty _))
        | Some m1, Some m2 => Some (M2.combine f m1 m2)
      end) m1 m2.

  Lemma gcombine:
    forall (A B C: Type) (f: option A -> option B -> option C),
    f None None = None ->
    forall (m1: t A) (m2: t B) (i: elt),
    get i (combine f m1 m2) = f (get i m1) (get i m2).
Proof.
    intros A B C f Hf m1 m2 [i1 i2]; unfold get, combine.
    rewrite M1.gcombine; auto.
    case_eq (M1.get i1 m1); intros; simpl; auto.
    case_eq (M1.get i1 m2); intros; simpl; auto.
    rewrite M2.gcombine; auto.
    rewrite M2.gcombine; auto.
    rewrite M2.gempty; auto.
    case_eq (M1.get i1 m2); intros; simpl; auto.
    rewrite M2.gcombine; auto.
    rewrite M2.gempty; auto.
  Qed.

Enumerating the bindings of a tree.
  Definition elements (A: Type) (m:t A) : list (elt * A) :=
    rev (M1.fold (fun l x1 m1 => M2.fold (fun l x2 a => ((x1,x2),a)::l) m1 l) m (@nil _)).

  Module P1 := Tree_Properties(M1).
  Module P2 := Tree_Properties(M2).

  Definition xelements (A: Type) (m:t A) : list (elt * A) :=
    let ll := List.map (fun p =>
                          let '(x1, m1) := p in
                          List.map (fun p => let '(x2, a) := p in ((x1,x2),a))
                                   (M2.elements m1))
                       (M1.elements m)
    in
    List.fold_left (@app _) ll (@nil _).

  Lemma xelements_eq : forall A (m:t A), xelements m = elements m.
Proof.
    intros; unfold xelements, elements.
    rewrite M1.fold_spec.
    change nil with (rev (nil (A:=M1.elt * M2.elt * A))) at 1.
    generalize (nil (A:=M1.elt * M2.elt * A)).
    induction (M1.elements m); intros; simpl. auto.
    rewrite <- IHl. rewrite M2.fold_spec. f_equal. clear.
    destruct a. simpl.
    revert l0; induction (M2.elements t0); simpl; intro.
    apply app_nil_r.
    destruct a; simpl. rewrite <- IHl. simpl.
    rewrite <- app_assoc. auto.
  Qed.

  Lemma elements_correct:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some v -> In (i, v) (elements m).
Proof.
    unfold get, elements; intros A m [i1 i2] v.
    case_eq (M1.get i1 m); [idtac| intros H1 H; discriminate].
    intros. apply -> in_rev.
    revert t0 H H0.
    apply (P1.fold_rec _
      (fun m l => forall i1 i2 v m2,
        M1.get i1 m = Some m2 ->
        M2.get i2 m2 = Some v ->
        In (i1, i2, v) l)); clear i1 i2 v.

    intros m' m'' l H1 H2 i1 i2 v m2 Heq1 Heq2.
    eapply H2; eauto.
    rewrite H1; auto.

    intros i1 i2 v m2 Heq1 Heq2.
    rewrite M1.gempty in *; congruence.

    assert (forall i x m2 l,
      In x l ->
      In x (M2.fold
        (fun (l0 : list (M1.elt * M2.elt * A)) (x2 : M2.elt) (a : A) =>
         (i, x2, a) :: l0) m2 l)).
      intros i x m2 l Hl.
      apply (P2.fold_rec
        (fun l0 x2 a => (i, x2, a) :: l0)
        (fun m2 l' => In x l')); auto with datatypes.

    intros m' l i m2 H1 H2 H3 i1 i2 v m0 Heq1 Heq2.
    rewrite M1.gsspec in Heq1; destruct M1.elt_eq; subst; eauto.
    inv Heq1.
    clear dependent m'.

    generalize i2 v Heq2; clear i2 v Heq2.
    apply (P2.fold_rec (fun l0 x2 a => (i, x2, a) :: l0)
      (fun m2 l => forall i2 v,
           M2.get i2 m2 = Some v ->
           In (i, i2, v) l) l); eauto.
    intros m2 m2' l1 H1 H3 i2 v.
    rewrite <- H1; auto.
    intros i2 v; rewrite M2.gempty; congruence.
    intros m1 a i2 v Heq1 Heq2 Heq3 i2' v'.
    rewrite M2.gsspec; destruct M2.elt_eq; intros T.
    inv T; left; auto.
    right; eauto.
  Qed.

  Lemma elements_complete:
    forall (A: Type) (m: t A) (i: elt) (v: A),
        In (i, v) (elements m) -> get i m = Some v.
Proof.
    unfold elements; intros A m.
    intros. apply in_rev in H. revert i v H.
    apply (P1.fold_rec _
      (fun m l =>
        (forall i v, In (i,v) l -> get i m = Some v))).
    intros m'' m' a Hyp1 Hyp0; auto.
    intros [i1 i2] Hv; unfold get.
    rewrite <- Hyp1 in *; apply (Hyp0 (i1,i2)).
    simpl; intuition.
    intros m1 l i v1 H1 H2 H3.
    apply (P2.fold_rec _
      (fun m l =>
        (forall i0 v, In (i0,v) l -> get i0 (M1.set i v1 m1) = Some v))); auto.
    intros [i1 i2] v Hv.
    unfold get. rewrite M1.gsspec; destruct M1.elt_eq; subst; auto.
    generalize (H3 _ _ Hv); unfold get.
    rewrite H1; congruence.
    generalize (H3 _ _ Hv); unfold get; auto.
    intros m0 a k v0 H H0 H4 [i1 i2] v; unfold get in *.
    simpl; destruct 1.
    inv H5.
    rewrite M1.gss; auto.
    rewrite M1.gsspec; destruct M1.elt_eq; subst.
    generalize (H4 _ _ H5); rewrite M1.gss; auto.
    generalize (H4 _ _ H5); rewrite M1.gso; auto.
  Qed.

  Lemma elements_keys_norepet:
    forall (A: Type) (m: t A),
      list_norepet (List.map (@fst elt A) (elements m)).
Proof.
    intros A m; rewrite <- xelements_eq.
    unfold xelements.
    assert (list_norepet (List.map fst (@nil ((M1.elt * M2.elt) * A)))) by constructor.
    assert (list_disjoint (List.map (fun (p:(M1.elt * M2.elt) * A) => fst (fst p)) nil)
                          (List.map fst (M1.elements m)))
      by destruct 1.
    revert H H0. generalize (@nil ((M1.elt * M2.elt) * A)).
    generalize (M1.elements_keys_norepet m).
    induction (M1.elements); simpl.
    auto.
    destruct a.
    intros. inv H.
    apply IHl; rewrite ?rev_append_rev, ?map_app, ?map_rev.
    - auto.
    - apply list_norepet_append; auto.
      + rewrite !map_map. pose proof M2.elements_keys_norepet t0. clear -H.
        induction (M2.elements t0) as [|[]]; simpl. constructor.
        inv H. specialize (IHl H3). constructor. 2:auto. contradict H2. clear -H2.
        induction l as [|[]]. auto. destruct H2. inv H. left. auto. right. auto.
      + rewrite map_map. repeat intro.
        apply list_in_map_inv in H2. destruct H2 as [[] []].
        apply in_map with (f:=fst) in H. rewrite map_map in H.
        apply (f_equal fst) in H3. apply H1 in H3. auto. auto.
        left. subst. auto.
    - rewrite map_map. repeat intro.
      apply in_app in H. destruct H.
      + eapply H1; eauto. right; auto.
      + apply list_in_map_inv in H. destruct H as [[] []].
        subst. subst. auto.
  Qed.

  Definition fold (A B: Type) (f: B -> elt -> A -> B) (m: t A) (b:B) : B :=
    M1.fold
      (fun acc x1 m1 => M2.fold (fun acc x2 a => f acc (x1, x2) a) m1 acc) m b.

  Lemma fold_spec:
    forall (A B: Type) (f: B -> elt -> A -> B) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
Proof.
    intros. unfold fold, elements.
    rewrite !M1.fold_spec.
    change v with (fold_right (fun p a => f a (fst p) (snd p)) v nil) at 1.
    fold elt. generalize (nil (A:=elt * A)).
    induction (M1.elements m).
    simpl. intros. rewrite <- fold_left_rev_right. rewrite rev_involutive. auto.
    intros. destruct a. simpl. rewrite !M2.fold_spec.
    revert l0. induction (M2.elements t0); simpl. auto.
    intros. rewrite <- IHl0. simpl. auto.
  Qed.

  Definition fold1 (A B: Type) (f: B -> A -> B) (m: t A) (b:B) : B :=
    M1.fold1
      (fun acc m1 => M2.fold1 (fun acc a => f acc a) m1 acc) m b.

  Lemma fold1_spec:
    forall (A B: Type) (f: B -> A -> B) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a p => f a (snd p)) (elements m) v.
Proof.
    intros. unfold fold1, elements.
    rewrite M1.fold_spec, M1.fold1_spec.
    change v with (fold_right (fun (p:elt*A) a => f a (snd p)) v nil) at 1.
    fold elt. generalize (nil (A:=elt * A)).
    induction (M1.elements m).
    simpl. intros. rewrite <- fold_left_rev_right. rewrite rev_involutive. auto.
    intros. destruct a. simpl. rewrite M2.fold_spec, M2.fold1_spec.
    revert l0. induction (M2.elements t0); simpl. auto.
    intros. rewrite <- IHl0. simpl. auto.
  Qed.

End ProdTree.

Module ProdShareTree (M1:SHARETREE) (M2:SHARETREE) <: SHARETREE with
  Definition elt := (M1.elt * M2.elt)%type.

  Include ProdTree(M1)(M2).

  Module M2Prop := ShareTree_Properties(M2).

  Program Definition shcombine A (f:elt -> option A -> option A -> option A)
                                 (Hf:∀ x v, f x v v = v) x y : t A :=
    M1.shcombine (fun e1 a1 b1 =>
      let a1' := match a1 return _ with Some a1' => a1' | None => M2.empty A end in
      let b1' := match b1 return _ with Some b1' => b1' | None => M2.empty A end in
      let res := M2.shcombine (fun e2 => f (e1, e2)) _ a1' b1' in
      if M2Prop.is_empty res return _ then
        match a1, b1 return _ with
          | None, _ | _, None => None
          | Some _, Some _ =>
            if M2Prop.is_empty a1' && M2Prop.is_empty b1' then Some res
            else None
        end
      else Some res)
      _ x y.
Next Obligation.
    rewrite M2.shcombine_eq. destruct v.
    - destruct M2Prop.is_empty; auto.
    - replace (M2Prop.is_empty (M2.empty A)) with true. auto.
      symmetry. rewrite M2Prop.is_empty_correct. intro. rewrite M2.gempty. auto.
  Qed.

  Lemma gshcombine:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    intros. unfold shcombine, get. destruct i.
    rewrite M1.gshcombine.
    match goal with
      | |- context [M2Prop.is_empty ?m] =>
        pose proof M2Prop.is_empty_correct _ m;
        destruct (M2Prop.is_empty m)
    end.
    - destruct H as (H & _). specialize (H eq_refl e0).
      rewrite M2.gshcombine in H. rewrite <- H.
      destruct (M1.get e m1), (M1.get e m2).
      + destruct andb. rewrite M2.gshcombine. auto. auto.
      + rewrite M2.gempty in *. rewrite H. auto.
      + rewrite M2.gempty in *. rewrite H. auto.
      + rewrite M2.gempty, Hf, Hf. auto.
    - rewrite M2.gshcombine. destruct M1.get, M1.get; rewrite ?M2.gempty; auto.
  Qed.

  Lemma shcombine_eq:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m: t A),
      shcombine f Hf m m = m.
Proof.
intros. apply M1.shcombine_eq. Qed.

  Program Definition shcombine_diff A B (f:elt -> option A -> option A -> option B)
                                        (Hf:∀ x v, f x v v = None) x y : t B :=
    M1.shcombine_diff (fun e1 a1 b1 =>
      let a1' := match a1 return _ with Some a1' => a1' | None => M2.empty A end in
      let b1' := match b1 return _ with Some b1' => b1' | None => M2.empty A end in
      let res := M2.shcombine_diff (fun e2 => f (e1, e2)) _ a1' b1' in
      if M2Prop.is_empty res return _ then None
      else Some res)
      _ x y.
Next Obligation.
    rewrite M2.shcombine_diff_eq.
    replace (M2Prop.is_empty (M2.empty B)) with true. auto.
    symmetry. rewrite M2Prop.is_empty_correct. intro. rewrite M2.gempty. auto.
  Qed.

  Lemma gshcombine_diff:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine_diff f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
    intros. unfold shcombine_diff, get. destruct i.
    rewrite M1.gshcombine_diff.
    match goal with
      | |- context [M2Prop.is_empty ?m] =>
        pose proof M2Prop.is_empty_correct _ m;
        destruct (M2Prop.is_empty m)
    end.
    - destruct H as (H & _). specialize (H eq_refl e0).
      rewrite M2.gshcombine_diff in H. rewrite <- H.
      destruct (M1.get e m1), (M1.get e m2).
      + auto.
      + rewrite M2.gempty. auto.
      + rewrite M2.gempty. auto.
      + rewrite M2.gempty. auto.
    - rewrite M2.gshcombine_diff. destruct M1.get, M1.get; rewrite ?M2.gempty; auto.
  Qed.

  Lemma shcombine_diff_eq:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m: t A),
      shcombine_diff f Hf m m = empty _.
Proof.
intros. apply M1.shcombine_diff_eq. Qed.

  Program Definition shforall2 A (f:elt -> option A -> option A -> bool)
                                 (Hf:forall x v, f x v v = true) x y : bool :=
    M1.shforall2 (fun e1 a1 b1 =>
      let a1' := match a1 return _ with Some a1' => a1' | None => M2.empty A end in
      let b1' := match b1 return _ with Some b1' => b1' | None => M2.empty A end in
      M2.shforall2 (fun e2 => f (e1, e2)) _ a1' b1')
      _ x y.
Next Obligation.
rewrite M2.shforall2_correct. auto. Qed.

  Lemma shforall2_correct:
    forall (A: Type)
           (f: elt -> option A -> option A -> bool) (Hf:∀ x v, f x v v = true),
    forall (m1: t A) (m2: t A),
    shforall2 f Hf m1 m2 = true <->
    (forall x, f x (get x m1) (get x m2) = true).
Proof.
    intros. unfold shforall2.
    rewrite M1.shforall2_correct.
    setoid_rewrite M2.shforall2_correct.
    split.
    - destruct x. specialize (H e e0). unfold get.
      destruct M1.get, M1.get;
      rewrite ?M2.gempty in H; auto.
    - intros. specialize (H (x, x0)).
      unfold get in H.
      destruct M1.get, M1.get;
      rewrite ?M2.gempty; auto.
  Qed.

End ProdShareTree.

Module UnitTree <: TREE with Definition elt := unit.

  Definition elt := unit.
  Definition elt_eq : forall (x y:elt), {x = y}+{xy}.
Proof.
left. destruct x, y. auto. Defined.
  Definition t (A:Type) := option A.
  Definition empty := @None.
  Definition get {A:Type} (x:elt) (m:t A) := m.
  Definition set {A:Type} (x:elt) (y:A) (m:t A) := Some y.
  Definition remove {A:Type} (x:elt) (m:t A) := @None A.

  Lemma gempty: forall (A: Type) (i: elt), get i (empty A) = None.
Proof.
reflexivity. Qed.
  Lemma gss: forall (A: Type) (i: elt) (x: A) (m: t A), get i (set i x m) = Some x.
Proof.
reflexivity. Qed.
  Lemma gso: forall (A: Type) (i j: elt) (x: A) (m: t A),
    i <> j -> get i (set j x m) = get i m.
Proof.
destruct i, j; congruence. Qed.
  Lemma gsspec: forall (A: Type) (i j: elt) (x: A) (m: t A),
    get i (set j x m) = if elt_eq i j then Some x else get i m.
Proof.
reflexivity. Qed.
  Lemma grs: forall (A: Type) (i: elt) (m: t A), get i (remove i m) = None.
Proof.
reflexivity. Qed.
  Lemma gro: forall (A: Type) (i j: elt) (m: t A),
    i <> j -> get i (remove j m) = get i m.
Proof.
destruct i, j; congruence. Qed.
  Lemma grspec: forall (A: Type) (i j: elt) (m: t A),
    get i (remove j m) = if elt_eq i j then None else get i m.
Proof.
reflexivity. Qed.
  Definition get_set A (i:elt) m : option A * (A -> t A) :=
    (m, Some).
  Lemma get_set_spec:
    forall (A: Type) (i:elt) (m:t A),
      fst (get_set i m) = get i m /\
      forall v, snd (get_set i m) v = set i v m.
Proof.
split; reflexivity. Qed.
  Definition beq A cmp (t1 t2: t A) : bool :=
    match t1, t2 with
    | None, None => true
    | Some y1, Some y2 => cmp y1 y2
    | _, _ => false
    end.
  Lemma beq_correct:
    forall (A: Type) (eqA: A -> A -> bool) (t1 t2: t A),
      beq eqA t1 t2 = true <->
      (forall (x: elt),
         match get x t1, get x t2 with
           | None, None => True
           | Some y1, Some y2 => eqA y1 y2 = true
           | _, _ => False
         end).
Proof.
    unfold get, beq. split; intro; [|specialize (H tt)]; destruct t1, t2; now intuition.
  Qed.
  Definition map {A B} f := @option_map A B (f tt).
  Lemma gmap:
    forall (A B: Type) (f: elt -> A -> B) (i: elt) (m: t A),
    get i (map f m) = option_map (f i) (get i m).
Proof.
destruct i; reflexivity. Qed.
  Definition map1 := option_map.
  Lemma gmap1:
    forall (A B: Type) (f: A -> B) (i: elt) (m: t A),
    get i (map1 f m) = option_map f (get i m).
Proof.
reflexivity. Qed.
  Definition combine A B C (f:option A -> option B -> option C) := f.
  Lemma gcombine:
    forall (A B C: Type) (f: option A -> option B -> option C),
    f None None = None ->
    forall (m1: t A) (m2: t B) (i: elt),
      get i (combine f m1 m2) = f (get i m1) (get i m2).
Proof.
reflexivity. Qed.
  Definition elements A m : list (elt * A) :=
    match m with
    | None => nil
    | Some x => (tt, x)::nil
    end.
  Lemma elements_correct:
    forall (A: Type) (m: t A) (i: elt) (v: A),
    get i m = Some v -> In (i, v) (elements m).
Proof.
unfold get. intros. subst. left. destruct i; reflexivity. Qed.
  Lemma elements_complete:
    forall (A: Type) (m: t A) (i: elt) (v: A),
      In (i, v) (elements m) -> get i m = Some v.
Proof.
destruct m; destruct 1. inv H; reflexivity. destruct H. Qed.
  Lemma elements_keys_norepet:
    forall (A: Type) (m: t A),
    list_norepet (List.map (@fst elt A) (elements m)).
Proof.
destruct m; constructor. auto. constructor. Qed.

  Definition fold A B f (m: t A) b : B :=
    match m with
    | None => b
    | Some x => f b tt x
    end.
  Lemma fold_spec:
    forall (A B: Type) (f: B -> elt -> A -> B) (v: B) (m: t A),
    fold f m v =
    List.fold_left (fun a p => f a (fst p) (snd p)) (elements m) v.
Proof.
destruct m; reflexivity. Qed.

  Definition fold1 A B f (m: t A) b : B :=
    match m with
    | None => b
    | Some x => f b x
    end.
  Lemma fold1_spec:
    forall (A B: Type) (f: B -> A -> B) (v: B) (m: t A),
    fold1 f m v =
    List.fold_left (fun a p => f a (snd p)) (elements m) v.
Proof.
destruct m; reflexivity. Qed.

End UnitTree.

Module UnitShareTree <: SHARETREE with Definition elt := unit.

  Include UnitTree.

  Program Definition shcombine A (f:elt -> option A -> option A -> option A)
          (Hf:∀ x v, f x v v = v) x y : t A :=
    ifeq x == y then x else f tt x y.
  Lemma gshcombine:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
destruct i; reflexivity. Qed.
  Lemma shcombine_eq:
    forall (A: Type)
           (f: elt -> option A -> option A -> option A) (Hf:∀ x v, f x v v = v),
    forall (m: t A),
      shcombine f Hf m m = m.
Proof.
intros. apply Hf. Qed.

  Program Definition shcombine_diff A B (f:elt -> option A -> option A -> option B)
                                        (Hf:∀ x v, f x v v = None) x y : t B :=
    ifeq x == y then None else f tt x y.
  Lemma gshcombine_diff:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m1 m2: t A) (i: elt),
    get i (shcombine_diff f Hf m1 m2) = f i (get i m1) (get i m2).
Proof.
destruct i; reflexivity. Qed.
  Lemma shcombine_diff_eq:
    forall (A B: Type)
           (f: elt -> option A -> option A -> option B) (Hf:∀ x v, f x v v = None),
    forall (m: t A),
      shcombine_diff f Hf m m = empty _.
Proof.
intros. apply Hf. Qed.

  Program Definition shforall2 A (f:elt -> option A -> option A -> bool)
                                 (Hf:forall x v, f x v v = true) x y : bool :=
    ifeq x == y then true else f tt x y.
  Lemma shforall2_correct:
    forall (A: Type)
           (f: elt -> option A -> option A -> bool) (Hf:∀ x v, f x v v = true),
    forall (m1: t A) (m2: t A),
    shforall2 f Hf m1 m2 = true <->
    (forall x, f x (get x m1) (get x m2) = true).
Proof.
split; intros. destruct x; apply H. apply H. Qed.

End UnitShareTree.