lpaggen/Python-Dimension-Static-Type-Checker
GitHub: lpaggen/Python-Dimension-Static-Type-Checker
基于 Z3 SMT 求解器的 PyTorch 张量维度静态检查器,在编译期验证张量形状与线性代数操作的维度正确性。
Stars: 0 | Forks: 0
# 警告:这是一个正在进行中的项目,README 的内容超前于代码的实际功能
# Torch Shape Checker (由 Z3 驱动)



本项目实现了 Microsoft 的 Z3 SMT 求解器 (https://www.microsoft.com/en-us/research/project/z3-3/)。
你是否在模型训练了几个小时后遇到过简单的 runtime 错误?如果你的错误与线性代数有关,或者是由于对 Pytorch tensor 维度的疏忽造成的,那么这个工具就是为你准备的!
Torch Shape Checker 是一个用 Python 编写的静态类型检查器,它能在编译时而非运行时检查 Pytorch tensor 的维度有效性。它通过利用 Python 3.14 的类型注解来实现这一点,所有功能都在原生 Python 3.14 上运行。因为是在原生 Python 上运行,所以在你的程序中应用此工具几乎是毫不费力的,请看下文!
## 为什么使用 Z3
Z3 是 Microsoft 的开源 SMT 求解器。SMT 求解器接收一组约束条件,并输出它们是否全部可行,令人惊讶的是,这可以直接应用于像 Python 这样的编程语言。借助 Z3,该工具让你能够在进入模型训练数小时之前,就检测出 Pytorch tensor 上的维度不匹配问题。
## 常规 Pytorch 程序示例
```
n = 13
m = 3
k = 3
A = torch.tensor([[1, 2, 3]])
B = torch.tensor([[1, 2, 3]])
C = torch.matmul(A, B)
```
## 带有注解的相同程序
```
n: int = 1
m: int = 3
k: int = 3
A: torch.Tensor[n, m] = torch.tensor([[1, 2, 3]]) # the tool verifies (n, m) matches actual shape
B: torch.Tensor[m, k] = torch.tensor([[1], [2], [3]]])
C: torch.Tensor[n, k] = torch.matmul(A, B) # tool verifies A and B can be multiplied and (n, m) matches shape(A dot B)
out -> VALID
```
该工具使用 Z3 SMT 求解器收集整数类型和 tensor 类型提示,并在编译时强制执行适用于 tensor 声明和线性代数操作的规则。这意味着你不需要运行代码就能发现细微的错误,该工具会检测你的错误并报告给你。查看以下示例和输出:
```
n: int = 1
m: int = 3
k: int = 1
A: torch.Tensor[n, m] = torch.tensor([[1, 2, 3, 4]]) # A's type annotation and its actual shape differ
B: torch.Tensor[m, k] = torch.randn(3, 1)
C: torch.Tensor[n, k] = torch.matmul(A, B)
out -> DeclarationError: tensor A was declared with shape(rows=1, cols=4), but expected shape(rows=1, cols=3)
```
# 工具架构
1. 使用 Python 的 _ast_ 模块将 Python 源代码转换为 AST
2. 自定义访问器遍历 AST,并将整数维度和包含 tensor 的节点转换为表示类型和形状的 IR
3. Z3 wrapper 对 IR 完成一次遍历,并根据类型和线性代数规则应用约束
4. 程序通知用户是否犯了任何维度错误
```
flowchart LR
A["Python Source Code"]
--> B["Python AST"]
B --> C["Semantic IR"]
C --> D["Z3 Constraint Layer"]
D --> E{"SAT ?"}
E -->|SAT| F["Runtime / Valid State"]
E -->|UNSAT| G["Invalid Path"]
%% Styling
style A fill:#2d3436,color:#fff,stroke:#636e72
style B fill:#6c5ce7,color:#fff,stroke:#4834d4
style C fill:#00b894,color:#fff,stroke:#019875
style D fill:#0984e3,color:#fff,stroke:#0767b1
style E fill:#b2bec3,color:#2d3436,stroke:#636e72
style F fill:#00cec9,color:#fff,stroke:#00a8a8
style G fill:#d63031,color:#fff,stroke:#a61e1e
```
# 如何使用
## 安装依赖项
你只需要 Python 3.14+ 和 Z3 即可运行该工具,只有在执行你的代码时才需要 Pytorch
```
python3.14 -m pip install z3-solver torch
pip install z3-solver
```
## 运行工具
```
torchdimchecker **your_file** --verbose
```
# 目前支持的功能
```
torch.matmul
torch.tensor
torch.randn
```
标签:Python, PyTorch, SOC Prime, Z3求解器, 凭据扫描, 开发工具, 无后门, 维度校验, 静态类型检查