refactor: Add checks
This commit is contained in:
parent
1f3a0dd2b6
commit
2e8e32df19
1 changed files with 16 additions and 4 deletions
|
@ -192,10 +192,14 @@ def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
|||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
|
||||
|
||||
/-- Creates the syntax associated with a tensor node. -/
|
||||
def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $T:term | $[$args]*) => do
|
||||
noIndexCheck T
|
||||
let indices ← getIndices stx
|
||||
let rawIndex ← getNoIndicesExact T
|
||||
if indices.length ≠ rawIndex then
|
||||
throwError "The number of indices does not match the tensor {T}."
|
||||
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
|
||||
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
|
||||
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
|
||||
|
@ -216,6 +220,7 @@ For a product node we can take the tensor product, and then contract the indices
|
|||
|
||||
-/
|
||||
|
||||
/-- Gets the indices associated with a product node. -/
|
||||
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => do
|
||||
|
@ -252,9 +257,11 @@ def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term :=
|
|||
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
|
||||
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
|
||||
|
||||
/-- The syntax associated with a product of tensors. -/
|
||||
def prodSyntax (T1 T2 : Term) : Term :=
|
||||
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
|
||||
|
||||
/-- The full term taking tensor syntax into a term for products and single tensor nodes. -/
|
||||
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
|
||||
match stx with
|
||||
| `(tensorExpr| $_:term | $[$args]*) => TensorNode.syntaxFull stx
|
||||
|
@ -284,11 +291,16 @@ elab_rules (kind:=tensorExprSyntax) : term
|
|||
variable {S : TensorStruct} {c4 : Fin 4 → S.C} (T4 : S.F.obj (OverColor.mk c4))
|
||||
{c5 : Fin 5 → S.C} (T5 : S.F.obj (OverColor.mk c5))
|
||||
|
||||
/-!
|
||||
|
||||
example : {T4 | i j}ᵀ = TensorTree.tensorNode T4 := by rfl
|
||||
# Checks
|
||||
|
||||
#check {T4 | i j l ⊗ T5 | i j k }ᵀ
|
||||
-/
|
||||
/-
|
||||
example : {T4 | i j k m}ᵀ = TensorTree.tensorNode T4 := by rfl
|
||||
|
||||
#check {(T4 | i j l ⊗ T5 | i j k) ⊗ T5 | i1 i2 i3}ᵀ
|
||||
#check {T4 | i j l d ⊗ T5 | i j k a b}ᵀ
|
||||
|
||||
#check {(T4 | i j l a ⊗ T5 | i j k c d) ⊗ T5 | i1 i2 i3 e f}ᵀ
|
||||
-/
|
||||
end ProdNode
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue