Files
AareLC_train/experiments/pruning.py
T
2026-04-14 16:07:31 +02:00

17 lines
627 B
Python

import torch
import torch.nn.utils.prune as prune
from ultralytics import YOLO
# Load trained YOLO model
model_name = "best_yolo26n-seg-overlap-false_2026-04-12"
model = YOLO(f"/Users/duan_j/repos/aare_suite/AareLC/models/{model_name}.pt")
net = model.model # underlying PyTorch model
# Prune 30% of weights in Conv2d and Linear layers
for module in net.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
prune.l1_unstructured(module, name="weight", amount=0.3)
prune.remove(module, "weight")
# Save pruned model weights
torch.save(net.state_dict(), f"pruned_{model_name}_weights.pt")