feat: Fix elab permutations
This commit is contained in:
parent
b792be4423
commit
a6cd796df2
1 changed files with 31 additions and 10 deletions
|
@ -109,7 +109,7 @@ syntax tensorExpr "+" tensorExpr : tensorExpr
|
|||
syntax "(" tensorExpr ")" : tensorExpr
|
||||
|
||||
/-- Scalar multiplication for tensors. -/
|
||||
syntax term "•" tensorExpr : tensorExpr
|
||||
syntax term "•ₜ" tensorExpr : tensorExpr
|
||||
|
||||
/-- Negation of a tensor tree. -/
|
||||
syntax "-" tensorExpr : tensorExpr
|
||||
|
@ -207,6 +207,16 @@ def specialTypes : List (String × (Term → Term)) := [
|
|||
Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[
|
||||
mkIdent ``Fermion.complexLorentzTensor, mkIdent ``Fermion.Color.up,
|
||||
mkIdent ``Fermion.Color.upL,
|
||||
mkIdent ``Fermion.Color.upR, T]),
|
||||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.leftHanded ⊗ Fermion.leftHanded", fun T =>
|
||||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||||
mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.upL,
|
||||
mkIdent ``Fermion.Color.upL, T]),
|
||||
("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.rightHanded ⊗ Fermion.rightHanded", fun T =>
|
||||
Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[
|
||||
mkIdent ``Fermion.complexLorentzTensor,
|
||||
mkIdent ``Fermion.Color.upR,
|
||||
mkIdent ``Fermion.Color.upR, T])]
|
||||
|
||||
/-- The syntax associated with a terminal node of a tensor tree. -/
|
||||
|
@ -266,8 +276,7 @@ partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
|||
return filt
|
||||
|
||||
/-- The list of indices after contraction or evaluation. -/
|
||||
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
let ind ← getIndices stx
|
||||
def withoutContr (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return indFilt.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
|
@ -312,10 +321,10 @@ For a product node we can take the tensor product, and then contract the indices
|
|||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
return (← TensorNode.withoutContr stx)
|
||||
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
|
||||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||||
let indicesA ← getIndices a
|
||||
let indicesB ← getIndices b
|
||||
let indicesA ← TensorNode.withoutContr (← getIndices a)
|
||||
let indicesB ← TensorNode.withoutContr (← getIndices b)
|
||||
return indicesA ++ indicesB
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← getIndices a)
|
||||
|
@ -361,8 +370,9 @@ def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List (ℕ))
|
|||
return l2''.map fun x => x.1
|
||||
|
||||
/-- Takes two maps `Fin n → Fin n` and returns the equivelance they form. -/
|
||||
def finMapToEquiv (f1 f2 : Fin n → Fin n) (h : ∀ x, f1 (f2 x) = x := by decide)
|
||||
(h' : ∀ x, f2 (f1 x) = x := by decide) : Fin n ≃ Fin n where
|
||||
def finMapToEquiv (f1 : Fin n → Fin m) (f2 : Fin m → Fin n)
|
||||
(h : ∀ x, f1 (f2 x) = x := by decide)
|
||||
(h' : ∀ x, f2 (f1 x) = x := by decide) : Fin n ≃ Fin m where
|
||||
toFun := f1
|
||||
invFun := f2
|
||||
left_inv := h'
|
||||
|
@ -371,7 +381,7 @@ def finMapToEquiv (f1 f2 : Fin n → Fin n) (h : ∀ x, f1 (f2 x) = x := by deci
|
|||
/-- Given two lists of indices returns the permutation between them based on `finMapToEquiv`. -/
|
||||
def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do
|
||||
let lPerm ← getPermutation l1 l2
|
||||
let l2Perm ← getPermutation l1 l2
|
||||
let l2Perm ← getPermutation l2 l1
|
||||
let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]"
|
||||
let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]"
|
||||
let P1 ← TensorNode.stringToTerm permString
|
||||
|
@ -391,18 +401,27 @@ end negNode
|
|||
partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
return (← TensorNode.withoutContr stx)
|
||||
return (← TensorNode.withoutContr (← TensorNode.getIndices stx))
|
||||
| `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do
|
||||
return (← ProdNode.withoutContr stx)
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← getIndicesFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| `(tensorExpr| $_:term •ₜ $a) => do
|
||||
return (← getIndicesFull a)
|
||||
| `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do
|
||||
return (← getIndicesFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
namespace SMul
|
||||
|
||||
def smulSyntax (c T : Term) : Term :=
|
||||
Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T]
|
||||
|
||||
end SMul
|
||||
|
||||
namespace Add
|
||||
|
||||
/-- Gets the indices associated with the LHS of an addition. -/
|
||||
|
@ -483,6 +502,8 @@ partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
|||
return (← syntaxFull a)
|
||||
| `(tensorExpr| -$a:tensorExpr) => do
|
||||
return negNode.negSyntax (← syntaxFull a)
|
||||
| `(tensorExpr| $c:term •ₜ $a:tensorExpr) => do
|
||||
return SMul.smulSyntax c (← syntaxFull a)
|
||||
| `(tensorExpr| $a + $b) => do
|
||||
let indicesLeft ← Add.getIndicesLeft stx
|
||||
let indicesRight ← Add.getIndicesRight stx
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue