Add prunt tool
This commit is contained in:
parent
d037c1f429
commit
621dabcadf
38
tools/prune.py
Normal file
38
tools/prune.py
Normal file
@ -0,0 +1,38 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description="Prune a model")
|
||||
parser.add_argument("model_prune", type=str, help="Path to model to prune")
|
||||
parser.add_argument("prune_output", type=str, help="Path to pruned ckpt output")
|
||||
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model...")
|
||||
model_prune = torch.load(args.model_prune)
|
||||
theta_prune = model_prune["state_dict"]
|
||||
theta = {}
|
||||
|
||||
print("Pruning model...")
|
||||
for key in tqdm(theta_prune.keys(), desc="Pruning keys"):
|
||||
if "model" in key:
|
||||
theta.update({key: theta_prune[key]})
|
||||
|
||||
del theta_prune
|
||||
|
||||
if args.half:
|
||||
print("Halving model...")
|
||||
state_dict = {k: v.half() for k, v in theta.items()}
|
||||
else:
|
||||
state_dict = theta
|
||||
|
||||
del theta
|
||||
|
||||
print("Saving pruned model...")
|
||||
|
||||
torch.save({"state_dict": state_dict}, args.prune_output)
|
||||
|
||||
del state_dict
|
||||
|
||||
print("Done pruning!")
|
Loading…
Reference in New Issue
Block a user