refactor: Add checks

This commit is contained in:
jstoobysmith 2024-10-08 11:55:06 +00:00
parent 1f3a0dd2b6
commit 2e8e32df19

View file

@ -192,10 +192,14 @@ def contrSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
/-- Creates the syntax associated with a tensor node. -/
def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $T:term | $[$args]*) => do
noIndexCheck T
let indices ← getIndices stx
let rawIndex ← getNoIndicesExact T
if indices.length ≠ rawIndex then
throwError "The number of indices does not match the tensor {T}."
let tensorNodeSyntax := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T]
let evalSyntax := evalSyntax (← getEvalPos stx) tensorNodeSyntax
let dualSyntax := dualSyntax (← getDualPos stx) evalSyntax
@ -216,6 +220,7 @@ For a product node we can take the tensor product, and then contract the indices
-/
/-- Gets the indices associated with a product node. -/
partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => do
@ -252,9 +257,11 @@ def contrSyntax (l : List ( × )) (T : Term) : Term :=
l.foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr)
#[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x0), T']) T
/-- The syntax associated with a product of tensors. -/
def prodSyntax (T1 T2 : Term) : Term :=
Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2]
/-- The full term taking tensor syntax into a term for products and single tensor nodes. -/
partial def syntaxFull (stx : Syntax) : TermElabM Term := do
match stx with
| `(tensorExpr| $_:term | $[$args]*) => TensorNode.syntaxFull stx
@ -284,11 +291,16 @@ elab_rules (kind:=tensorExprSyntax) : term
variable {S : TensorStruct} {c4 : Fin 4 → S.C} (T4 : S.F.obj (OverColor.mk c4))
{c5 : Fin 5 → S.C} (T5 : S.F.obj (OverColor.mk c5))
/-!
example : {T4 | i j}ᵀ = TensorTree.tensorNode T4 := by rfl
# Checks
#check {T4 | i j l ⊗ T5 | i j k }ᵀ
-/
/-
example : {T4 | i j k m}ᵀ = TensorTree.tensorNode T4 := by rfl
#check {(T4 | i j l ⊗ T5 | i j k) ⊗ T5 | i1 i2 i3}ᵀ
#check {T4 | i j l d ⊗ T5 | i j k a b}ᵀ
#check {(T4 | i j l a ⊗ T5 | i j k c d) ⊗ T5 | i1 i2 i3 e f}ᵀ
-/
end ProdNode