17 lines
627 B
Python
17 lines
627 B
Python
import torch
|
|
import torch.nn.utils.prune as prune
|
|
from ultralytics import YOLO
|
|
|
|
# Load trained YOLO model
|
|
model_name = "best_yolo26n-seg-overlap-false_2026-04-12"
|
|
model = YOLO(f"/Users/duan_j/repos/aare_suite/AareLC/models/{model_name}.pt")
|
|
net = model.model # underlying PyTorch model
|
|
|
|
# Prune 30% of weights in Conv2d and Linear layers
|
|
for module in net.modules():
|
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
|
prune.l1_unstructured(module, name="weight", amount=0.3)
|
|
prune.remove(module, "weight")
|
|
|
|
# Save pruned model weights
|
|
torch.save(net.state_dict(), f"pruned_{model_name}_weights.pt") |