Module CellEncoding


Require Import
  Utf8
  Coqlib Axioms
  AdomLib Util.

Require Import ArithLib Cells.

Set Implicit Arguments.

Local Open Scope N_scope.

Definition size (mn: N * N) : nat :=
  N.to_nat (fst mn + snd mn).

Function interleave (mn: N * N) {measure size mn} : N :=
  let '(m, n) := mn in
  match m, n with
  | N0, N0 => N0
  | N0, Npos 1 => Npos 2
  | N0, Npos (n'~0) => 4 * interleave (N0, Npos n')
  | N0, Npos (n'~1) => 2 + 4 * interleave (N0, Npos n')
  | Npos 1, N0 => Npos 1
  | Npos 1, Npos 1 => Npos 3
  | Npos 1, Npos (n'~0) => 1 + 4 * interleave (N0, Npos n')
  | Npos 1, Npos (n'~1) => 3 + 4 * interleave (N0, Npos n')
  | Npos (m'~0), N0 => 4 * interleave (Npos m', N0)
  | Npos (m'~0), Npos 1 => 2 + 4 * interleave (Npos m', N0)
  | Npos (m'~0), Npos (n'~0) => 4 * interleave (Npos m', Npos n')
  | Npos (m'~0), Npos (n'~1) => 2 + 4 * interleave (Npos m', Npos n')
  | Npos (m'~1), N0 => 1 + 4 * interleave (Npos m', N0)
  | Npos (m'~1), Npos 1 => 3 + 4 * interleave (Npos m', N0)
  | Npos (m'~1), Npos (n'~0) => 1 + 4 * interleave (Npos m', Npos n')
  | Npos (m'~1), Npos (n'~1) => 3 + 4 * interleave (Npos m', Npos n')
  end.
Proof.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. rewrite Pos.add_carry_spec. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
  unfold size. intros. simpl. zify. Psatz.lia.
Defined.

Function split (n: N) {measure N.to_nat n} : N * N :=
  match n with
  | 0 => (0, 0)
  | 1 => (1, 0)
  | 2 => (0, 1)
  | 3 => (1, 1)
  | Npos (n'~0~0) => let (x, y) := split (Npos n') in (2 * x, 2 * y)
  | Npos (n'~0~1) => let (x, y) := split (Npos n') in (1 + 2 * x, 2 * y)
  | Npos (n'~1~0) => let (x, y) := split (Npos n') in (2 * x, 1 + 2 * y)
  | Npos (n'~1~1) => let (x, y) := split (Npos n') in (1 + 2 * x, 1 + 2 * y)
  end.
Proof.
  intros; simpl; zify; Psatz.lia.
  intros; simpl; zify; Psatz.lia.
  intros; simpl; zify; Psatz.lia.
  intros; simpl; zify; Psatz.lia.
Defined.

Lemma split_4n n : split (4 * n) = let '(x, y) := split n in (2 * x, 2 * y).
Proof.
rewrite split_equation. destruct n; reflexivity. Qed.
Lemma split_4np1 n : split (1 + 4 * n) = let '(x, y) := split n in (1 + 2 * x, 2 * y).
Proof.
rewrite split_equation. destruct n; reflexivity. Qed.
Lemma split_4np2 n : split (2 + 4 * n) = let '(x, y) := split n in (2 * x, 1 + 2 * y).
Proof.
rewrite split_equation. destruct n; reflexivity. Qed.
Lemma split_4np3 n : split (3 + 4 * n) = let '(x, y) := split n in (1 + 2 * x, 1 + 2 * y).
Proof.
rewrite split_equation. destruct n; reflexivity. Qed.

Lemma split_interleave mn : split (interleave mn) = mn.
Proof.
  apply interleave_ind; try reflexivity;
  intros; subst;
  first [ rewrite split_4n | rewrite split_4np1 | rewrite split_4np2 | rewrite split_4np3 ];
  rewrite H; auto.
Qed.

Lemma interleave_split n : interleave (split n) = n.
Proof.
  apply split_ind; try reflexivity;
  intros n0 n' e H x y e0; subst;
  rewrite e0 in H;
  rewrite interleave_equation;
  destruct x; destruct y; try discriminate; simpl; rewrite H; reflexivity.
Qed.

Unset Elimination Schemes.

Section ISO.

  Class injection (A B: Type) := inj : AB.

  Class bijection A B `(B_of_A: injection A B) `(A_of_B: injection B A) : Prop :=
    Bijection
      { forward : ∀ a : A, inj (inj a) = a ;
        backward: ∀ b : B, inj (inj b) = b }.

  Definition inverse_bijection A B `(bijection A B) : @bijection B A _ _ :=
    Bijection _ _ backward forward.

End ISO.

Definition InjectionTrans B A C `(injection A B) `(injection B C) : injection A C
  := injinj.
Arguments InjectionTrans B A C {_} {_} _.

Instance BijectionTrans B A C `(bijection A B) `(bijection B C)
  : bijection (InjectionTrans B A C) (InjectionTrans B C A).
Proof.
  split; intros; unfold inj, InjectionTrans. repeat rewrite forward; reflexivity.
  repeat rewrite backward; reflexivity.
Qed.

Instance NOfPos : injection positive N := Pos.pred_N.
Instance PosOfN : injection N positive := N.succ_pos.

Instance PosAsN : bijection PosOfN NOfPos.
Proof.
  split. exact N.pos_pred_succ.
  intros p. unfold inj, NOfPos, PosOfN.
  destruct p as [|p|]; try reflexivity.
  simpl. apply Pos.succ_pred_double.
Qed.

Instance PosOfZ : injection Z positive := λ z,
  match z with
  | Z0 => xH
  | Zpos p => xO p
  | Zneg p => xI p
  end.

Instance ZOfPos : injection positive Z := λ p,
  match p with
  | xH => Z0
  | xO p => Zpos p
  | xI p => Zneg p
  end.

Instance ZAsPos : bijection PosOfZ ZOfPos.
Proof.
  split; intros [| |]; reflexivity.
Qed.

Instance NOfZ : injection Z N := λ z,
  match z with
  | Z0 => N0
  | Zpos p => Npos (xO p)
  | Zneg p => Npos (2 * p - 1)
  end.

Instance ZOfN : injection N Z := λ p,
  match p with
  | N0 => Z0
  | Npos xH => Zneg xH
  | Npos (xO p) => Zpos p
  | Npos (xI p) => Zneg (p+1)
  end.

Instance ZAsN : bijection NOfZ ZOfN.
Proof.
  split. intros [| |]; try reflexivity. unfold inj.
  induction p as [|p IH|]; try reflexivity. simpl. f_equal.
  simpl in *. rewrite <- Pos.pred_sub in IH. simpl in IH.
  destruct (Pos.pred_double p); inv IH. zify. ring. reflexivity.
  intros [|p]. reflexivity.
  destruct p; try reflexivity.
  change (N.pos (2 * (p + 1) - 1) = N.pos (2*p+1)).
  rewrite Pos.mul_add_distr_l, <- Pos.add_sub_assoc by reflexivity. reflexivity.
Qed.

Instance NOfPair A B `(injection A N) `(injection B N) : injection (A * B) N :=
  λ ab, let '(a, b) := ab in interleave (inj a, inj b).

Instance PairOfN A B `(injection N A) `(injection N B) : injection N (A * B) :=
  λ n, let '(a, b) := split n in (inj a, inj b).

Arguments NOfPair A B {_} {_} _.
Arguments PairOfN A B {_} {_} _.

Instance PairBijection A B `(bijection A N) `(bijection B N) : bijection (NOfPair A B) (PairOfN A B).
Proof.
  split.
  intros (a, b). unfold inj, PairOfN. simpl. rewrite split_interleave.
  repeat rewrite forward. reflexivity.
  intros n. unfold inj, PairOfN, NOfPair.
  generalize (interleave_split n). destruct (split n) as (a, b).
  intros <- . repeat rewrite backward. reflexivity.
Qed.

Definition NOfMemoryChunk : MemChunkTree.typed_memory_chunk -> N := λ c,
  match c with
    | exist AST.Mint8signed _ => 0
    | exist AST.Mint8unsigned _ => 1
    | exist AST.Mint16signed _ => 2
    | exist AST.Mint16unsigned _ => 3
    | exist AST.Mint32 _ => 4
    | exist AST.Mint64 _ => 5
    | exist AST.Mfloat32 _ => 6
    | exist AST.Mfloat64 _ => 7
    | exist _ H => match H with end
  end%N.

Definition MemoryChunkOfN : N -> MemChunkTree.typed_memory_chunk := λ n,
  match n with
    | 0 => exist _ AST.Mint8signed I
    | 1 => exist _ AST.Mint8unsigned I
    | 2 => exist _ AST.Mint16signed I
    | 3 => exist _ AST.Mint16unsigned I
    | 4 => exist _ AST.Mint32 I
    | 5 => exist _ AST.Mint64 I
    | 6 => exist _ AST.Mfloat32 I
    | _ => exist _ AST.Mfloat64 I
  end%N.

Instance NOfCell : injection cell N := λ ab,
  (match ab with
  | AClocal i i' o c => 24 * inj (i, i', o) + 3 * NOfMemoryChunk c
  | ACtemp i r => 3 * inj (i, r) + 1
  | ACglobal b o c => 24 * inj (b, o) + 2 + 3 * NOfMemoryChunk c
  end)%N.

Instance CellOfN : injection N cell := λ n,
  let '(q, r) := N.div_eucl n 3 in
  (match r with
  | 0 =>
    let '(q', r') := N.div_eucl q 8 in
    let c := MemoryChunkOfN r' in
    let '(i, i', o) := inj q' in AClocal i i' o c
  | 1 => let '(i, r) := inj q in ACtemp i r
  | _ =>
    let '(q', r') := N.div_eucl q 8 in
    let c := MemoryChunkOfN r' in
    let '(b, o) := inj q' in ACglobal b o c
  end)%N.

Instance PosOfCell : injection cell positive := InjectionTrans N cell positive.
Instance CellOfPos : injection positive cell := InjectionTrans N positive cell.

Instance CellAsPos : bijection PosOfCell CellOfPos.
Proof.
  apply BijectionTrans.
  - split.
    + unfold inj, CellOfN. intros a.
      generalize (N_div_eucl_spec (NOfCell a) 3).
      destruct N.div_eucl as (q, r).
      intros (Ha & Hr). lapply Hr; [intro|discriminate].
      unfold NOfCell in Ha.
      destruct a.
      * assert (r = 0 ∧ q = 8 * inj (f, l, ofs) + NOfMemoryChunk κ)%N as K.
        { remember (inj (f, l, ofs)). clear Heqy. abstract Psatz.lia. }
        destruct K; subst r q.
        generalize (N_div_eucl_spec (8 * inj (f, l, ofs) + NOfMemoryChunk κ) 8).
        destruct N.div_eucl as (q', r').
        intros (Ha' & Hr'). lapply Hr'; [intro|discriminate].
        assert (r' = NOfMemoryChunk κ ∧ q' = inj (f, l, ofs))%N as K.
        { assert (NOfMemoryChunk κ < 8)%N by (destruct κ as [[][]]; reflexivity).
          remember (inj (f, l, ofs)). clear Heqy. abstract Psatz.lia. }
        destruct K; subst r' q'.
        rewrite @forward. destruct κ as [[][]]; reflexivity.
        repeat apply PairBijection; eauto using inverse_bijection with typeclass_instances.
      * assert (r = 1 ∧ q = inj (f, r0))%N as K.
        { remember (inj (f, r0)). clear Heqy. abstract Psatz.lia. }
        destruct K; subst r q.
        rewrite @forward. auto.
        repeat apply PairBijection; eauto using inverse_bijection with typeclass_instances.
      * assert (r = 2 ∧ q = 8 * inj (b, ofs) + NOfMemoryChunk κ)%N as K.
        { remember (inj (b, ofs)). clear Heqy. abstract Psatz.lia. }
        destruct K; subst r q.
        generalize (N_div_eucl_spec (8 * inj (b, ofs) + NOfMemoryChunk κ) 8).
        destruct N.div_eucl as (q', r').
        intros (Ha' & Hr'). lapply Hr'; [intro|discriminate].
        assert (r' = NOfMemoryChunk κ ∧ q' = inj (b, ofs))%N as K.
        { assert (NOfMemoryChunk κ < 8)%N by (destruct κ as [[][]]; reflexivity).
          remember (inj (b, ofs)). clear Heqy. abstract Psatz.lia. }
        destruct K; subst r' q'.
        rewrite @forward. destruct κ as [[][]]; reflexivity.
        repeat apply PairBijection; eauto using inverse_bijection with typeclass_instances.
  + unfold inj, CellOfN. intros n.
    generalize (N_div_eucl_spec n 3).
    destruct N.div_eucl as (q, r).
    intros (Hb & Hr). lapply Hr; [intro|discriminate].
    unfold NOfCell.
    destruct r as [|[r|r|]].
    * generalize (N_div_eucl_spec q 8).
      destruct N.div_eucl as (q', r').
      intros (Hb' & Hr'). lapply Hr'; [intro|discriminate].
      destruct (inj q') as [[] ?] eqn: Q'.
      rewrite <- Q', @backward by eauto using inverse_bijection with typeclass_instances.
      assert (NOfMemoryChunk (MemoryChunkOfN r') = r').
      { abstract (destruct r' as [|[[[]|[|[]|]|]|[[]|[|[]|]|]|]]; try reflexivity; Psatz.lia). }
      abstract Psatz.lia.
    * generalize (N_div_eucl_spec q 8).
      destruct N.div_eucl as (q', r').
      intros (Hb' & Hr'). lapply Hr'; [intro|discriminate].
      destruct (inj q') eqn: Q'.
      rewrite <- Q', @backward by eauto using inverse_bijection with typeclass_instances.
      assert (NOfMemoryChunk (MemoryChunkOfN r') = r').
      { abstract (destruct r' as [|[[[]|[]|]|[[]|[]|]|]]; try reflexivity; Psatz.lia). }
      abstract Psatz.lia.
    * generalize (N_div_eucl_spec q 8).
      destruct N.div_eucl as (q', r').
      intros (Hb' & Hr'). lapply Hr'; [intro|discriminate].
      destruct (inj q') eqn: Q'.
      rewrite <- Q', @backward by eauto using inverse_bijection with typeclass_instances.
      assert (NOfMemoryChunk (MemoryChunkOfN r') = r').
      { abstract (destruct r' as [|[[[]|[|[]|]|]|[[]|[|[]|]|]|]]; try reflexivity; Psatz.lia). }
      abstract Psatz.lia.
    * destruct (inj q) eqn: Q.
      rewrite <- Q, @backward by eauto using inverse_bijection with typeclass_instances.
      Psatz.lia.
  - apply PosAsN.
Qed.