def sparsify_weight(
module: torch.nn.Module,
row_scalars_dict: Dict[torch.nn.Module, torch.Tensor],
sparsity: float,
prune_n: int,
prune_m: int,
) -> torch.Tensor:
"""
Run pruning on the layer up to the target sparsity value.
:param sparsity: target sparsity to reach for layer
:param prunen: N for N:M pruning
:param prunem: M for N:M pruning
"""
final_shape = module.weight.shape
final_dtype = module.weight.dtype
W = module.weight.data.clone()
if isinstance(module, torch.nn.Conv2d):
W = W.flatten(1)
if isinstance(module, transformers.Conv1D):
W = W.t()
W = W.to(dtype=WANDA_PRECISION)
S = row_scalars_dict[module] # unfortunately python does not have a `move` keyword
del row_scalars_dict[module] # so we have to delete the original reference manually
W_metric = torch.abs(W) * torch.sqrt(S.reshape((1, -1)))
# initialize a mask to be all False
W_mask = torch.zeros_like(W_metric) == 1
if prune_n != 0:
# structured n:m sparsity
for ii in range(W_metric.shape[1]):
if ii % prune_m == 0:
tmp = W_metric[:, ii : (ii + prune_m)].float()
W_mask.scatter_(
1,
ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
True,
)
else:
sort_res = torch.sort(W_metric, dim=-1, stable=True)
indices = sort_res[1][:, : int(W_metric.shape[1] * sparsity)]
W_mask.scatter_(1, indices, True)
W[W_mask] = 0.0 # set weights to zero
if isinstance(module, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)
return W