KohyaSS/tools/prune.py

38 lines
964 B
Python
Raw Normal View History

2022-12-02 00:06:33 +00:00
import os
import argparse
import torch
from tqdm import tqdm
parser = argparse.ArgumentParser(description="Prune a model")
parser.add_argument("model_prune", type=str, help="Path to model to prune")
parser.add_argument("prune_output", type=str, help="Path to pruned ckpt output")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
args = parser.parse_args()
print("Loading model...")
model_prune = torch.load(args.model_prune)
theta_prune = model_prune["state_dict"]
theta = {}
print("Pruning model...")
for key in tqdm(theta_prune.keys(), desc="Pruning keys"):
if "model" in key:
theta.update({key: theta_prune[key]})
del theta_prune
if args.half:
print("Halving model...")
state_dict = {k: v.half() for k, v in theta.items()}
else:
state_dict = theta
del theta
print("Saving pruned model...")
torch.save({"state_dict": state_dict}, args.prune_output)
del state_dict
print("Done pruning!")