feat: Fix elab permutations

This commit is contained in:
jstoobysmith 2024-10-22 16:43:32 +00:00
parent b792be4423
commit a6cd796df2

View file

@ -109,7 +109,7 @@ syntax tensorExpr "+" tensorExpr : tensorExpr
syntax "(" tensorExpr ")" : tensorExpr
/-- Scalar multiplication for tensors. -/
syntax term "•" tensorExpr : tensorExpr
syntax term "•" tensorExpr : tensorExpr
/-- Negation of a tensor tree. -/
syntax "-" tensorExpr : tensorExpr
@ -207,6 +207,16 @@ def specialTypes : List (String × (Term → Term)) := [
Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.up,
mkIdent ``Fermion.Color.upL,
mkIdent ``Fermion.Color.upR, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.leftHanded ⊗ Fermion.leftHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.upL,
mkIdent ``Fermion.Color.upL, T]),
("𝟙_ (Rep SL(2, )) ⟶ Fermion.rightHanded ⊗ Fermion.rightHanded", fun T =>
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
mkIdent ``Fermion.complexLorentzTensor,
mkIdent ``Fermion.Color.upR,
mkIdent ``Fermion.Color.upR, T])]
/-- The syntax associated with a terminal node of a tensor tree. -/
@ -266,8 +276,7 @@ partial def getContrPos (stx : Syntax) : TermElabM (List ( × )) := do
return filt
/-- The list of indices after contraction or evaluation. -/
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
let ind ← getIndices stx
def withoutContr (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
return indFilt.filter (fun x => indFilt.count x ≤ 1)
@ -312,10 +321,10 @@ For a product node we can take the tensor product, and then contract the indices
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
return (← TensorNode.withoutContr stx)
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
let indicesA ← getIndices a
let indicesB ← getIndices b
let indicesA ← TensorNode.withoutContr (← getIndices a)
let indicesB ← TensorNode.withoutContr (← getIndices b)
return indicesA ++ indicesB
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndices a)
@ -361,8 +370,9 @@ def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List ())
return l2''.map fun x => x.1
/-- Takes two maps `Fin n → Fin n` and returns the equivelance they form. -/
def finMapToEquiv (f1 f2 : Fin n → Fin n) (h : ∀ x, f1 (f2 x) = x := by decide)
(h' : ∀ x, f2 (f1 x) = x := by decide) : Fin n ≃ Fin n where
def finMapToEquiv (f1 : Fin n → Fin m) (f2 : Fin m → Fin n)
(h : ∀ x, f1 (f2 x) = x := by decide)
(h' : ∀ x, f2 (f1 x) = x := by decide) : Fin n ≃ Fin m where
toFun := f1
invFun := f2
left_inv := h'
@ -371,7 +381,7 @@ def finMapToEquiv (f1 f2 : Fin n → Fin n) (h : ∀ x, f1 (f2 x) = x := by deci
/-- Given two lists of indices returns the permutation between them based on `finMapToEquiv`. -/
def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do
let lPerm ← getPermutation l1 l2
let l2Perm ← getPermutation l1 l2
let l2Perm ← getPermutation l2 l1
let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]"
let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]"
let P1 ← TensorNode.stringToTerm permString
@ -391,18 +401,27 @@ end negNode
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
return (← TensorNode.withoutContr stx)
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
return (← ProdNode.withoutContr stx)
| `(tensorExpr| ($a:tensorExpr)) => do
return (← getIndicesFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return (← getIndicesFull a)
| `(tensorExpr| $_:term •ₜ $a) => do
return (← getIndicesFull a)
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
return (← getIndicesFull a)
| _ =>
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
namespace SMul
def smulSyntax (c T : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T]
end SMul
namespace Add
/-- Gets the indices associated with the LHS of an addition. -/
@ -483,6 +502,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
return (← syntaxFull a)
| `(tensorExpr| -$a:tensorExpr) => do
return negNode.negSyntax (← syntaxFull a)
| `(tensorExpr| $c:term •ₜ $a:tensorExpr) => do
return SMul.smulSyntax c (← syntaxFull a)
| `(tensorExpr| $a + $b) => do
let indicesLeft ← Add.getIndicesLeft stx
let indicesRight ← Add.getIndicesRight stx