Module FastArith



Require Import Utf8 ZArith Util Coqlib ToString Recdef.

Inductive Zfast :=
| Zwrap (x:Z).
Extract Inductive Zfast =>
  "FastArithImpl.ZArith.t" [ "FastArithImpl.encode_Z" ]
                           "(fun f z -> f (FastArithImpl.decode_Z z))".
Coercion Zwrap : Z >-> Zfast.

Definition Zunwrap (x:Zfast) := let 'Zwrap x := x in x.
Coercion Zunwrap : Zfast >-> Z.

Instance ZfastToString : ToString Zfast := fun x => to_string (x:Z).

Definition Zfastopp (a:Zfast) : Zfast := Z.opp a.
Extract Constant Zfastopp => "FastArithImpl.ZArith.neg".
Definition Zfastadd (a b:Zfast) : Zfast := Z.add a b.
Extract Constant Zfastadd => "FastArithImpl.ZArith.add".
Definition Zfastsub (a b:Zfast) : Zfast := Z.sub a b.
Extract Constant Zfastsub => "FastArithImpl.ZArith.sub".
Definition Zfastmul (a b:Zfast) : Zfast := Z.mul a b.
Extract Constant Zfastmul => "FastArithImpl.ZArith.mul".
Definition Zfastquot (a b:Zfast) : Zfast := Z.quot a b.
Extract Constant Zfastquot => "FastArithImpl.quot".
Definition Zfastdiv (a b:Zfast) : Zfast := Z.div a b.
Extract Constant Zfastdiv => "FastArithImpl.div".
Definition Zfastmod (a b:Zfast) : Zfast := Z.modulo a b.
Extract Constant Zfastmod => "FastArithImpl.modulo".
Definition Zfastlnot (a:Zfast) : Zfast := Z.lnot a.
Extract Constant Zfastlnot => "FastArithImpl.ZArith.lognot".
Definition Zfastland (a b:Zfast) : Zfast := Z.land a b.
Extract Constant Zfastland => "FastArithImpl.ZArith.logand".
Definition Zfastlor (a b:Zfast) : Zfast := Z.lor a b.
Extract Constant Zfastlor => "FastArithImpl.ZArith.logor".
Definition Zfastlxor (a b:Zfast) : Zfast := Z.lxor a b.
Extract Constant Zfastlxor => "FastArithImpl.ZArith.logxor".
Definition Zfastshl (a b:Zfast) : Zfast := Z.shiftl a b.
Extract Constant Zfastshl => "FastArithImpl.shiftl".
Definition Zfastshr (a b:Zfast) : Zfast := Z.shiftr a b.
Extract Constant Zfastshr => "FastArithImpl.shiftr".
Definition Zfastcompare (a b:Zfast) : comparison := Z.compare a b.
Extract Constant Zfastcompare => "FastArithImpl.compare".
Definition Zfastleb (a b:Zfast) : bool :=
  match Zfastcompare a b with
  | Gt => false | Eq | Lt => true
  end.
Definition Zfasteqb (a b:Zfast) : bool :=
  match Zfastcompare a b with
  | Gt | Lt => false | Eq => true
  end.
Lemma Zfasteqb_Zeqb:
  ∀ (a b:Zfast), Zfasteqb a b = Z.eqb a b.
Proof.
  unfold Zfasteqb, Zfastcompare. intros.
  destruct (Z.compare_spec a b), (Z.eqb_spec a b); auto; Psatz.lia.
Qed.
Program Instance Zfasteqb_dec : EqDec Zfast := fun a b =>
  match Zfasteqb a b with true => left _ | false => right _ end.
Next Obligation.
  rewrite Zfasteqb_Zeqb in Heq_anonymous.
  destruct (Z.eqb_spec a b), a as [a], b as [b]; simpl in *; congruence.
Qed.
Next Obligation.
  rewrite Zfasteqb_Zeqb in Heq_anonymous.
  destruct (Z.eqb_spec a b), a as [a], b as [b]; simpl in *; congruence.
Qed.
Definition Zfastmax (a b:Zfast) : Zfast :=
  if Zfastleb a b then b else a.
Lemma Zfastmax_Zmax:
  ∀ (a b:Zfast), Zfastmax a b = Zmax a b.
Proof.
  intros. unfold Zfastmax, Zfastleb, Zfastcompare, Zmax. simpl.
  destruct (Zcompare_spec a b), a, b; simpl in *; subst; auto.
Qed.
Definition Zfastmin (a b:Zfast) : Zfast :=
  if Zfastleb a b then a else b.
Lemma Zfastmin_Zmin:
  ∀ (a b:Zfast), Zfastmin a b = Zmin a b.
Proof.
  intros. unfold Zfastmin, Zfastleb, Zfastcompare, Zmin. simpl.
  destruct (Zcompare_spec a b), a, b; simpl in *; subst; auto.
Qed.

Inductive Nfast :=
| Nwrap (x:N).
Coercion Nwrap : N >-> Nfast.
Extract Inductive Nfast =>
  "FastArithImpl.ZArith.t" [ "FastArithImpl.encode_N" ]
                           "(fun f z -> f (FastArithImpl.decode_N z))".
Definition Nunwrap (x:Nfast) := let 'Nwrap x := x in x.
Coercion Nunwrap : Nfast >-> N.

Instance NfastToString : ToString Nfast := fun x => to_string (x:N).

Definition Zfastabs (a:Zfast) : Nfast := Z.abs_N a.
Extract Constant Zfastabs => "FastArithImpl.ZArith.abs".
Definition Nfastadd (a b:Nfast) : Nfast := N.add a b.
Extract Constant Nfastadd => "FastArithImpl.ZArith.add".
Definition Nfastmul (a b:Nfast) : Nfast := N.mul a b.
Extract Constant Nfastmul => "FastArithImpl.ZArith.mul".
Definition Nfastdiv (a b:Nfast) : Nfast := N.div a b.
Extract Constant Nfastdiv => "FastArithImpl.div".
Definition Nfastmod (a b:Nfast) : Nfast := N.modulo a b.
Extract Constant Nfastmod => "FastArithImpl.modulo".
Definition Nfastland (a b:Nfast) : Nfast := N.land a b.
Extract Constant Nfastland => "FastArithImpl.ZArith.logand".
Definition Nfastlor (a b:Nfast) : Nfast := N.lor a b.
Extract Constant Nfastlor => "FastArithImpl.ZArith.logor".
Definition Nfastlxor (a b:Nfast) : Nfast := N.lxor a b.
Extract Constant Nfastlxor => "FastArithImpl.ZArith.logxor".
Definition Nfastshl (a b:Nfast) : Nfast := N.shiftl a b.
Extract Constant Nfastshl => "FastArithImpl.shiftl".
Definition Nfastshr (a b:Nfast) : Nfast := N.shiftr a b.
Extract Constant Nfastshr => "FastArithImpl.shiftr".
Definition Nfastcompare (a b:Nfast) : comparison := N.compare a b.
Extract Constant Nfastcompare => "FastArithImpl.compare".
Definition Nfastleb (a b:Nfast) : bool :=
  match Nfastcompare a b with
  | Gt => false | Eq | Lt => true
  end.

Definition Nfasteqb (a b:Nfast) : bool :=
  match Nfastcompare a b with
  | Gt | Lt => false | Eq => true
  end.
Lemma Nfasteqb_Neqb:
  ∀ (a b:Nfast), Nfasteqb a b = N.eqb a b.
Proof.
  unfold Nfasteqb, Nfastcompare. intros.
  destruct (N.compare_spec a b), (N.eqb_spec a b); auto; Psatz.lia.
Qed.

Definition Nfastmax (a b:Nfast) : Nfast :=
  if Nfastleb a b then b else a.
Lemma Nfastmax_Nmax:
  ∀ (a b:Nfast), Nfastmax a b = Nmax a b.
Proof.
  intros. unfold Nfastmax, Nfastleb, Nfastcompare, Nmax. simpl.
  destruct (Ncompare_spec a b), a, b; simpl in *; subst; auto.
Qed.

Definition Nfastmin (a b:Nfast) : Nfast :=
  if Nfastleb a b then a else b.
Lemma Nfastmin_Nmin:
  ∀ (a b:Nfast), Nfastmin a b = Nmin a b.
Proof.
  intros. unfold Nfastmin, Nfastleb, Nfastcompare, Nmin. simpl.
  destruct (Ncompare_spec a b), a, b; simpl in *; subst; auto.
Qed.

Definition Zfast_of_Nfast (a:Nfast) : Zfast := Z.of_N a.
Extract Constant Zfast_of_Nfast => "(fun x -> x)".
Coercion Zfast_of_Nfast : Nfast >-> Zfast.

Definition Nfast_of_Zfast (a:Zfast) : Nfast := Z.to_N a.
Extract Constant Nfast_of_Zfast => "FastArithImpl.nofz".

Definition Fm1 : Zfast := -1.
Definition F0 : Nfast := 0%N. Definition F20: Nfast := 20%N.
Definition F1 : Nfast := 1%N. Definition F21: Nfast := 21%N.
Definition F2 : Nfast := 2%N. Definition F22: Nfast := 22%N.
Definition F3 : Nfast := 3%N. Definition F23: Nfast := 23%N.
Definition F4 : Nfast := 4%N. Definition F24: Nfast := 24%N.
Definition F5 : Nfast := 5%N. Definition F25: Nfast := 25%N.
Definition F6 : Nfast := 6%N. Definition F26: Nfast := 26%N.
Definition F7 : Nfast := 7%N. Definition F27: Nfast := 27%N.
Definition F8 : Nfast := 8%N. Definition F28: Nfast := 28%N.
Definition F9 : Nfast := 9%N. Definition F29: Nfast := 29%N.
Definition F10: Nfast := 10%N. Definition F30: Nfast := 30%N.
Definition F11: Nfast := 11%N. Definition F31: Nfast := 31%N.
Definition F12: Nfast := 12%N. Definition F32: Nfast := 32%N.
Definition F13: Nfast := 13%N. Definition F33: Nfast := 33%N.
Definition F14: Nfast := 14%N. Definition F34: Nfast := 34%N.
Definition F15: Nfast := 15%N. Definition F35: Nfast := 35%N.
Definition F16: Nfast := 16%N. Definition F36: Nfast := 36%N.
Definition F17: Nfast := 17%N. Definition F37: Nfast := 37%N.
Definition F18: Nfast := 18%N. Definition F38: Nfast := 38%N.
Definition F19: Nfast := 19%N. Definition F39: Nfast := 39%N.
Definition F64: Nfast := 64%N.
Definition F255 : Nfast := 255%N.

Definition Nfastones (n:Nfast) : Nfast :=
  Nfast_of_Zfast (Zfastsub (Zfastshl F1 (Zfast_of_Nfast n)) F1).

Lemma Nfastones_Nones:
  ∀ (n:Nfast), Nfastones n = N.ones n.
Proof.
  intros.
  unfold Nfastones, N.ones, Nfast_of_Zfast, Zfastsub, Zfastshl, Zfast_of_Nfast, Zunwrap.
  f_equal. apply N2Z.inj.
  rewrite N.shiftl_mul_pow2, Z.shiftl_mul_pow2, N2Z.inj_pred_max,
          N2Z.inj_mul, N2Z.inj_pow by Psatz.lia.
  simpl. destruct (2^Z.of_N n) as [|[]|]; auto.
Qed.

Definition Nfastldiff (a b:Nfast) : Nfast :=
  Nfast_of_Zfast (Zfastland (Zfast_of_Nfast a) (Zfastlnot (Zfast_of_Nfast b))).

Lemma Nfastldiff_Nldiff:
  ∀ (a b:Nfast), Nfastldiff a b = N.ldiff a b.
Proof.
  intros [[|a]] [[|b]]; auto.
  unfold Nfastldiff, Zfastlnot, Zfastland, Zfast_of_Nfast, Nfast_of_Zfast. simpl.
  rewrite N2Z.id, Pos.add_comm. simpl. destruct b; auto. simpl.
  rewrite Pos.pred_double_succ. auto.
Qed.

Function eucl_alg (a b:Nfast) {measure N.to_nat b} : Nfast * Zfast * Zfast :=
  if Nfasteqb b F0 then (a, F1:Zfast, F0:Zfast)
  else
    let q := Nfastdiv a b in
    let r := Nfastmod a b in
    let '(g, bb, rr) := eucl_alg b r in
    (g, rr, Zfastsub bb (Zfastmul rr (Zfast_of_Nfast q))).
Proof.
  intros. rewrite Nfasteqb_Neqb in teq.
  simpl in teq. apply N.eqb_neq in teq.
  pose proof N.mod_upper_bound a b teq.
  unfold Nfastmod. simpl in *. Psatz.lia.
Defined.

Lemma eucl_alg_bezout :
  ∀ (a b:Nfast) g aa bb, eucl_alg a b = (g, aa, bb) ->
    Z.of_N g = (aa * Z.of_N a + bb * Z.of_N b)%Z.
Proof.
  intros a [b] g. revert a. rewrite <- (Nnat.N2Nat.id b).
  induction (N.to_nat b) using lt_wf_ind. intros.
  rewrite eucl_alg_equation, Nfasteqb_Neqb in H0.
  simpl in H0. destruct (N.eqb_spec (N.of_nat n) 0).
  - inv H0. simpl Zunwrap. simpl Nunwrap. destruct n; inv e. ring.
  - assert (N.to_nat (Nfastmod a (N.of_nat n)) < n)%nat.
    { pose proof N.mod_upper_bound a (N.of_nat n) n0.
      unfold Nfastmod. simpl in *. Psatz.lia. }
    specialize (H _ H1 (N.of_nat n)). unfold Nfastmod in *. simpl Nunwrap in *.
    rewrite Nnat.N2Nat.id in H.
    destruct eucl_alg as [[g' aa'] bb']. inv H0. erewrite H by reflexivity.
    unfold Zfastsub, Zfastmul, Zfast_of_Nfast, Nfastdiv.
    pose proof N.div_mod' a (N.of_nat n). apply (f_equal Z.of_N) in H0.
    rewrite H0. simpl. Psatz.lia.
Qed.

Definition Nfastgcd (a b:Nfast) : Nfast := fst (fst (eucl_alg a b)).

Lemma Nfastgcd_Ngcd:
  ∀ (a b:Nfast), Nfastgcd a b = N.gcd a b.
Proof.
  intros. unfold Nfastgcd. destruct (eucl_alg a b) as [[g' aa'] bb'] eqn:EQ.
  symmetry. simpl. destruct g' as [g'].
  f_equal. apply N.gcd_unique_alt. Psatz.lia. split; intro.
  - destruct a as [a], b as [b]. revert a aa' bb' EQ. rewrite <- (Nnat.N2Nat.id b).
    induction (N.to_nat b) using lt_wf_ind. intros.
    rewrite eucl_alg_equation in EQ. rewrite Nfasteqb_Neqb in EQ. simpl in EQ.
    destruct (N.eqb_spec (N.of_nat n) 0).
    + inv EQ. split. auto. exists 0%N. auto.
    + destruct (eucl_alg (N.of_nat n) (Nfastmod a (N.of_nat n))) as [[g aa] bb] eqn:EQ'.
      inv EQ.
      assert (N.to_nat (Nfastmod a (N.of_nat n)) < n)%nat.
      { pose proof N.mod_upper_bound a (N.of_nat n) n0.
        unfold Nfastmod. simpl in *. Psatz.lia. }
      specialize (H0 _ H1 (N.of_nat n)). unfold Nfastmod in *. simpl Nunwrap in *.
      rewrite Nnat.N2Nat.id in H0. specialize (H0 _ _ EQ').
      split. 2:apply H0. rewrite (N.div_mod' a (N.of_nat n)).
      apply N.divide_add_r. 2:apply H0. apply N.divide_mul_l, H0.
  - destruct H. pose proof eucl_alg_bezout _ _ _ _ _ EQ.
    destruct H, H0. apply (f_equal Z.to_N) in H1.
    rewrite H, H0, !N2Z.inj_mul, !Z.mul_assoc, <- Z.mul_add_distr_r, N2Z.id in H1.
    destruct (Z.leb_spec 0 (aa' * Z.of_N x + bb' * Z.of_N x0)).
    + rewrite Z2N.inj_mul, N2Z.id in H1 by Psatz.lia. eexists. eauto.
    + simpl in H1. rewrite H1. destruct Z.add; try discriminate.
      destruct q; simpl; exists 0%N; auto.
Qed.

Ltac fastunwrap :=
  repeat match goal with
  | H: context [Zunwrap (Zwrap ?A)] |- _ => change (Zunwrap (Zwrap A)) with A in H
  | |- context [Zunwrap (Zwrap ?A)] => change (Zunwrap (Zwrap A)) with A
  | H: context [Zfastleb ?a ?b] |- _ => change (Zfastleb a b) with (Z.leb a b) in H
  | |- context [Zfastleb ?a ?b] => change (Zfastleb a b) with (Z.leb a b)
  | |- _ => progress rewrite ?Zfastmin_Zmin, ?Zfastmax_Zmax in *
  | |- _ => rewrite Zfasteqb_Zeqb in *
  | H: context [Nunwrap (Nwrap ?A)] |- _ => change (Nunwrap (Nwrap A)) with A in H
  | |- context [Nunwrap (Nwrap ?A)] => change (Nunwrap (Nwrap A)) with A
  | H: context [Nfastleb ?a ?b] |- _ => change (Nfastleb a b) with (N.leb a b) in H
  | |- context [Nfastleb ?a ?b] => change (Nfastleb a b) with (N.leb a b)
  | |- _ => progress rewrite ?Nfastmin_Nmin, ?Nfastmax_Nmax in *
  | |- _ => progress rewrite ?Nfasteqb_Neqb, ?Nfastgcd_Ngcd in *
  | |- _ => rewrite Nfastldiff_Nldiff in *
  | |- _ => rewrite Nfastones_Nones in *
  | |- _ =>
    progress unfold Zfastcompare, Zfastopp, Zfastabs, Zfastadd, Zfastsub,
    Zfastmul, Zfastlnot, Zfastland, Zfastlor, Zfastlxor, Zfastshl, Zfastshr,
    Zfastquot, Zfastdiv, Zfastmod,
    Nfastadd, Nfastmul, Nfastdiv, Nfastmod, Nfastland, Nfastlor, Nfastlxor,
    Nfastshl, Nfastshr, Zfast_of_Nfast, Nfast_of_Zfast,
     Fm1, F0, F1, F2, F3, F4, F5, F6, F7, F8, F9,
     F10, F11, F12, F13, F14, F15, F16, F17, F18, F19,
     F20, F21, F22, F23, F24, F25, F26, F27, F28, F29,
     F64, F255 in *
  end.