make existing script loading and new preload code use same code for loading modules
limit extension preload scripts to just one file named preload.py
This commit is contained in:
parent
e5690d0bf2
commit
a1a376331c
@ -1,7 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
import git
|
||||
|
||||
@ -85,23 +84,3 @@ def list_extensions():
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||
extensions.append(extension)
|
||||
|
||||
|
||||
def preload_extensions(parser):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
path = os.path.join(extensions_dir, dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
for file in os.listdir(path):
|
||||
if "preload.py" in file:
|
||||
full_file = os.path.join(path, file)
|
||||
print(f"Got preload file: {full_file}")
|
||||
|
||||
try:
|
||||
ext = SourceFileLoader("preload", full_file).load_module()
|
||||
parser = ext.preload(parser)
|
||||
except Exception as e:
|
||||
print(f"Exception preloading script: {e}")
|
||||
return parser
|
34
modules/script_loading.py
Normal file
34
modules/script_loading.py
Normal file
@ -0,0 +1,34 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def load_module(path):
|
||||
with open(path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
compiled = compile(text, path, 'exec')
|
||||
module = ModuleType(os.path.basename(path))
|
||||
exec(compiled, module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def preload_extensions(extensions_dir, parser):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||
if not os.path.isfile(preload_script):
|
||||
continue
|
||||
|
||||
try:
|
||||
module = load_module(preload_script)
|
||||
if hasattr(module, 'preload'):
|
||||
module.preload(parser)
|
||||
|
||||
except Exception:
|
||||
print(f"Error running preload() for {preload_script}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
@ -6,7 +6,7 @@ from collections import namedtuple
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks, extensions
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
@ -161,13 +161,7 @@ def load_scripts():
|
||||
sys.path = [scriptfile.basedir] + sys.path
|
||||
current_basedir = scriptfile.basedir
|
||||
|
||||
with open(scriptfile.path, "r", encoding="utf8") as file:
|
||||
text = file.read()
|
||||
|
||||
from types import ModuleType
|
||||
compiled = compile(text, scriptfile.path, 'exec')
|
||||
module = ModuleType(scriptfile.filename)
|
||||
exec(compiled, module.__dict__)
|
||||
module = script_loading.load_module(scriptfile.path)
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
if type(script_class) == type and issubclass(script_class, Script):
|
||||
@ -328,19 +322,13 @@ class ScriptRunner:
|
||||
|
||||
def reload_sources(self, cache):
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
with open(script.filename, "r", encoding="utf8") as file:
|
||||
args_from = script.args_from
|
||||
args_to = script.args_to
|
||||
filename = script.filename
|
||||
text = file.read()
|
||||
|
||||
from types import ModuleType
|
||||
|
||||
module = cache.get(filename, None)
|
||||
if module is None:
|
||||
compiled = compile(text, filename, 'exec')
|
||||
module = ModuleType(script.filename)
|
||||
exec(compiled, module.__dict__)
|
||||
module = script_loading.load_module(script.filename)
|
||||
cache[filename] = module
|
||||
|
||||
for key, script_class in module.__dict__.items():
|
||||
|
@ -3,7 +3,6 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
@ -15,7 +14,7 @@ import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers, sd_models, localization, sd_vae, extensions
|
||||
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
|
||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
|
||||
extensions.preload_extensions(parser)
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user