Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Created January 28, 2026 08:16
Show Gist options
  • Select an option

  • Save stellaraccident/1b189c42a4c64c6d0854c34f7a61e34a to your computer and use it in GitHub Desktop.

Select an option

Save stellaraccident/1b189c42a4c64c6d0854c34f7a61e34a to your computer and use it in GitHub Desktop.
Mixtral compilation failure: expand_shape dimension inference through moe_ffn_block
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d1)>
#map4 = affine_map<(d0, d1, d2) -> (d2, d0)>
#map5 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map7 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map8 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
#map9 = affine_map<(d0, d1, d2) -> (d2)>
#map10 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>
#map11 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map12 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
#map13 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map14 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
#map15 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
#map16 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
#map17 = affine_map<(d0, d1, d2, d3, d4) -> ()>
#map18 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
#map19 = affine_map<(d0) -> (d0)>
#map20 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, 0)>
#map21 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, 1)>
#map22 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
#map23 = affine_map<(d0, d1, d2, d3, d4) -> (d3)>
#map24 = affine_map<(d0, d1) -> (d0)>
module @mixtral {
util.func private @hparams.vocab_size() -> i64 {
%c32000_i64 = arith.constant 32000 : i64
util.return %c32000_i64 : i64
}
util.func private @hparams.block_count() -> i64 {
%c2_i64 = arith.constant 2 : i64
util.return %c2_i64 : i64
}
util.func private @hparams.embedding_length() -> i64 {
%c2048_i64 = arith.constant 2048 : i64
util.return %c2048_i64 : i64
}
util.func private @hparams.attention_head_count() -> i64 {
%c16_i64 = arith.constant 16 : i64
util.return %c16_i64 : i64
}
util.func private @hparams.attention_head_count_kv() -> i64 {
%c4_i64 = arith.constant 4 : i64
util.return %c4_i64 : i64
}
util.func private @hparams.feed_forward_length() -> i64 {
%c8192_i64 = arith.constant 8192 : i64
util.return %c8192_i64 : i64
}
util.func private @hparams.expert_count() -> i64 {
%c4_i64 = arith.constant 4 : i64
util.return %c4_i64 : i64
}
util.func private @hparams.expert_used_count() -> i64 {
%c2_i64 = arith.constant 2 : i64
util.return %c2_i64 : i64
}
util.func private @hparams.rope_freq_base() -> f32 {
%cst = arith.constant 1.000000e+04 : f32
util.return %cst : f32
}
util.func private @hparams.layer_norm_rms_epsilon() -> f32 {
%cst = arith.constant 9.99999974E-6 : f32
util.return %cst : f32
}
util.func private @model_params.token_embd_weight() -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<0.000000e+00> : tensor<32000x2048xf32>
%cast = tensor.cast %0 : tensor<32000x2048xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @model_params.output_norm_weight() -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<1.000000e+00> : tensor<2048xf32>
%cast = tensor.cast %0 : tensor<2048xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func private @model_params.output_weight() -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<0.000000e+00> : tensor<2048x32000xf32>
%cast = tensor.cast %0 : tensor<2048x32000xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @embedding_components.embedding_lookup(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xi64>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
%dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = arith.muli %dim, %dim_0 : index
%collapsed = tensor.collapse_shape %arg1 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
%1 = tensor.empty(%0, %dim_1) : tensor<?x?xf32>
%2 = iree_linalg_ext.gather dimension_map = [0] ins(%arg0, %collapsed : tensor<?x?xf32>, tensor<?xi64>) outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
%expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [%dim, %dim_0, %dim_1] : tensor<?x?xf32> into tensor<?x?x?xf32>
util.return %expanded : tensor<?x?x?xf32>
}
util.func private @transformer_layer_moe_components.transformer_layer_moe(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: i32, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: f32, %arg10: f32, %arg11: f32, %arg12: i1, %arg13: i1) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%0 = arith.muli %dim, %dim_0 : index
%1 = util.call @model_params.attn_norm_weight(%arg2) : (i32) -> tensor<?xf32>
%2 = util.call @model_params.ffn_norm_weight(%arg2) : (i32) -> tensor<?xf32>
%3 = util.call @model_params.attn_q_weight(%arg2) : (i32) -> tensor<?x?xf32>
%4 = util.call @model_params.attn_k_weight(%arg2) : (i32) -> tensor<?x?xf32>
%5 = util.call @model_params.attn_v_weight(%arg2) : (i32) -> tensor<?x?xf32>
%6 = util.call @model_params.attn_output_weight(%arg2) : (i32) -> tensor<?x?xf32>
%7 = util.call @model_params.attn_q_bias(%arg2) : (i32) -> tensor<?xf32>
%8 = util.call @model_params.attn_k_bias(%arg2) : (i32) -> tensor<?xf32>
%9 = util.call @model_params.attn_v_bias(%arg2) : (i32) -> tensor<?xf32>
%10 = util.call @model_params.attn_output_bias(%arg2) : (i32) -> tensor<?xf32>
%11 = util.call @model_params.ffn_gate_inp_weight(%arg2) : (i32) -> tensor<?x?xf32>
%12 = util.call @model_params.ffn_up_exps_weight(%arg2) : (i32) -> tensor<?x?x?xf32>
%13 = util.call @model_params.ffn_gate_exps_weight(%arg2) : (i32) -> tensor<?x?x?xf32>
%14 = util.call @model_params.ffn_down_exps_weight(%arg2) : (i32) -> tensor<?x?x?xf32>
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
%15 = util.call @rms_norm_components.rms_norm_linalg(%collapsed, %1, %arg9) : (tensor<?x?xf32>, tensor<?xf32>, f32) -> tensor<?x?xf32>
%expanded = tensor.expand_shape %15 [[0, 1], [2]] output_shape [%dim, %dim_0, %arg5] : tensor<?x?xf32> into tensor<?x?x?xf32>
%16 = util.call @attention_block_components.attention_block(%expanded, %arg1, %3, %4, %5, %6, %7, %8, %9, %10, %arg12, %arg3, %arg4, %arg5, %arg10, %arg11) : (tensor<?x?x?xf32>, tensor<?x?xi64>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, i1, index, index, index, f32, f32) -> tensor<?x?x?xf32>
%17 = tensor.empty(%dim, %dim_0, %arg5) : tensor<?x?x?xf32>
%18 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %16 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%17 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_3: f32, %out: f32):
%23 = arith.addf %in, %in_3 : f32
linalg.yield %23 : f32
} -> tensor<?x?x?xf32>
%collapsed_1 = tensor.collapse_shape %18 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
%19 = util.call @rms_norm_components.rms_norm_linalg(%collapsed_1, %2, %arg9) : (tensor<?x?xf32>, tensor<?xf32>, f32) -> tensor<?x?xf32>
%20 = util.call @moe_ffn_components.moe_ffn_block(%19, %11, %12, %13, %14, %arg7, %arg8, %arg5, %arg6, %arg13) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, index, index, index, index, i1) -> tensor<?x?xf32>
%expanded_2 = tensor.expand_shape %20 [[0, 1], [2]] output_shape [%dim, %dim_0, %arg5] : tensor<?x?xf32> into tensor<?x?x?xf32>
%21 = tensor.empty(%dim, %dim_0, %arg5) : tensor<?x?x?xf32>
%22 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%18, %expanded_2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%21 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_3: f32, %out: f32):
%23 = arith.addf %in, %in_3 : f32
linalg.yield %23 : f32
} -> tensor<?x?x?xf32>
util.return %22 : tensor<?x?x?xf32>
}
util.func private @moe_ffn_components.moe_ffn_block(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?x?xf32>, %arg3: tensor<?x?x?xf32>, %arg4: tensor<?x?x?xf32>, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: i1) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%0 = tensor.empty(%arg7, %dim) : tensor<?x?xf32>
%1 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?xf32>
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty(%arg5, %dim) : tensor<?x?xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = linalg.matmul ins(%arg1, %1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%3 : tensor<?x?xf32>) -> tensor<?x?xf32>
%5 = tensor.empty(%arg5, %dim) : tensor<?x?xf32>
%6 = linalg.softmax dimension(0) ins(%4 : tensor<?x?xf32>) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
%7 = tensor.empty(%arg6, %dim) : tensor<?x?xf32>
%8 = tensor.empty(%arg6, %dim) : tensor<?x?xi32>
%9:2 = iree_linalg_ext.topk dimension(0) ins(%6 : tensor<?x?xf32>) outs(%7, %8 : tensor<?x?xf32>, tensor<?x?xi32>) {
^bb0(%arg10: f32, %arg11: f32):
%24 = arith.cmpf ogt, %arg10, %arg11 : f32
iree_linalg_ext.yield %24 : i1
} -> tensor<?x?xf32>, tensor<?x?xi32>
%10 = scf.if %arg9 -> (tensor<?x?xf32>) {
%24 = tensor.empty(%dim) : tensor<?xf32>
%25 = linalg.fill ins(%cst : f32) outs(%24 : tensor<?xf32>) -> tensor<?xf32>
%26 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["reduction", "parallel"]} ins(%9#0 : tensor<?x?xf32>) outs(%25 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%29 = arith.addf %in, %out : f32
linalg.yield %29 : f32
} -> tensor<?xf32>
%27 = tensor.empty(%arg6, %dim) : tensor<?x?xf32>
%28 = linalg.generic {indexing_maps = [#map2, #map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%9#0, %26 : tensor<?x?xf32>, tensor<?xf32>) outs(%27 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%29 = arith.divf %in, %in_0 : f32
linalg.yield %29 : f32
} -> tensor<?x?xf32>
scf.yield %28 : tensor<?x?xf32>
} else {
scf.yield %9#0 : tensor<?x?xf32>
}
%11 = tensor.empty(%arg7, %arg6, %dim) : tensor<?x?x?xf32>
%12 = linalg.generic {indexing_maps = [#map4, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%11 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%13 = util.call @moe_components.mul_mat_id(%arg2, %12, %9#1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?xf32>
%14 = util.call @moe_components.mul_mat_id(%arg3, %12, %9#1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?xf32>
%15 = util.call @activation_components.swiglu(%14, %13) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%16 = util.call @moe_components.mul_mat_id(%arg4, %15, %9#1) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?xi32>) -> tensor<?x?x?xf32>
%17 = tensor.empty(%arg7, %arg6, %dim) : tensor<?x?x?xf32>
%18 = linalg.generic {indexing_maps = [#map, #map5, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%16, %10 : tensor<?x?x?xf32>, tensor<?x?xf32>) outs(%17 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%24 = arith.mulf %in, %in_0 : f32
linalg.yield %24 : f32
} -> tensor<?x?x?xf32>
%19 = tensor.empty(%arg7, %dim) : tensor<?x?xf32>
%20 = linalg.fill ins(%cst : f32) outs(%19 : tensor<?x?xf32>) -> tensor<?x?xf32>
%21 = linalg.generic {indexing_maps = [#map, #map6], iterator_types = ["parallel", "reduction", "parallel"]} ins(%18 : tensor<?x?x?xf32>) outs(%20 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
%24 = arith.addf %in, %out : f32
linalg.yield %24 : f32
} -> tensor<?x?xf32>
%22 = tensor.empty(%dim, %arg7) : tensor<?x?xf32>
%23 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%21 : tensor<?x?xf32>) outs(%22 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?xf32>
util.return %23 : tensor<?x?xf32>
}
util.func private @activation_components.swiglu(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%0 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%0 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%2 = arith.negf %in_2 : f32
%3 = math.exp %2 : f32
%cst = arith.constant 1.000000e+00 : f32
%4 = arith.addf %cst, %3 : f32
%5 = arith.divf %cst, %4 : f32
%6 = arith.mulf %in_2, %5 : f32
%7 = arith.mulf %in, %6 : f32
linalg.yield %7 : f32
} -> tensor<?x?x?xf32>
util.return %1 : tensor<?x?x?xf32>
}
util.func private @moe_components.mul_mat_id(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?xi32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%dim_2 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_3 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%0 = tensor.empty(%dim_1, %dim, %dim_0) : tensor<?x?x?xf32>
%1 = linalg.generic {indexing_maps = [#map, #map7], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf32>) outs(%0 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%2 = arith.muli %dim_2, %dim_3 : index
%collapsed = tensor.collapse_shape %arg2 [[0, 1]] : tensor<?x?xi32> into tensor<?xi32>
%3 = tensor.empty(%2, %dim, %dim_0) : tensor<?x?x?xf32>
%4 = iree_linalg_ext.gather dimension_map = [0] ins(%1, %collapsed : tensor<?x?x?xf32>, tensor<?xi32>) outs(%3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%5 = tensor.empty(%dim_2, %dim_3, %dim_0) : tensor<?x?x?xf32>
%6 = linalg.generic {indexing_maps = [#map, #map8], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x?x?xf32>) outs(%5 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%collapsed_4 = tensor.collapse_shape %6 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
%expanded = tensor.expand_shape %collapsed_4 [[0], [1, 2]] output_shape [%2, %dim_0, %c1] : tensor<?x?xf32> into tensor<?x?x1xf32>
%cst = arith.constant 0.000000e+00 : f32
%7 = tensor.empty(%2, %dim) : tensor<?x?x1xf32>
%8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<?x?x1xf32>) -> tensor<?x?x1xf32>
%9 = linalg.batch_matmul ins(%4, %expanded : tensor<?x?x?xf32>, tensor<?x?x1xf32>) outs(%8 : tensor<?x?x1xf32>) -> tensor<?x?x1xf32>
%collapsed_5 = tensor.collapse_shape %9 [[0], [1, 2]] : tensor<?x?x1xf32> into tensor<?x?xf32>
%expanded_6 = tensor.expand_shape %collapsed_5 [[0, 1], [2]] output_shape [%dim_2, %dim_3, %dim] : tensor<?x?xf32> into tensor<?x?x?xf32>
%10 = tensor.empty(%dim, %dim_2, %dim_3) : tensor<?x?x?xf32>
%11 = linalg.generic {indexing_maps = [#map, #map7], iterator_types = ["parallel", "parallel", "parallel"]} ins(%expanded_6 : tensor<?x?x?xf32>) outs(%10 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
util.return %11 : tensor<?x?x?xf32>
}
util.func private @attention_block_components.attention_block(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>, %arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>, %arg6: tensor<?xf32>, %arg7: tensor<?xf32>, %arg8: tensor<?xf32>, %arg9: tensor<?xf32>, %arg10: i1, %arg11: index, %arg12: index, %arg13: index, %arg14: f32, %arg15: f32) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg3, %c1 : tensor<?x?xf32>
%dim_2 = tensor.dim %arg2, %c0 : tensor<?x?xf32>
%dim_3 = tensor.dim %arg3, %c0 : tensor<?x?xf32>
%dim_4 = tensor.dim %arg4, %c0 : tensor<?x?xf32>
%dim_5 = tensor.dim %arg5, %c0 : tensor<?x?xf32>
%0 = tensor.empty(%dim, %dim_2, %arg13) : tensor<?x?x?xf32>
%1 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg2 : tensor<?x?xf32>) outs(%0 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%2 = tensor.empty(%dim, %dim_3, %dim_1) : tensor<?x?x?xf32>
%3 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg3 : tensor<?x?xf32>) outs(%2 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%4 = tensor.empty(%dim, %dim_4, %dim_1) : tensor<?x?x?xf32>
%5 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg4 : tensor<?x?xf32>) outs(%4 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%cst = arith.constant 0.000000e+00 : f32
%6 = tensor.empty(%dim, %dim_0, %arg13) : tensor<?x?x?xf32>
%7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%8 = linalg.batch_matmul ins(%arg0, %1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%7 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%9 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32>
%10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%11 = linalg.batch_matmul ins(%arg0, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%10 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%12 = tensor.empty(%dim, %dim_0, %dim_1) : tensor<?x?x?xf32>
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%14 = linalg.batch_matmul ins(%arg0, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%13 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%15 = scf.if %arg10 -> (tensor<?x?x?xf32>) {
%32 = linalg.generic {indexing_maps = [#map, #map9, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%8, %arg6 : tensor<?x?x?xf32>, tensor<?xf32>) outs(%6 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_8: f32, %out: f32):
%33 = arith.addf %in, %in_8 : f32
linalg.yield %33 : f32
} -> tensor<?x?x?xf32>
scf.yield %32 : tensor<?x?x?xf32>
} else {
scf.yield %8 : tensor<?x?x?xf32>
}
%16 = scf.if %arg10 -> (tensor<?x?x?xf32>) {
%32 = linalg.generic {indexing_maps = [#map, #map9, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%11, %arg7 : tensor<?x?x?xf32>, tensor<?xf32>) outs(%9 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_8: f32, %out: f32):
%33 = arith.addf %in, %in_8 : f32
linalg.yield %33 : f32
} -> tensor<?x?x?xf32>
scf.yield %32 : tensor<?x?x?xf32>
} else {
scf.yield %11 : tensor<?x?x?xf32>
}
%17 = scf.if %arg10 -> (tensor<?x?x?xf32>) {
%32 = linalg.generic {indexing_maps = [#map, #map9, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%14, %arg8 : tensor<?x?x?xf32>, tensor<?xf32>) outs(%12 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_8: f32, %out: f32):
%33 = arith.addf %in, %in_8 : f32
linalg.yield %33 : f32
} -> tensor<?x?x?xf32>
scf.yield %32 : tensor<?x?x?xf32>
} else {
scf.yield %14 : tensor<?x?x?xf32>
}
%18 = arith.divsi %arg13, %arg11 : index
%19 = arith.divsi %dim_1, %arg12 : index
%expanded = tensor.expand_shape %15 [[0], [1], [2, 3]] output_shape [%dim, %dim_0, %arg11, %18] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
%expanded_6 = tensor.expand_shape %16 [[0], [1], [2, 3]] output_shape [%dim, %dim_0, %arg12, %19] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
%expanded_7 = tensor.expand_shape %17 [[0], [1], [2, 3]] output_shape [%dim, %dim_0, %arg12, %19] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
%20 = util.call @position_components.rope(%expanded, %arg1, %arg14, %arg15) : (tensor<?x?x?x?xf32>, tensor<?x?xi64>, f32, f32) -> tensor<?x?x?x?xf32>
%21 = util.call @position_components.rope(%expanded_6, %arg1, %arg14, %arg15) : (tensor<?x?x?x?xf32>, tensor<?x?xi64>, f32, f32) -> tensor<?x?x?x?xf32>
%22 = arith.index_cast %18 : index to i32
%23 = arith.sitofp %22 : i32 to f32
%24 = math.rsqrt %23 : f32
%25 = util.call @attention_components.attention_gqa(%20, %21, %expanded_7, %24) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, f32) -> tensor<?x?x?x?xf32>
%collapsed = tensor.collapse_shape %25 [[0], [1], [2, 3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%26 = tensor.empty(%dim, %dim_5, %arg13) : tensor<?x?x?xf32>
%27 = linalg.generic {indexing_maps = [#map5, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg5 : tensor<?x?xf32>) outs(%26 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?xf32>
%28 = tensor.empty(%dim, %dim_0, %arg13) : tensor<?x?x?xf32>
%29 = linalg.fill ins(%cst : f32) outs(%28 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%30 = linalg.batch_matmul ins(%collapsed, %27 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%29 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%31 = scf.if %arg10 -> (tensor<?x?x?xf32>) {
%32 = linalg.generic {indexing_maps = [#map, #map9, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%30, %arg9 : tensor<?x?x?xf32>, tensor<?xf32>) outs(%28 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %in_8: f32, %out: f32):
%33 = arith.addf %in, %in_8 : f32
linalg.yield %33 : f32
} -> tensor<?x?x?xf32>
scf.yield %32 : tensor<?x?x?xf32>
} else {
scf.yield %30 : tensor<?x?x?xf32>
}
util.return %31 : tensor<?x?x?xf32>
}
util.func private @attention_components.attention_gqa(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>, %arg3: f32) -> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?x?xf32>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?x?xf32>
%dim_2 = tensor.dim %arg0, %c3 : tensor<?x?x?x?xf32>
%dim_3 = tensor.dim %arg1, %c2 : tensor<?x?x?x?xf32>
%0 = tensor.empty(%dim, %dim_1, %dim_0, %dim_2) : tensor<?x?x?x?xf32>
%1 = linalg.generic {indexing_maps = [#map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?x?xf32>) outs(%0 : tensor<?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?x?xf32>
%2 = tensor.empty(%dim, %dim_3, %dim_0, %dim_2) : tensor<?x?x?x?xf32>
%3 = linalg.generic {indexing_maps = [#map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<?x?x?x?xf32>) outs(%2 : tensor<?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?x?xf32>
%4 = tensor.empty(%dim, %dim_3, %dim_0, %dim_2) : tensor<?x?x?x?xf32>
%5 = linalg.generic {indexing_maps = [#map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<?x?x?x?xf32>) outs(%4 : tensor<?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?x?xf32>
%6 = arith.divui %dim_1, %dim_3 : index
%7 = tensor.empty(%dim, %dim_3, %6, %dim_0, %dim_2) : tensor<?x?x?x?x?xf32>
%8 = linalg.generic {indexing_maps = [#map12, #map13], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%3 : tensor<?x?x?x?xf32>) outs(%7 : tensor<?x?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?x?x?xf32>
%collapsed = tensor.collapse_shape %8 [[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
%9 = tensor.empty(%dim, %dim_3, %6, %dim_0, %dim_2) : tensor<?x?x?x?x?xf32>
%10 = linalg.generic {indexing_maps = [#map12, #map13], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<?x?x?x?xf32>) outs(%9 : tensor<?x?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?x?x?xf32>
%collapsed_4 = tensor.collapse_shape %10 [[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
%11 = arith.muli %dim, %dim_1 : index
%collapsed_5 = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%collapsed_6 = tensor.collapse_shape %collapsed [[0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%collapsed_7 = tensor.collapse_shape %collapsed_4 [[0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%12 = tensor.empty(%11, %dim_0, %dim_2) : tensor<?x?x?xf32>
%13 = iree_linalg_ext.attention {indexing_maps = [#map14, #map15, #map16, #map17, #map18]} ins(%collapsed_5, %collapsed_6, %collapsed_7, %arg3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32) outs(%12 : tensor<?x?x?xf32>) {
^bb0(%arg4: f32):
iree_linalg_ext.yield %arg4 : f32
} -> tensor<?x?x?xf32>
%expanded = tensor.expand_shape %13 [[0, 1], [2], [3]] output_shape [%dim, %dim_1, %dim_0, %dim_2] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
%14 = tensor.empty(%dim, %dim_0, %dim_1, %dim_2) : tensor<?x?x?x?xf32>
%15 = linalg.generic {indexing_maps = [#map10, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<?x?x?x?xf32>) outs(%14 : tensor<?x?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<?x?x?x?xf32>
util.return %15 : tensor<?x?x?x?xf32>
}
util.func private @position_components.rope(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: f32, %arg3: f32) -> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?x?xf32>
%dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?x?xf32>
%dim_2 = tensor.dim %arg0, %c3 : tensor<?x?x?x?xf32>
%0 = arith.divsi %dim_2, %c2 : index
%expanded = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [%dim, %dim_0, %dim_1, %0, 2] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x2xf32>
%1 = arith.index_cast %dim_2 : index to i32
%2 = arith.sitofp %1 : i32 to f32
%3 = tensor.empty(%0) : tensor<?xf32>
%4 = linalg.generic {indexing_maps = [#map19], iterator_types = ["parallel"]} outs(%3 : tensor<?xf32>) {
^bb0(%out: f32):
%7 = linalg.index 0 : index
%8 = arith.index_cast %7 : index to i32
%9 = arith.sitofp %8 : i32 to f32
%cst = arith.constant 2.000000e+00 : f32
%10 = arith.mulf %cst, %9 : f32
%11 = arith.divf %10, %2 : f32
%12 = arith.negf %11 : f32
%13 = math.powf %arg2, %12 : f32
%14 = arith.mulf %13, %arg3 : f32
linalg.yield %14 : f32
} -> tensor<?xf32>
%5 = tensor.empty(%dim, %dim_0, %dim_1, %0) : tensor<?x?x?x?x2xf32>
%6 = linalg.generic {indexing_maps = [#map20, #map21, #map22, #map23, #map13], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%expanded, %expanded, %arg1, %4 : tensor<?x?x?x?x2xf32>, tensor<?x?x?x?x2xf32>, tensor<?x?xi64>, tensor<?xf32>) outs(%5 : tensor<?x?x?x?x2xf32>) {
^bb0(%in: f32, %in_3: f32, %in_4: i64, %in_5: f32, %out: f32):
%7 = arith.trunci %in_4 : i64 to i32
%8 = arith.sitofp %7 : i32 to f32
%9 = arith.mulf %8, %in_5 : f32
%10 = math.cos %9 : f32
%11 = math.sin %9 : f32
%12 = linalg.index 4 : index
%c0_6 = arith.constant 0 : index
%13 = arith.cmpi eq, %12, %c0_6 : index
%14 = scf.if %13 -> (f32) {
%15 = arith.mulf %in, %10 : f32
%16 = arith.mulf %in_3, %11 : f32
%17 = arith.subf %15, %16 : f32
scf.yield %17 : f32
} else {
%15 = arith.mulf %in, %11 : f32
%16 = arith.mulf %in_3, %10 : f32
%17 = arith.addf %15, %16 : f32
scf.yield %17 : f32
}
linalg.yield %14 : f32
} -> tensor<?x?x?x?x2xf32>
%collapsed = tensor.collapse_shape %6 [[0], [1], [2], [3, 4]] : tensor<?x?x?x?x2xf32> into tensor<?x?x?x?xf32>
util.return %collapsed : tensor<?x?x?x?xf32>
}
util.func private @rms_norm_components.rms_norm_linalg(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: f32) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%0 = tensor.empty(%dim) : tensor<?xf32>
%cst = arith.constant 0.000000e+00 : f32
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
%2 = linalg.generic {indexing_maps = [#map2, #map24], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<?x?xf32>) outs(%1 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.mulf %in, %in : f32
%10 = arith.addf %out, %9 : f32
linalg.yield %10 : f32
} -> tensor<?xf32>
%3 = arith.index_cast %dim_0 : index to i32
%4 = arith.sitofp %3 : i32 to f32
%5 = tensor.empty(%dim) : tensor<?xf32>
%6 = linalg.generic {indexing_maps = [#map19, #map19], iterator_types = ["parallel"]} ins(%2 : tensor<?xf32>) outs(%5 : tensor<?xf32>) {
^bb0(%in: f32, %out: f32):
%9 = arith.divf %in, %4 : f32
%10 = arith.addf %9, %arg2 : f32
%11 = math.sqrt %10 : f32
linalg.yield %11 : f32
} -> tensor<?xf32>
%7 = tensor.empty(%dim, %dim_0) : tensor<?x?xf32>
%8 = linalg.generic {indexing_maps = [#map2, #map24, #map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg0, %6, %arg1 : tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>) outs(%7 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %out: f32):
%9 = arith.divf %in, %in_1 : f32
%10 = arith.mulf %9, %in_2 : f32
linalg.yield %10 : f32
} -> tensor<?x?xf32>
util.return %8 : tensor<?x?xf32>
}
util.func private @model_params.ffn_down_exps_weight(%arg0: i32) -> tensor<?x?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<2048x8192x4xf32>
%cast = tensor.cast %0 : tensor<2048x8192x4xf32> to tensor<?x?x?xf32>
util.return %cast : tensor<?x?x?xf32>
}
util.func private @model_params.ffn_gate_exps_weight(%arg0: i32) -> tensor<?x?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<8192x2048x4xf32>
%cast = tensor.cast %0 : tensor<8192x2048x4xf32> to tensor<?x?x?xf32>
util.return %cast : tensor<?x?x?xf32>
}
util.func private @model_params.ffn_up_exps_weight(%arg0: i32) -> tensor<?x?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<8192x2048x4xf32>
%cast = tensor.cast %0 : tensor<8192x2048x4xf32> to tensor<?x?x?xf32>
util.return %cast : tensor<?x?x?xf32>
}
util.func private @model_params.ffn_gate_inp_weight(%arg0: i32) -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<4x2048xf32>
%cast = tensor.cast %0 : tensor<4x2048xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @model_params.attn_output_bias(%arg0: i32) -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<0.000000e+00> : tensor<2048xf32>
%cast = tensor.cast %0 : tensor<2048xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func private @model_params.attn_v_bias(%arg0: i32) -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<0.000000e+00> : tensor<512xf32>
%cast = tensor.cast %0 : tensor<512xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func private @model_params.attn_k_bias(%arg0: i32) -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<0.000000e+00> : tensor<512xf32>
%cast = tensor.cast %0 : tensor<512xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func private @model_params.attn_q_bias(%arg0: i32) -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<0.000000e+00> : tensor<2048xf32>
%cast = tensor.cast %0 : tensor<2048xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func private @model_params.attn_output_weight(%arg0: i32) -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<2048x2048xf32>
%cast = tensor.cast %0 : tensor<2048x2048xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @model_params.attn_v_weight(%arg0: i32) -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<2048x512xf32>
%cast = tensor.cast %0 : tensor<2048x512xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @model_params.attn_k_weight(%arg0: i32) -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<2048x512xf32>
%cast = tensor.cast %0 : tensor<2048x512xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @model_params.attn_q_weight(%arg0: i32) -> tensor<?x?xf32> {
%0 = util.unfoldable_constant dense<1.000000e-01> : tensor<2048x2048xf32>
%cast = tensor.cast %0 : tensor<2048x2048xf32> to tensor<?x?xf32>
util.return %cast : tensor<?x?xf32>
}
util.func private @model_params.ffn_norm_weight(%arg0: i32) -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<1.000000e+00> : tensor<2048xf32>
%cast = tensor.cast %0 : tensor<2048xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func private @model_params.attn_norm_weight(%arg0: i32) -> tensor<?xf32> {
%0 = util.unfoldable_constant dense<1.000000e+00> : tensor<2048xf32>
%cast = tensor.cast %0 : tensor<2048xf32> to tensor<?xf32>
util.return %cast : tensor<?xf32>
}
util.func public @forward(%arg0: tensor<?x?xi64>, %arg1: tensor<?x?xi64>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = util.call @hparams.vocab_size() : () -> i64
%1 = util.call @hparams.block_count() : () -> i64
%2 = util.call @hparams.embedding_length() : () -> i64
%3 = util.call @hparams.attention_head_count() : () -> i64
%4 = util.call @hparams.attention_head_count_kv() : () -> i64
%5 = util.call @hparams.feed_forward_length() : () -> i64
%6 = util.call @hparams.expert_count() : () -> i64
%7 = util.call @hparams.expert_used_count() : () -> i64
%8 = util.call @hparams.rope_freq_base() : () -> f32
%9 = util.call @hparams.layer_norm_rms_epsilon() : () -> f32
%10 = arith.index_cast %0 : i64 to index
%11 = util.assume.int %10<umin = 32000, umax = 200000> : index
%12 = arith.index_cast %1 : i64 to index
%13 = util.assume.int %12<umin = 16, umax = 80> : index
%14 = arith.index_cast %2 : i64 to index
%15 = util.assume.int %14<umin = 2048, umax = 8192> : index
%16 = arith.index_cast %3 : i64 to index
%17 = util.assume.int %16<umin = 16, umax = 64> : index
%18 = arith.index_cast %4 : i64 to index
%19 = util.assume.int %18<umin = 4, umax = 16> : index
%20 = arith.index_cast %5 : i64 to index
%21 = util.assume.int %20<umin = 8192, umax = 32768> : index
%22 = arith.index_cast %6 : i64 to index
%23 = util.assume.int %22<umin = 4, umax = 64> : index
%24 = arith.index_cast %7 : i64 to index
%25 = util.assume.int %24<umin = 1, umax = 8> : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?xi64>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xi64>
%26 = util.assume.int %dim<umin = 1, umax = 32> : index
%27 = util.assume.int %dim_0<umin = 1, umax = 32768> : index
%false = arith.constant false
%false_1 = arith.constant false
%28 = util.call @model_params.token_embd_weight() : () -> tensor<?x?xf32>
%29 = util.call @embedding_components.embedding_lookup(%28, %arg0) : (tensor<?x?xf32>, tensor<?x?xi64>) -> tensor<?x?x?xf32>
%cst = arith.constant 1.000000e+00 : f32
%30 = scf.for %arg2 = %c0 to %13 step %c1 iter_args(%arg3 = %29) -> (tensor<?x?x?xf32>) {
%37 = arith.index_cast %arg2 : index to i32
%38 = util.call @transformer_layer_moe_components.transformer_layer_moe(%arg3, %arg1, %37, %17, %19, %15, %21, %23, %25, %8, %cst, %9, %false, %false_1) : (tensor<?x?x?xf32>, tensor<?x?xi64>, i32, index, index, index, index, index, index, f32, f32, f32, i1, i1) -> tensor<?x?x?xf32>
scf.yield %38 : tensor<?x?x?xf32>
}
%collapsed = tensor.collapse_shape %30 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
%31 = util.call @model_params.output_norm_weight() : () -> tensor<?xf32>
%32 = util.call @rms_norm_components.rms_norm_linalg(%collapsed, %31, %9) : (tensor<?x?xf32>, tensor<?xf32>, f32) -> tensor<?x?xf32>
%expanded = tensor.expand_shape %32 [[0, 1], [2]] output_shape [%26, %27, %15] : tensor<?x?xf32> into tensor<?x?x?xf32>
%33 = util.call @model_params.output_weight() : () -> tensor<?x?xf32>
%expanded_2 = tensor.expand_shape %33 [[0, 1], [2]] output_shape [%c1, %15, %11] : tensor<?x?xf32> into tensor<1x?x?xf32>
%cast = tensor.cast %expanded_2 : tensor<1x?x?xf32> to tensor<?x?x?xf32>
%34 = tensor.empty(%26, %27, %11) : tensor<?x?x?xf32>
%cst_3 = arith.constant 0.000000e+00 : f32
%35 = linalg.fill ins(%cst_3 : f32) outs(%34 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%36 = linalg.batch_matmul ins(%expanded, %cast : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%35 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
util.return %36 : tensor<?x?x?xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment