feat: lemmas relating to index notation
This commit is contained in:
parent
ac11a510cf
commit
672cc1ed8b
11 changed files with 510 additions and 17 deletions
|
@ -108,6 +108,9 @@ syntax "(" tensorExpr ")" : tensorExpr
|
|||
/-- Scalar multiplication for tensors. -/
|
||||
syntax term "•" tensorExpr : tensorExpr
|
||||
|
||||
/-- Equality. -/
|
||||
syntax tensorExpr "=" tensorExpr : tensorExpr
|
||||
|
||||
namespace TensorNode
|
||||
|
||||
/-!
|
||||
|
@ -160,6 +163,17 @@ def stringToType (str : String) : TermElabM Expr := do
|
|||
| Except.error _ => throwError "Could not create type from string (stringToType). "
|
||||
| Except.ok stx => elabTerm stx none
|
||||
|
||||
/-- The construction of an expression corresponding to the type of a given string once parsed. -/
|
||||
def stringToTerm (str : String) : TermElabM Term := do
|
||||
let env ← getEnv
|
||||
let stx := Parser.runParserCategory env `term str
|
||||
match stx with
|
||||
| Except.error _ => throwError "Could not create type from string (stringToType). "
|
||||
| Except.ok stx =>
|
||||
match stx with
|
||||
| `(term| $e) => return e
|
||||
|
||||
|
||||
/-- The syntax associated with a terminal node of a tensor tree. -/
|
||||
def termNodeSyntax (T : Term) : TermElabM Term := do
|
||||
let expr ← elabTerm T none
|
||||
|
@ -167,6 +181,7 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
|
|||
let strType := toString type
|
||||
let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length
|
||||
let const := (String.splitOn strType "Quiver.Hom").length
|
||||
println! "n: {n}, const: {const}"
|
||||
match n, const with
|
||||
| 1, 1 =>
|
||||
match type with
|
||||
|
@ -183,14 +198,26 @@ def termNodeSyntax (T : Term) : TermElabM Term := do
|
|||
| true => return Syntax.mkApp (mkIdent ``TensorTree.twoNodeE)
|
||||
#[mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.upL, mkIdent ``Fermion.Color.up, T]
|
||||
| _ => return Syntax.mkApp (mkIdent ``TensorTree.twoNode) #[T]
|
||||
| _ =>
|
||||
match ← isDefEq type (← stringToType "ModuleCat.carrier
|
||||
(Lorentz.complexContr ⊗ Lorentz.complexContr).V") with
|
||||
| true =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.up, mkIdent ``Fermion.Color.up, T]
|
||||
| _ =>
|
||||
match ← isDefEq type (← stringToType "ModuleCat.carrier
|
||||
(Lorentz.complexCo ⊗ Lorentz.complexCo).V") with
|
||||
| true =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.down, mkIdent ``Fermion.Color.down, T]
|
||||
| _ =>
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.twoNode) #[T]
|
||||
| 3, 1 => return Syntax.mkApp (mkIdent ``TensorTree.threeNode) #[T]
|
||||
| 1, 2 => return Syntax.mkApp (mkIdent ``TensorTree.constVecNode) #[T]
|
||||
| 2, 2 =>
|
||||
match ← isDefEq type (← stringToType
|
||||
"𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo") with
|
||||
| true =>
|
||||
println! "here"
|
||||
return Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||||
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.down,
|
||||
mkIdent ``Fermion.Color.down, T]
|
||||
|
@ -237,11 +264,25 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
|||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
def toPairs (l : List ℕ) : List (ℕ × ℕ) :=
|
||||
match l with
|
||||
| x1 :: x2 :: xs => (x1, x2) :: toPairs xs
|
||||
| [] => []
|
||||
| [x] => [(x, 0)]
|
||||
|
||||
def contrListAdjust (l : List (ℕ × ℕ)) : List (ℕ × ℕ ) :=
|
||||
let l' := l.bind (fun p => [p.1, p.2])
|
||||
let l'' := List.mapAccumr
|
||||
(fun x (prev : List ℕ) =>
|
||||
let e := prev.countP (fun y => y < x)
|
||||
(x :: prev, x - e)) l'.reverse []
|
||||
toPairs l''.2.reverse
|
||||
|
||||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
|
||||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1),
|
||||
Syntax.mkNumLit (toString x0), mkIdent ``rfl, T']) T
|
||||
(contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x0),
|
||||
Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
|
||||
|
||||
/-- Creates the syntax associated with a tensor node. -/
|
||||
def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
|
@ -250,7 +291,7 @@ def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
let indices ← getIndices stx
|
||||
let rawIndex ← getNoIndicesExact T
|
||||
if indices.length ≠ rawIndex then
|
||||
throwError "The number of indices does not match the tensor {T}."
|
||||
throwError "The expected number of indices {rawIndex} does not match the tensor {T}."
|
||||
let tensorNodeSyntax ← termNodeSyntax T
|
||||
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) evalSyntax
|
||||
|
@ -303,8 +344,8 @@ def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
|||
|
||||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
|
||||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), mkIdent ``rfl, T']) T
|
||||
(TensorNode.contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x0), Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T
|
||||
|
||||
/-- The syntax associated with a product of tensors. -/
|
||||
def prodSyntax (T1 T2 : Term) : Term :=
|
||||
|
@ -316,6 +357,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
| `(tensorExpr| $_:term | $[$args]*) => TensorNode.syntaxFull stx
|
||||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||||
let prodSyntax := prodSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
println! (← getContrPos stx)
|
||||
println! TensorNode.contrListAdjust (← getContrPos stx)
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) prodSyntax
|
||||
return contrSyntax
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
|
@ -323,6 +366,90 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
||||
end ProdNode
|
||||
|
||||
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
return (← TensorNode.withoutContr stx)
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||||
return (← ProdNode.withoutContr stx)
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
namespace Equality
|
||||
|
||||
/-!
|
||||
|
||||
## For equality.
|
||||
|
||||
-/
|
||||
|
||||
/-- Gets the indices associated with the LHS of an equality. -/
|
||||
partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
/-- Gets the indices associated with the RHS of an equality. -/
|
||||
partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List (ℕ)) := do
|
||||
let l1' ← l1.mapM (fun x => indexToIdent x)
|
||||
let l2' ← l2.mapM (fun x => indexToIdent x)
|
||||
let l1enum := l1'.enum
|
||||
let l2'' := l2'.filterMap (fun x => l1enum.find? (fun y => Lean.TSyntax.getId y.2 = Lean.TSyntax.getId x))
|
||||
return l2''.map fun x => x.1
|
||||
|
||||
def finMapToEquiv (f1 f2 : Fin n → Fin n) (h : ∀ x, f1 (f2 x) = x := by decide)
|
||||
(h' : ∀ x, f2 (f1 x) = x := by decide) : Fin n ≃ Fin n where
|
||||
toFun := f1
|
||||
invFun := f2
|
||||
left_inv := h'
|
||||
right_inv := h
|
||||
|
||||
def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do
|
||||
let lPerm ← getPermutation l1 l2
|
||||
let l2Perm ← getPermutation l1 l2
|
||||
let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]"
|
||||
let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]"
|
||||
let P1 ← TensorNode.stringToTerm permString
|
||||
let P2 ← TensorNode.stringToTerm perm2String
|
||||
let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2]
|
||||
return stx
|
||||
|
||||
def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do
|
||||
let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1]
|
||||
let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax]
|
||||
let X2' := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2]
|
||||
let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2']
|
||||
return Syntax.mkApp (mkIdent ``Eq) #[X1, X2]
|
||||
|
||||
/-- Creates the syntax associated with a tensor node. -/
|
||||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => ProdNode.syntaxFull stx
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← syntaxFull a)
|
||||
| `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do
|
||||
let indicesLeft ← getIndicesLeft stx
|
||||
let indicesRight ← getIndicesRight stx
|
||||
let permSyntax ← getPermutationSyntax indicesLeft indicesRight
|
||||
let equalSyntax ← equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
return equalSyntax
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
||||
/-- An elaborator for tensor nodes. This is to be generalized. -/
|
||||
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
|
||||
let tensorExpr ← elabTerm (← syntaxFull stx) none
|
||||
|
@ -354,6 +481,6 @@ variable (𝓣 : TensorTree S c4)
|
|||
|
||||
#check {(T4 | i j l a ⊗ T5 | i j k c d) ⊗ T5 | i1 i2 i3 e f}ᵀ
|
||||
-/
|
||||
end ProdNode
|
||||
end Equality
|
||||
|
||||
end TensorTree
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue