Skip to content

Commit 686f378

Browse files
committed
Add MLIR support (#1044)
1 parent 9f7629c commit 686f378

File tree

7 files changed

+19383
-85
lines changed

7 files changed

+19383
-85
lines changed

source/mlir-metadata.json

Lines changed: 17034 additions & 0 deletions
Large diffs are not rendered by default.

source/mlir.js

Lines changed: 1013 additions & 75 deletions
Large diffs are not rendered by default.

source/view.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6298,7 +6298,7 @@ view.ModelFactoryService = class {
62986298
this.register('./nnef', ['.nnef', '.dat']);
62996299
this.register('./onednn', ['.json']);
63006300
this.register('./espresso', ['.espresso.net', '.espresso.shape', '.espresso.weights'], ['.mlmodelc']);
6301-
this.register('./mlir', ['.mlir', '.mlir.txt', '.mlirbc']);
6301+
this.register('./mlir', ['.mlir', '.mlir.txt', '.mlirbc', '.txt']);
63026302
this.register('./sentencepiece', ['.model']);
63036303
this.register('./hailo', ['.hn', '.har', '.metadata.json']);
63046304
this.register('./tvm', ['.json', '.params']);

test/models.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3283,15 +3283,13 @@
32833283
"target": "dogs_vs_cats.stablehlo.mlir",
32843284
"source": "https://github.com/user-attachments/files/22517931/dogs_vs_cats.stablehlo.mlir.zip[dogs_vs_cats.stablehlo.mlir]",
32853285
"format": "MLIR",
3286-
"error": "Unexpected value 'array' at 7:20.",
32873286
"link": "https://github.com/lutzroeder/netron/issues/1044"
32883287
},
32893288
{
32903289
"type": "mlir",
32913290
"target": "example.mlir",
32923291
"source": "https://github.com/user-attachments/files/17792104/example.mlir.zip[example.mlir]",
32933292
"format": "MLIR",
3294-
"error": "Expected token of type '=', but got '%' at 10:59.",
32953293
"link": "https://github.com/lutzroeder/netron/issues/1044"
32963294
},
32973295
{
@@ -3334,7 +3332,6 @@
33343332
"target": "gemm.onnx.mlir",
33353333
"source": "https://github.com/user-attachments/files/17775324/gemm.onnx.mlir.zip[gemm.onnx.mlir]",
33363334
"format": "MLIR",
3337-
"error": "Expected token of type '=', but got ')' at 50:37.",
33383335
"link": "https://github.com/lutzroeder/netron/issues/1044"
33393336
},
33403337
{
@@ -3407,7 +3404,6 @@
34073404
"target": "stablehlo_gpt_125M.mlir",
34083405
"source": "https://github.com/user-attachments/files/21081167/stablehlo_gpt_125M.mlir.zip[stablehlo_gpt_125M.mlir]",
34093406
"format": "MLIR",
3410-
"error": "Expected token of type ')', but got 'id' at 14:40.",
34113407
"link": "https://github.com/lutzroeder/netron/issues/1044"
34123408
},
34133409
{
@@ -3422,7 +3418,6 @@
34223418
"target": "stablehlo_resnet18.mlir",
34233419
"source": "https://github.com/user-attachments/files/21081166/stablehlo_resnet18.mlir.zip[stablehlo_resnet18.mlir]",
34243420
"format": "MLIR",
3425-
"error": "Unexpected value '=' at 52:56.",
34263421
"link": "https://github.com/lutzroeder/netron/issues/1044"
34273422
},
34283423
{
@@ -3459,6 +3454,7 @@
34593454
"target": "wcr.mlir",
34603455
"source": "https://github.com/user-attachments/files/17788767/wcr.mlir.zip[wcr.mlir]",
34613456
"format": "MLIR",
3457+
"error": "Unexpected operation name '=' at 9:8.",
34623458
"link": "https://github.com/lutzroeder/netron/issues/1044"
34633459
},
34643460
{

tools/mlir

Lines changed: 120 additions & 4 deletions
Large diffs are not rendered by default.

tools/mlir_script.js

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
2+
import * as fs from 'fs/promises';
3+
import * as path from 'path';
4+
import * as tablegen from './tablegen.js';
5+
import * as url from 'url';
6+
7+
class Operator {
8+
9+
constructor(def) {
10+
this.def = def;
11+
let opInfo = null;
12+
for (const parent of this.def.parents) {
13+
const parentClass = this.def.parser.classes.get(parent.name);
14+
if (parentClass) {
15+
opInfo = this._findOpParent(parentClass, parent.args, {});
16+
if (opInfo) {
17+
break;
18+
}
19+
}
20+
}
21+
this.dialectName = opInfo?.dialect || null;
22+
this.opName = opInfo?.mnemonic || null;
23+
}
24+
25+
getDialectName() {
26+
return this.dialectName || '';
27+
}
28+
29+
getOperationName() {
30+
return this.dialectName && this.opName ? `${this.dialectName}.${this.opName}` : null;
31+
}
32+
33+
_findOpParent(parentClass, parentArgs, substitutions) {
34+
const subs = { ...substitutions };
35+
if (parentClass.templateArgs && parentArgs) {
36+
for (let i = 0; i < Math.min(parentClass.templateArgs.length, parentArgs.length); i++) {
37+
const paramName = parentClass.templateArgs[i].name;
38+
const argValue = parentArgs[i];
39+
subs[paramName] = (typeof argValue === 'string' && substitutions[argValue])
40+
? substitutions[argValue] : argValue;
41+
}
42+
}
43+
if (parentClass.name === 'Op' && parentArgs.length >= 2) {
44+
let [dialectArg, mnemonicArg] = parentArgs;
45+
if (typeof dialectArg === 'string' && subs[dialectArg]) {
46+
dialectArg = subs[dialectArg];
47+
}
48+
if (typeof mnemonicArg === 'string' && subs[mnemonicArg]) {
49+
mnemonicArg = subs[mnemonicArg];
50+
}
51+
let dialectName = null;
52+
if (typeof dialectArg === 'string') {
53+
const dialectDef = this.def.parser.defs.get(dialectArg) || this.def.parser.classes.get(dialectArg);
54+
if (dialectDef) {
55+
dialectName = dialectDef.getValueAsString('name');
56+
}
57+
}
58+
const mnemonic = typeof mnemonicArg === 'string' ? mnemonicArg.replace(/^"|"$/g, '') : null;
59+
if (dialectName && mnemonic) {
60+
return { dialect: dialectName, mnemonic };
61+
}
62+
}
63+
for (const grandparent of parentClass.parents) {
64+
const grandparentClass = this.def.parser.classes.get(grandparent.name);
65+
if (grandparentClass) {
66+
const resolvedArgs = grandparent.args.map((arg) =>
67+
(typeof arg === 'string' && subs[arg]) ? subs[arg] : arg
68+
);
69+
const result = this._findOpParent(grandparentClass, resolvedArgs, subs);
70+
if (result) {
71+
return result;
72+
}
73+
}
74+
}
75+
return null;
76+
}
77+
}
78+
79+
const main = async () => {
80+
const dirname = path.dirname(url.fileURLToPath(import.meta.url));
81+
const source = path.join(dirname, '..', 'third_party', 'source');
82+
const paths = [
83+
path.join(source, 'llvm-project', 'mlir', 'include'),
84+
path.join(source, 'stablehlo',),
85+
path.join(source, 'onnx-mlir'),
86+
path.join(source, 'torch-mlir', 'include')
87+
];
88+
const dialects = [
89+
'stablehlo/dialect/StablehloOps.td',
90+
'stablehlo/dialect/ChloOps.td',
91+
'mlir/Dialect/Affine/IR/AffineOps.td',
92+
'mlir/Dialect/Linalg/IR/LinalgOps.td',
93+
'mlir/Dialect/MemRef/IR/MemRefOps.td',
94+
'mlir/Dialect/Vector/IR/VectorOps.td',
95+
'mlir/Dialect/IRDL/IR/IRDLOps.td',
96+
'src/Dialect/ONNX/ONNX.td',
97+
'src/Dialect/ONNX/ONNXOps.td.inc',
98+
'src/Dialect/ONNX/AdditionalONNXOps.td',
99+
'torch-mlir/Dialect/Torch/IR/TorchOps.td',
100+
];
101+
const operations = [];
102+
const parser = new tablegen.Parser();
103+
await parser.parse(dialects, paths);
104+
for (const [, def] of parser.defs) {
105+
const op = new Operator(def);
106+
const operationName = op.getOperationName();
107+
if (!operationName) {
108+
continue;
109+
}
110+
const metadata = {
111+
name: operationName,
112+
dialect: op.getDialectName()
113+
};
114+
const summary = def.resolveField('summary');
115+
if (summary && summary.value) {
116+
metadata.summary = summary.value.value;
117+
}
118+
const description = def.resolveField('description');
119+
if (description && description.value) {
120+
metadata.description = description.value.value;
121+
}
122+
const argsField = def.resolveField('arguments');
123+
if (argsField && argsField.value && argsField.value.type === 'dag') {
124+
const dag = argsField.value.value;
125+
if (dag.operator === 'ins') {
126+
metadata.inputs = [];
127+
metadata.attributes = [];
128+
for (const operand of dag.operands) {
129+
if (!operand.value || !operand.name) {
130+
continue;
131+
}
132+
let typeName = '';
133+
if (operand.value.type === 'def') {
134+
typeName = operand.value.value;
135+
} else {
136+
// Try to extract from other value types
137+
typeName = String(operand.value.value);
138+
}
139+
if (typeName.includes('Attr')) {
140+
metadata.attributes.push({
141+
name: operand.name,
142+
type: typeName
143+
});
144+
} else {
145+
metadata.inputs.push({
146+
name: operand.name,
147+
type: typeName
148+
});
149+
}
150+
}
151+
}
152+
}
153+
const resultsField = def.resolveField('results');
154+
if (resultsField && resultsField.value && resultsField.value.type === 'dag') {
155+
const dag = resultsField.value.value;
156+
if (dag.operator === 'outs') {
157+
metadata.outputs = [];
158+
for (const operand of dag.operands) {
159+
if (!operand.value || !operand.name) {
160+
continue;
161+
}
162+
let typeName = '';
163+
if (operand.value.type === 'def') {
164+
typeName = operand.value.value;
165+
} else {
166+
typeName = String(operand.value.value);
167+
}
168+
metadata.outputs.push({
169+
name: operand.name,
170+
type: typeName
171+
});
172+
}
173+
}
174+
}
175+
const assemblyFormatField = def.resolveField('assemblyFormat');
176+
if (assemblyFormatField && assemblyFormatField.value) {
177+
metadata.assemblyFormat = assemblyFormatField.value.value;
178+
}
179+
const regionsField = def.resolveField('regions');
180+
if (regionsField) {
181+
metadata.hasRegions = true;
182+
}
183+
const operation = {};
184+
if (metadata.name) {
185+
operation.name = metadata.name;
186+
}
187+
if (metadata.category) {
188+
operation.category = metadata.category;
189+
}
190+
if (metadata.summary) {
191+
let summary = metadata.summary.trim();
192+
summary = summary.replace(/^"|"$/g, '');
193+
if (summary) {
194+
operation.summary = summary;
195+
}
196+
}
197+
if (metadata.description) {
198+
let desc = metadata.description.trim();
199+
desc = desc.replace(/^\[\{\s*|\s*\}\]$/g, '');
200+
desc = desc.trim();
201+
if (desc) {
202+
operation.description = desc;
203+
}
204+
}
205+
if (metadata.inputs && metadata.inputs.length > 0) {
206+
operation.inputs = metadata.inputs;
207+
}
208+
if (metadata.outputs && metadata.outputs.length > 0) {
209+
operation.outputs = metadata.outputs;
210+
}
211+
if (metadata.attributes && metadata.attributes.length > 0) {
212+
operation.attributes = metadata.attributes;
213+
}
214+
if (metadata.assemblyFormat) {
215+
let format = metadata.assemblyFormat.trim();
216+
format = format.replace(/^\[\{\s*|\s*\}\]$/g, '');
217+
if (format) {
218+
operation.assemblyFormat = format;
219+
}
220+
}
221+
if (Object.keys(operation).length > 1) {
222+
if (!operation.category) {
223+
const name = operation.name.replace(/^(stablehlo|chlo|affine|linalg|memref|vector|onnx|torch)\./, '');
224+
if (['reshape', 'broadcast_in_dim', 'dynamic_reshape', 'Reshape', 'Shape', 'Size', 'ConstantOfShape'].includes(name)) {
225+
operation.category = 'Shape';
226+
} else if (['transpose', 'reverse', 'pad', 'Transpose', 'Pad'].includes(name)) {
227+
operation.category = 'Transform';
228+
} else if (['slice', 'dynamic_slice', 'gather', 'scatter', 'Slice', 'Gather', 'Scatter'].includes(name)) {
229+
operation.category = 'Tensor';
230+
} else if (['tanh', 'Sigmoid', 'Tanh', 'Relu', 'Softmax'].includes(name)) {
231+
operation.category = 'Activation';
232+
} else if (['convolution', 'Conv', 'matmul', 'batch_matmul'].includes(name)) {
233+
operation.category = 'Layer';
234+
}
235+
}
236+
operations.push(operation);
237+
}
238+
}
239+
operations.sort((a, b) => a.name.localeCompare(b.name));
240+
let output = JSON.stringify(operations, null, 2);
241+
output = output.replace(/\{\s+"name":\s+"([^"]+)",\s+"type":\s+"([^"]+)"\s+\}/g, '{ "name": "$1", "type": "$2" }');
242+
const file = path.join(dirname, '..', 'source', 'mlir-metadata.json');
243+
await fs.writeFile(file, output, 'utf-8');
244+
};
245+
246+
await main();

0 commit comments

Comments
 (0)