feat: Add elab for prod of tensor in index notation
This commit is contained in:
parent
5a34238499
commit
1f3a0dd2b6
1 changed files with 114 additions and 11 deletions
|
@ -20,6 +20,7 @@ open Lean.Meta
|
|||
open Lean.Elab
|
||||
open Lean.Elab.Term
|
||||
open Lean Meta Elab Tactic
|
||||
open IndexNotation
|
||||
|
||||
/-!
|
||||
|
||||
|
@ -93,9 +94,15 @@ syntax term "|" (ppSpace indexExpr)* : tensorExpr
|
|||
/-- The syntax for tensor prod two tensor nodes. -/
|
||||
syntax tensorExpr "⊗" tensorExpr : tensorExpr
|
||||
|
||||
/-- The syntax for tensor addition. -/
|
||||
syntax tensorExpr "+" tensorExpr : tensorExpr
|
||||
|
||||
/-- Allowing brackets to be used in a tensor expression. -/
|
||||
syntax "(" tensorExpr ")" : tensorExpr
|
||||
|
||||
|
||||
namespace TensorNode
|
||||
|
||||
/-!
|
||||
|
||||
## For tensor nodes.
|
||||
|
@ -104,12 +111,14 @@ The operations are done in the following order:
|
|||
- evaluation.
|
||||
- dualization.
|
||||
- contraction.
|
||||
|
||||
We also want to ensure the number of indices is correct.
|
||||
|
||||
-/
|
||||
|
||||
namespace TensorNode
|
||||
|
||||
/-- The indices of a tensor node. Before contraction, dualisation, and evaluation. -/
|
||||
partial def getIndicesNode (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
let indices ← args.toList.mapM fun arg => do
|
||||
|
@ -119,9 +128,25 @@ partial def getIndicesNode (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)
|
|||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesNode: {stx}"
|
||||
|
||||
/-- Uses the structure of the tensor to get the number of indices. -/
|
||||
def getNoIndicesExact (stx : Syntax) : TermElabM ℕ := do
|
||||
let expr ← elabTerm stx none
|
||||
let type ← inferType expr
|
||||
match type with
|
||||
| Expr.app _ (Expr.app _ (Expr.app _ c)) =>
|
||||
let typeC ← inferType c
|
||||
match typeC with
|
||||
| Expr.forallE _ (Expr.app _ (Expr.app (Expr.app _ (Expr.lit (Literal.natVal n))) _)) _ _ =>
|
||||
return n
|
||||
| _ => throwError "Could not extract number of indices from tensor (getNoIndicesExact). "
|
||||
| _ =>
|
||||
throwError "Could not extract number of indices from tensor (getNoIndicesExact)."
|
||||
|
||||
|
||||
|
||||
/-- The positions in getIndicesNode which get evaluated, and the value they take. -/
|
||||
partial def getEvalPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let ind ← getIndices stx
|
||||
let indEnum := ind.enum
|
||||
let evals := indEnum.filter (fun x => indexExprIsNum x.2)
|
||||
let evals2 ← (evals.mapM (fun x => indexToNum x.2))
|
||||
|
@ -134,7 +159,7 @@ def evalSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
|||
|
||||
/-- The positions in getIndicesNode which get dualized. -/
|
||||
partial def getDualPos (stx : Syntax) : TermElabM (List ℕ) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
let indEnum := indFilt.enum
|
||||
let duals := indEnum.filter (fun x => indexToDual x.2)
|
||||
|
@ -147,7 +172,7 @@ def dualSyntax (l : List ℕ) (T : Term) : Term :=
|
|||
|
||||
/-- The pairs of positions in getIndicesNode which get contracted. -/
|
||||
partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
let indEnum := indFilt.enum
|
||||
let bind := List.bind indEnum (fun a => indEnum.map (fun b => (a, b)))
|
||||
|
@ -158,7 +183,7 @@ partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
|||
|
||||
/-- The list of indices after contraction. -/
|
||||
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
let ind ← getIndicesNode stx
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
|
@ -167,19 +192,87 @@ def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
|||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
|
||||
|
||||
/-- An elaborator for tensor nodes. This is to be generalized. -/
|
||||
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
|
||||
def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $T:term | $[$args]*) => do
|
||||
noIndexCheck T
|
||||
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
|
||||
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
|
||||
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) dualSyntax
|
||||
let tensorExpr ← elabTerm contrSyntax none
|
||||
return tensorExpr
|
||||
return contrSyntax
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
||||
end TensorNode
|
||||
|
||||
namespace ProdNode
|
||||
|
||||
/-!
|
||||
|
||||
## For product nodes.
|
||||
|
||||
For a product node we can take the tensor product, and then contract the indices.
|
||||
|
||||
-/
|
||||
|
||||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
return (← TensorNode.withoutContr stx)
|
||||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||||
let indicesA ← getIndices a
|
||||
let indicesB ← getIndices b
|
||||
return indicesA ++ indicesB
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← getIndices a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}"
|
||||
|
||||
|
||||
/-- The pairs of positions in getIndicesNode which get contracted. -/
|
||||
partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
let indEnum := indFilt.enum
|
||||
let bind := List.bind indEnum (fun a => indEnum.map (fun b => (a, b)))
|
||||
let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2)
|
||||
if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then
|
||||
throwError "To many contractions"
|
||||
return filt
|
||||
|
||||
/-- The list of indices after contraction. -/
|
||||
def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
let ind ← getIndices stx
|
||||
let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x)
|
||||
return ind.filter (fun x => indFilt.count x ≤ 1)
|
||||
|
||||
/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/
|
||||
def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
||||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
|
||||
|
||||
def prodSyntax (T1 T2 : Term) : Term :=
|
||||
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
|
||||
|
||||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => TensorNode.syntaxFull stx
|
||||
| `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do
|
||||
let prodSyntax := prodSyntax (← syntaxFull a) (← syntaxFull b)
|
||||
let contrSyntax := contrSyntax (← getContrPos stx) prodSyntax
|
||||
return contrSyntax
|
||||
| `(tensorExpr| ($a:tensorExpr)) => do
|
||||
return (← syntaxFull a)
|
||||
| _ =>
|
||||
throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}"
|
||||
|
||||
|
||||
/-- An elaborator for tensor nodes. This is to be generalized. -/
|
||||
def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do
|
||||
let tensorExpr ← elabTerm (← syntaxFull stx) none
|
||||
return tensorExpr
|
||||
|
||||
/-- Syntax turning a tensor expression into a term. -/
|
||||
syntax (name := tensorExprSyntax) "{" tensorExpr "}ᵀ" : term
|
||||
|
||||
|
@ -188,4 +281,14 @@ elab_rules (kind:=tensorExprSyntax) : term
|
|||
let tensorTree ← elaborateTensorNode e
|
||||
return tensorTree
|
||||
|
||||
end TensorNode
|
||||
variable {S : TensorStruct} {c4 : Fin 4 → S.C} (T4 : S.F.obj (OverColor.mk c4))
|
||||
{c5 : Fin 5 → S.C} (T5 : S.F.obj (OverColor.mk c5))
|
||||
|
||||
|
||||
example : {T4 | i j}ᵀ = TensorTree.tensorNode T4 := by rfl
|
||||
|
||||
#check {T4 | i j l ⊗ T5 | i j k }ᵀ
|
||||
|
||||
#check {(T4 | i j l ⊗ T5 | i j k) ⊗ T5 | i1 i2 i3}ᵀ
|
||||
|
||||
end ProdNode
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue