Unverified Commit 05ef0b1c authored by lllyasviel's avatar lllyasviel Committed by GitHub
Browse files

safetensor ipadapters (#2241)

Showing with 16 additions and 1 deletion
+16 -1
......@@ -135,6 +135,21 @@ def build_model_by_guess(state_dict, unet, model_path):
prefix_replace["adapter."] = ""
state_dict = state_dict_prefix_replace(state_dict, prefix_replace)
if any('image_proj.' in x for x in state_dict.keys()) and any('ip_adapter.' in x for x in state_dict.keys()): # safetensor ipadapters
st_model = {"image_proj": {}, "ip_adapter": {}}
for key in state_dict.keys():
if key.startswith("image_proj."):
st_model["image_proj"][key.replace("image_proj.", "")] = state_dict[key]
elif key.startswith("ip_adapter."):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = state_dict[key]
# sort keys
model = {"image_proj": st_model["image_proj"], "ip_adapter": {}}
sorted_keys = sorted(st_model["ip_adapter"].keys(), key=lambda x: int(x.split(".")[0]))
for key in sorted_keys:
model["ip_adapter"][key] = st_model["ip_adapter"][key]
state_dict = model
del st_model
model_has_shuffle_in_filename = 'shuffle' in Path(os.path.abspath(model_path)).stem.lower()
state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("adapter.", ""): v for k, v in state_dict.items()}
......
version_flag = 'v1.1.415'
version_flag = 'v1.1.416'
from scripts.logging import logger
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment