PhysLean/HepLean/Tensors/IndexNotation/IndexList/Basic.lean
2024-10-03 13:50:18 +00:00

353 lines
12 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2024 Joseph Tooby-Smith. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joseph Tooby-Smith
-/
import Mathlib.Data.Set.Finite
import Mathlib.Logic.Equiv.Fin
import Mathlib.Data.Finset.Sort
import HepLean.Tensors.IndexNotation.Basic
/-!
# Index lists
i.e. lists of indices.
-/
namespace IndexNotation
variable (X : Type) [IndexNotation X]
variable [Fintype X] [DecidableEq X]
/-- The type of lists of indices. -/
structure IndexList where
/-- The list of index values. For example `['ᵘ¹','ᵘ²','ᵤ₁']`. -/
val : List (Index X)
namespace IndexList
variable {X : Type} [IndexNotation X] [Fintype X] [DecidableEq X]
variable (l : IndexList X)
/-- The number of indices in an index list. -/
def length : := l.val.length
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
lemma ext (h : l.val = l2.val) : l = l2 := by
cases l
cases l2
simp_all
/-- The index list constructed by prepending an index to the list. -/
def cons (i : Index X) : IndexList X := {val := i :: l.val}
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
@[simp]
lemma cons_val (i : Index X) : (l.cons i).val = i :: l.val := by
rfl
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
@[simp]
lemma cons_length (i : Index X) : (l.cons i).length = l.length + 1 := by
rfl
/-- The tail of an index list. That is, the index list with the first index dropped. -/
def tail : IndexList X := {val := l.val.tail}
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
@[simp]
lemma tail_val : l.tail.val = l.val.tail := by
rfl
/-- The first index in a non-empty index list. -/
def head (h : l ≠ {val := ∅}) : Index X := l.val.head (by cases' l; simpa using h)
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
lemma head_cons_tail (h : l ≠ {val := ∅}) : l = (l.tail.cons (l.head h)) := by
apply ext
simp only [cons_val, tail_val]
simp only [head, List.head_cons_tail]
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
lemma induction {P : IndexList X → Prop } (h_nil : P {val := ∅})
(h_cons : ∀ (x : Index X) (xs : IndexList X), P xs → P (xs.cons x)) (l : IndexList X) : P l := by
cases' l with val
induction val with
| nil => exact h_nil
| cons x xs ih =>
exact h_cons x ⟨xs⟩ ih
/-- The map of from `Fin s.numIndices` into colors associated to an index list. -/
def colorMap : Fin l.length → X :=
fun i => (l.val.get i).toColor
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
lemma colorMap_cast {l1 l2 : IndexList X} (h : l1 = l2) :
l1.colorMap = l2.colorMap ∘ Fin.cast (congrArg length h) := by
subst h
rfl
/-- The map of from `Fin s.numIndices` into the natural numbers associated to an index list. -/
def idMap : Fin l.length → Nat :=
fun i => (l.val.get i).id
omit [IndexNotation X] [Fintype X] [DecidableEq X]
lemma idMap_cast {l1 l2 : IndexList X} (h : l1 = l2) (i : Fin l1.length) :
l1.idMap i = l2.idMap (Fin.cast (by rw [h]) i) := by
subst h
rfl
lemma ext_colorMap_idMap {l l2 : IndexList X} (hl : l.length = l2.length)
(hi : l.idMap = l2.idMap ∘ Fin.cast hl) (hc : l.colorMap = l2.colorMap ∘ Fin.cast hl) :
l = l2 := by
apply ext
refine List.ext_get hl ?h.h
intro n h1 h2
rw [Index.eq_iff_color_eq_and_id_eq]
apply And.intro
· trans l.colorMap ⟨n, h1⟩
· rfl
· rw [hc]
rfl
· trans l.idMap ⟨n, h1⟩
· rfl
· rw [hi]
rfl
/-- Given a list of indices a subset of `Fin l.numIndices × Index X`
of pairs of positions in `l` and the corresponding item in `l`. -/
def toPosSet (l : IndexList X) : Set (Fin l.length × Index X) :=
{(i, l.val.get i) | i : Fin l.length}
/-- Equivalence between `toPosSet` and `Fin l.numIndices`. -/
def toPosSetEquiv (l : IndexList X) : l.toPosSet ≃ Fin l.length where
toFun := fun x => x.1.1
invFun := fun x => ⟨(x, l.val.get x), by simp [toPosSet]⟩
left_inv x := by
have hx := x.prop
simp only [toPosSet, List.get_eq_getElem, Set.mem_setOf_eq] at hx
simp only [List.get_eq_getElem]
obtain ⟨i, hi⟩ := hx
have hi2 : i = x.1.1 := by
obtain ⟨val, property⟩ := x
obtain ⟨fst, snd⟩ := val
simp_all only [Prod.mk.injEq]
subst hi2
simp_all only [Subtype.coe_eta]
right_inv := by
intro x
rfl
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
lemma toPosSet_is_finite (l : IndexList X) : l.toPosSet.Finite :=
Finite.intro l.toPosSetEquiv
instance : Fintype l.toPosSet where
elems := Finset.map l.toPosSetEquiv.symm.toEmbedding Finset.univ
complete := by
intro x
simp_all only [Finset.mem_map_equiv, Equiv.symm_symm, Finset.mem_univ]
/-- Given a list of indices a finite set of `Fin l.length × Index X`
of pairs of positions in `l` and the corresponding item in `l`. -/
def toPosFinset (l : IndexList X) : Finset (Fin l.length × Index X) :=
l.toPosSet.toFinset
/-- The construction of a list of indices from a map
from `Fin n` to `Index X`. -/
def fromFinMap {n : } (f : Fin n → Index X) : IndexList X where
val := (Fin.list n).map f
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
@[simp]
lemma fromFinMap_numIndices {n : } (f : Fin n → Index X) :
(fromFinMap f).length = n := by
simp [fromFinMap, length]
/-!
## Appending index lists.
-/
section append
variable {X : Type} [IndexNotation X] [Fintype X] [DecidableEq X]
variable (l l2 l3 : IndexList X)
instance : HAppend (IndexList X) (IndexList X) (IndexList X) where
hAppend := fun l l2 => {val := l.val ++ l2.val}
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
@[simp]
lemma cons_append (i : Index X) : (l.cons i) ++ l2 = (l ++ l2).cons i := by
rfl
omit [IndexNotation X] [Fintype X] [DecidableEq X]
@[simp]
lemma append_length : (l ++ l2).length = l.length + l2.length := by
simpa only [length] using List.length_append l.val l2.val
lemma append_assoc : l ++ l2 ++ l3 = l ++ (l2 ++ l3) := by
apply ext
change l.val ++ l2.val ++ l3.val = l.val ++ (l2.val ++ l3.val)
exact List.append_assoc l.val l2.val l3.val
/-- An equivalence between the sum of the types of indices of `l` an `l2` and the type
of indices of the joined index list `l ++ l2`. -/
def appendEquiv {l l2 : IndexList X} : Fin l.length ⊕ Fin l2.length ≃ Fin (l ++ l2).length :=
finSumFinEquiv.trans (Fin.castOrderIso (List.length_append _ _).symm).toEquiv
/-- The inclusion of the indices of `l` into the indices of `l ++ l2`. -/
def appendInl : Fin l.length ↪ Fin (l ++ l2).length where
toFun := appendEquiv ∘ Sum.inl
inj' := by
intro i j h
simp only [Function.comp, EmbeddingLike.apply_eq_iff_eq, Sum.inl.injEq] at h
exact h
/-- The inclusion of the indices of `l2` into the indices of `l ++ l2`. -/
def appendInr : Fin l2.length ↪ Fin (l ++ l2).length where
toFun := appendEquiv ∘ Sum.inr
inj' i j h := by
simpa only [Function.comp, EmbeddingLike.apply_eq_iff_eq, Sum.inr.injEq] using h
@[simp]
lemma appendInl_appendEquiv :
(l.appendInl l2).trans appendEquiv.symm.toEmbedding =
{toFun := Sum.inl, inj' := Sum.inl_injective} := by
ext i
simp [appendInl]
@[simp]
lemma appendInr_appendEquiv :
(l.appendInr l2).trans appendEquiv.symm.toEmbedding =
{toFun := Sum.inr, inj' := Sum.inr_injective} := by
ext i
simp [appendInr]
@[simp]
lemma append_val {l l2 : IndexList X} : (l ++ l2).val = l.val ++ l2.val := by
rfl
@[simp]
lemma idMap_append_inl {l l2 : IndexList X} (i : Fin l.length) :
(l ++ l2).idMap (appendEquiv (Sum.inl i)) = l.idMap i := by
simp only [idMap, append_val, appendEquiv, Equiv.trans_apply, finSumFinEquiv_apply_left,
List.get_eq_getElem]
rw [List.getElem_append_left]
rfl
@[simp]
lemma idMap_append_inr {l l2 : IndexList X} (i : Fin l2.length) :
(l ++ l2).idMap (appendEquiv (Sum.inr i)) = l2.idMap i := by
simp only [idMap, append_val, length, appendEquiv, Equiv.trans_apply, finSumFinEquiv_apply_right,
RelIso.coe_fn_toEquiv, Fin.castOrderIso_apply, List.get_eq_getElem, Fin.coe_cast,
Fin.coe_natAdd]
rw [List.getElem_append_right]
· simp only [Nat.add_sub_cancel_left]
· omega
· omega
@[simp]
lemma colorMap_append_inl {l l2 : IndexList X} (i : Fin l.length) :
(l ++ l2).colorMap (appendEquiv (Sum.inl i)) = l.colorMap i := by
simp only [colorMap, append_val, length, appendEquiv, Equiv.trans_apply,
finSumFinEquiv_apply_left, RelIso.coe_fn_toEquiv, Fin.castOrderIso_apply, List.get_eq_getElem,
Fin.coe_cast, Fin.coe_castAdd]
rw [List.getElem_append_left]
@[simp]
lemma colorMap_append_inl' :
(l ++ l2).colorMap ∘ appendEquiv ∘ Sum.inl = l.colorMap := by
funext i
simp
@[simp]
lemma colorMap_append_inr {l l2 : IndexList X} (i : Fin l2.length) :
(l ++ l2).colorMap (appendEquiv (Sum.inr i)) = l2.colorMap i := by
simp only [colorMap, append_val, length, appendEquiv, Equiv.trans_apply,
finSumFinEquiv_apply_right, RelIso.coe_fn_toEquiv, Fin.castOrderIso_apply, List.get_eq_getElem,
Fin.coe_cast, Fin.coe_natAdd]
rw [List.getElem_append_right]
· simp only [Nat.add_sub_cancel_left]
· omega
· omega
@[simp]
lemma colorMap_append_inr' :
(l ++ l2).colorMap ∘ appendEquiv ∘ Sum.inr = l2.colorMap := by
funext i
simp
lemma colorMap_sumELim (l1 l2 : IndexList X) :
Sum.elim l1.colorMap l2.colorMap =
(l1 ++ l2).colorMap ∘ appendEquiv := by
funext x
match x with
| Sum.inl i => simp
| Sum.inr i => simp
end append
/-!
## Filter id
-/
/-! TODO: Replace with Mathlib lemma. -/
lemma filter_sort_comm {n : } (s : Finset (Fin n)) (p : Fin n → Prop) [DecidablePred p] :
List.filter p (Finset.sort (fun i j => i ≤ j) s) =
Finset.sort (fun i j => i ≤ j) (Finset.filter p s) := by
simp only [Finset.sort, Finset.filter]
have : ∀ (m : Multiset (Fin n)), List.filter p (Multiset.sort (fun i j => i ≤ j) m) =
Multiset.sort (fun i j => i ≤ j) (Multiset.filter p m) := by
apply Quot.ind
intro m
simp only [Multiset.quot_mk_to_coe'', Multiset.coe_sort, Multiset.filter_coe]
have h1 : List.Sorted (fun i j => i ≤ j) (List.filter (fun b => decide (p b))
(List.mergeSort' (fun i j => i ≤ j) m)) := by
simp only [List.Sorted]
rw [List.pairwise_filter, List.pairwise_iff_get]
intro i j h1 _ _
have hs : List.Sorted (fun i j => i ≤ j) (List.mergeSort' (fun i j => i ≤ j) m) := by
exact List.sorted_mergeSort' (fun i j => i ≤ j) m
simp only [List.Sorted] at hs
rw [List.pairwise_iff_get] at hs
exact hs i j h1
have hp1 : (List.mergeSort' (fun i j => i ≤ j) m).Perm m := by
exact List.perm_mergeSort' (fun i j => i ≤ j) m
have hp2 : (List.filter (fun b => decide (p b)) ((List.mergeSort' (fun i j => i ≤ j) m))).Perm
(List.filter (fun b => decide (p b)) m) := by
exact List.Perm.filter (fun b => decide (p b)) hp1
have hp3 : (List.filter (fun b => decide (p b)) m).Perm
(List.mergeSort' (fun i j => i ≤ j) (List.filter (fun b => decide (p b)) m)) := by
exact List.Perm.symm (List.perm_mergeSort' (fun i j => i ≤ j)
(List.filter (fun b => decide (p b)) m))
have hp4 := hp2.trans hp3
refine List.eq_of_perm_of_sorted hp4 h1 ?_
exact List.sorted_mergeSort' (fun i j => i ≤ j) (List.filter (fun b => decide (p b)) m)
exact this s.val
omit [IndexNotation X] [Fintype X] [DecidableEq X] in
lemma filter_id_eq_sort (i : Fin l.length) : l.val.filter (fun J => (l.val.get i).id = J.id) =
List.map l.val.get (Finset.sort (fun i j => i ≤ j)
(Finset.filter (fun j => l.idMap i = l.idMap j) Finset.univ)) := by
have h1 := (List.finRange_map_get l.val).symm
have h2 : l.val = List.map l.val.get (Finset.sort (fun i j => i ≤ j) Finset.univ) := by
nth_rewrite 1 [h1, (Fin.sort_univ l.val.length).symm]
rfl
nth_rewrite 3 [h2]
rw [List.filter_map]
apply congrArg
rw [← filter_sort_comm]
apply List.filter_congr
intro x _
rfl
end IndexList
end IndexNotation