Adds NSFW content filter option
This commit is contained in:
parent
fa8be8acd6
commit
fc18e2d483
@ -19,6 +19,14 @@ import modules.face_restoration
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
# load safety model
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
safety_checker = None
|
||||
|
||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
@ -146,6 +154,28 @@ def fix_seed(p):
|
||||
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
|
||||
|
||||
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
# check and replace nsfw content
|
||||
def check_safety(x_image):
|
||||
global safety_feature_extractor, safety_checker
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
return x_checked_image, has_nsfw_concept
|
||||
|
||||
|
||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
@ -248,6 +278,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if opts.filter_nsfw:
|
||||
x_samples_ddim_numpy = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
||||
x_samples_ddim = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||
|
||||
for i, x_sample in enumerate(x_samples_ddim):
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
@ -111,6 +111,7 @@ class Options:
|
||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"samples_save": OptionInfo(True, "Save indiviual samples"),
|
||||
"samples_format": OptionInfo('png', 'File format for individual samples'),
|
||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||
"grid_save": OptionInfo(True, "Save image grids"),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||
|
Loading…
Reference in New Issue
Block a user