mirror of
https://github.com/kodxana/madiator-docker-runpod.git
synced 2024-12-04 23:40:13 +01:00
607 lines
27 KiB
Python
607 lines
27 KiB
Python
import os
|
|
import requests
|
|
from urllib.parse import unquote, urlparse
|
|
from tqdm import tqdm
|
|
import json
|
|
import re
|
|
import time
|
|
import math
|
|
### model_utils-v0.2 by lutzapps, Oct 30th 2024 ###
|
|
# lutzapps - modify for new shared_models module and overwrite for this module
|
|
from utils.shared_models import (ensure_shared_models_folders, update_model_symlinks, SHARED_MODELS_DIR)
|
|
from utils.websocket_utils import send_websocket_message, active_websockets
|
|
|
|
#SHARED_MODELS_DIR = '/workspace/shared_models' # this global var is now owned by the 'shared_models' module
|
|
|
|
# lutzapps - modify this CivitAI model_type mapping to the new SHARED_MODEL_FOLDERS map
|
|
MODEL_TYPE_MAPPING = {
|
|
# CivitAI-Modeltype: SHARED_MODEL_FOLDERS
|
|
'Checkpoint': 'ckpt', #'Stable-diffusion', # not clear name for model_type
|
|
'LORA': 'loras', #'Lora', # now lowercase and plural
|
|
'LoCon': 'loras', #'Lora', # now lowercase and plural
|
|
'TextualInversion': 'embeddings',
|
|
'VAE': 'vae', #'VAE', # now lowercase
|
|
'Hypernetwork': 'hypernetworks',
|
|
'AestheticGradient': 'embeddings', #'aesthetic_embeddings', # store together with "embeddings"
|
|
'ControlNet': 'controlnet',
|
|
'Upscaler': 'upscale_models' #'ESRGAN' # there are probably other upscalers not based on ESRGAN
|
|
}
|
|
|
|
def ensure_shared_folder_exists():
|
|
# lutzapps - replace with new shared_models code
|
|
#for folder in ['Stable-diffusion', 'Lora', 'embeddings', 'VAE', 'hypernetworks', 'aesthetic_embeddings', 'controlnet', 'ESRGAN']:
|
|
# os.makedirs(os.path.join(SHARED_MODELS_DIR, folder), exist_ok=True)
|
|
ensure_shared_models_folders()
|
|
|
|
def check_civitai_url(url:str) -> tuple[bool, bool, str, str]:
|
|
# https://education.civitai.com/civitais-guide-to-downloading-via-api/
|
|
|
|
civitai_domain = "civitai.com"
|
|
|
|
try:
|
|
url = url.lower() # convert to lcase
|
|
|
|
### sample url for normal civitai_url
|
|
# url = "https://civitai.com/models/618692?modelVersionId=691639"
|
|
url_pattern_models = r"https://civitai\.com/models/(\d+)(?:\?modelversionid=(\d+))?"
|
|
|
|
### sample url for civitai_api_url
|
|
# api_url = "https://civitai.com/api/download/models/12345?type=Model&format=SafeTensor&size=pruned&fp=fp16&token=YOUR_TOKEN_HERE"
|
|
url_pattern_api_models = r"https://civitai\.com/api/download/models/(\d+)"
|
|
|
|
is_civitai = (civitai_domain in url) # any civitai url
|
|
is_civitai_api = (is_civitai and ("/api/" in url)) # only civitai_api_url
|
|
# refine the is_civitai to be only true, if NOT already a civit_api_url
|
|
is_civitai = (is_civitai and not is_civitai_api)
|
|
|
|
model_id = None
|
|
version_id = None
|
|
|
|
if not is_civitai:
|
|
return is_civitai, is_civitai_api, model_id, version_id # False, False, None, None
|
|
|
|
if is_civitai_api:
|
|
match = re.match(url_pattern_api_models, url)
|
|
if match:
|
|
model_id = match.group(1) # e.g., "619777"
|
|
return is_civitai, is_civitai_api, model_id, version_id # False, True, model_id, None
|
|
|
|
# only case left (is_civitai == True)
|
|
match = re.match(url_pattern_models, url)
|
|
if match:
|
|
model_id = match.group(1) # e.g., "618692"
|
|
version_id = match.group(2) # e.g., "691639" or None if not present
|
|
return is_civitai, is_civitai_api, model_id, version_id # True, False, model_id, version_id
|
|
|
|
except Exception as e:
|
|
print(f"ERROR in Url parsing for CivitAI url: {url}: {str(e)}")
|
|
|
|
return False, False, None, None
|
|
|
|
# old version - disabled
|
|
def check_civitai_url_v0(url):
|
|
prefix = "civitai.com"
|
|
try:
|
|
if prefix in url:
|
|
if "civitai.com/api/download" in url:
|
|
version_id = url.strip("/").split("/")[-1]
|
|
return False, True, None, int(version_id)
|
|
|
|
subpath = url[url.find(prefix) + len(prefix):].strip("/")
|
|
url_parts = subpath.split("?")
|
|
if len(url_parts) > 1:
|
|
model_id = url_parts[0].split("/")[1]
|
|
version_id = url_parts[1].split("=")[1]
|
|
return True, False, int(model_id), int(version_id)
|
|
else:
|
|
model_id = subpath.split("/")[1]
|
|
return True, False, int(model_id), None
|
|
except (ValueError, IndexError):
|
|
print("Error parsing Civitai model URL")
|
|
return False, False, None, None
|
|
|
|
|
|
def check_huggingface_url(url:str) -> tuple[bool, str, str, str, str]:
|
|
huggingface_domain = "huggingface.co" # matches both ".com" and ".co"
|
|
|
|
try:
|
|
url = url.lower() # convert to lcase
|
|
|
|
### sample url for normal huggingface_url
|
|
# url1 = "https://huggingface.com/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
|
# url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/subfolder1/subfolder2/vae-ft-mse-840000-ema-pruned.safetensors"
|
|
url_pattern = r"https://huggingface\.(?:co|com)/([\w.-]+/[\w.-]+)/(resolve|blob)/(main|tree)(?:/([\w./-]+))?/([\w.-]+)$"
|
|
|
|
is_huggingface = (huggingface_domain in url)
|
|
|
|
repo_id = None
|
|
filename = None
|
|
folder_name = None
|
|
branch_name = None
|
|
|
|
match = re.match(url_pattern, url)
|
|
if match:
|
|
repo_id = match.group(1) # e.g., "stabilityai/sd-vae-ft-mse-original"
|
|
folder_name = match.group(2) # e.g., "resolve" or "blob" (normally only the "resolve" folder is the download link of the model)
|
|
branch_name = match.group(3) # e.g., "main" or "tree"
|
|
folders = match.group(4) # e.g., "subfolder1/subfolder2" or None if not present
|
|
filename = match.group(5) # e.g., "vae-ft-mse-840000-ema-pruned.safetensors"
|
|
|
|
except Exception as e:
|
|
error_msg = f"ERROR in Url parsing for HuggingFace url: {url}: {str(e)}"
|
|
print(error_msg)
|
|
|
|
return is_huggingface, repo_id, filename, folder_name, branch_name
|
|
|
|
|
|
# old version - disabled
|
|
def check_huggingface_url_v0(url):
|
|
parsed_url = urlparse(url)
|
|
if parsed_url.netloc not in ["huggingface.co", "huggingface.com"]:
|
|
return False, None, None, None, None
|
|
|
|
path_parts = [p for p in parsed_url.path.split("/") if p]
|
|
if len(path_parts) < 5 or (path_parts[2] != "resolve" and path_parts[2] != "blob"):
|
|
return False, None, None, None, None
|
|
|
|
repo_id = f"{path_parts[0]}/{path_parts[1]}"
|
|
branch_name = path_parts[3]
|
|
remaining_path = "/".join(path_parts[4:])
|
|
folder_name = os.path.dirname(remaining_path) if "/" in remaining_path else None
|
|
filename = unquote(os.path.basename(remaining_path))
|
|
|
|
return True, repo_id, filename, folder_name, branch_name
|
|
|
|
def download_model(url, model_name, model_type, civitai_token=None, hf_token=None, version_id=None, file_index=None) -> tuple[bool, str]:
|
|
ensure_shared_folder_exists()
|
|
is_civitai, is_civitai_api, model_id, _ = check_civitai_url(url)
|
|
is_huggingface, repo_id, hf_filename, hf_folder_name, hf_branch_name = check_huggingface_url(url) # TODO: double calls (need to)
|
|
|
|
if is_civitai or is_civitai_api:
|
|
if not civitai_token:
|
|
return False, "Civitai token is required for downloading from Civitai"
|
|
success, message = download_civitai_model(url, model_name, model_type, civitai_token, version_id, file_index)
|
|
elif is_huggingface:
|
|
success, message = download_huggingface_model(url, model_name, model_type, repo_id, hf_filename, hf_folder_name, hf_branch_name, hf_token)
|
|
else:
|
|
return False, "Unsupported URL"
|
|
|
|
if success:
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': 100,
|
|
'stage': 'Complete',
|
|
'message': 'Download complete and symlinks updated'
|
|
})
|
|
|
|
return success, message
|
|
|
|
# lutzapps - added SHA256 checks for already existing ident and downloaded HuggingFace model
|
|
def download_civitai_model(url, model_name, model_type, civitai_token, version_id=None, file_index=None) -> tuple[bool, str]:
|
|
try:
|
|
|
|
# Error: Exception downloading from CivitAI:
|
|
# cannot access local variable 'civitai_file' where it is not associated with a value
|
|
#
|
|
# example of model Flux Dev Model: https://civitai.com/models/618692?modelVersionId=691639
|
|
|
|
# examine the url, and extract model_id and url_version_id
|
|
is_civitai, is_civitai_api, model_id, url_version_id = check_civitai_url(url)
|
|
from app import (load_civitai_token)
|
|
# use provided token or try to read the token from ENV var or stored file
|
|
if not civitai_token:
|
|
civitai_token = load_civitai_token()
|
|
|
|
# use the civitai token for Authorization
|
|
headers = {"Authorization": f"Bearer {civitai_token}"} if civitai_token else {}
|
|
|
|
if is_civitai_api:
|
|
api_url = f"https://civitai.com/api/v1/model-versions/{url_version_id}"
|
|
else:
|
|
api_url = f"https://civitai.com/api/v1/models/{model_id}"
|
|
|
|
# get the model data from the civitai repository
|
|
response = requests.get(api_url, headers=headers)
|
|
#response.raise_for_status()
|
|
if response.status_code != 200:
|
|
raise Exception(f"Failed to get model info: {response.text}")
|
|
|
|
model_data = response.json()
|
|
|
|
civitai_model_type = model_data['type']
|
|
# map the civitai model type to our internal SHARED_MODEL_FOLDERS model_type
|
|
model_type = MODEL_TYPE_MAPPING.get(civitai_model_type, 'Stable-diffusion') # default is 'ckpt'
|
|
|
|
if is_civitai_api:
|
|
version_data = model_data
|
|
model_data = version_data['model']
|
|
else:
|
|
|
|
versions = model_data.get('modelVersions', [])
|
|
|
|
if not version_id:
|
|
version_id = url_version_id
|
|
|
|
if version_id == None: # no version_id specified, let the user pick from available versions
|
|
|
|
return True, {
|
|
'choice_required': {
|
|
'type': 'version',
|
|
'model_id': model_id,
|
|
'versions': versions
|
|
}
|
|
}
|
|
|
|
# Get the selected version
|
|
version_data = next(
|
|
(v for v in versions if str(v['id']) == str(version_id)), None)
|
|
|
|
# if version_id:
|
|
# version_data = next((v for v in versions if v['id'] == version_id), None)
|
|
# elif url_version_id:
|
|
# version_data = next((v for v in versions if v['id'] == url_version_id), None)
|
|
# else:
|
|
# version_data = versions[0]
|
|
|
|
if not version_data:
|
|
return False, f"Version ID {version_id or url_version_id} not found for this model."
|
|
|
|
files = version_data['files']
|
|
if file_index is not None and 0 <= file_index < len(files):
|
|
civitai_file = files[file_index]
|
|
elif len(files) > 1:
|
|
# If there are multiple files and no specific file was chosen, ask the user to choose
|
|
# extended for more info needed - 'metadata': {'format': 'SafeTensor', 'size': 'full', 'fp': 'fp32'},
|
|
file_options = [{'name': f['name'], 'sizeKB': f['sizeKB'], 'type': f['type'], 'format': f['metadata']['format'], 'size': f['metadata']['size'], 'fp': f['metadata']['fp']} for f in files]
|
|
return True, {
|
|
'choice_required': {
|
|
'type': 'file',
|
|
'model_id': model_id,
|
|
'version_id': version_data['id'],
|
|
'files': file_options
|
|
}
|
|
}
|
|
else:
|
|
civitai_file = files[0] # that is the metadata civitai_file
|
|
|
|
download_url = civitai_file['downloadUrl']
|
|
if not model_name:
|
|
model_name = civitai_file['name']
|
|
|
|
model_path = os.path.join(SHARED_MODELS_DIR, model_type, model_name)
|
|
|
|
platformInfo = {
|
|
"platform_name": 'civitai',
|
|
"civitai_file": civitai_file # civitai_file metadata dictionary
|
|
}
|
|
# call shared function for "huggingface" and "civitai" for SHA256 support and "Model Downloader UI" extended support
|
|
download_sha256_hash, found_ident_local_model, message = get_modelfile_hash_and_ident_existing_modelfile_exists(
|
|
model_name, model_type, model_path, # pass local workspace vars, then platform specific vars as dictionary
|
|
platformInfo) # [str, bool, str]
|
|
|
|
if found_ident_local_model:
|
|
return True, message
|
|
|
|
# model_path does NOT exist - run with original code
|
|
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
|
|
# lutzapps - add SHA256 check for download_sha256_hash is handled after download finished in download_file()
|
|
return download_file(download_url, download_sha256_hash, model_path, headers) # [bool, str]
|
|
|
|
except Exception as e: # requests.RequestException as e:
|
|
|
|
return False, f"Exception downloading from CivitAI: {str(e)}"
|
|
|
|
|
|
# lutzapps - calculate the SHA256 hash string of a file
|
|
def get_sha256_hash_from_file(file_path:str) -> tuple[bool, str]:
|
|
import hashlib # support SHA256 checks
|
|
|
|
try:
|
|
sha256_hash = hashlib.sha256()
|
|
|
|
with open(file_path, "rb") as f:
|
|
# read and update hash string value in blocks of 4K
|
|
for byte_block in iter(lambda: f.read(4096), b""):
|
|
sha256_hash.update(byte_block)
|
|
|
|
return True, sha256_hash.hexdigest().lower()
|
|
|
|
except Exception as e:
|
|
return False, str(e)
|
|
|
|
|
|
# lutzapps - support SHA256 Hash check of already locally existing modelfile against its metadata hash before downloading is needed
|
|
# shared function for "huggingface" and "civitai" called by download_huggingface_model() and download_civitai_model()
|
|
def get_modelfile_hash_and_ident_existing_modelfile_exists(model_name:str, model_type:str, model_path:str, platformInfo:dict) -> tuple[bool, str, str]:
|
|
try:
|
|
# update (and remember) the selected index of the modelType select list of the "Model Downloader"
|
|
message = f"Select the ModelType '{model_type}' to download"
|
|
print(message)
|
|
|
|
send_websocket_message('extend_ui_helper', {
|
|
'cmd': 'selectModelType',
|
|
'model_type': f'{model_type}', # e.g. "loras" or "vae"
|
|
'message': message
|
|
} )
|
|
|
|
# get the SHA256 hash - used for compare against existing or downloaded model
|
|
platform_name = platformInfo['platform_name'].lower() # currently "civitai" or "huggingface", but could be extendend
|
|
print(f"\nPlatform: {platform_name}")
|
|
|
|
match platform_name:
|
|
case "huggingface":
|
|
# get the platform-specific passed variables for "huggingface"
|
|
hf_token = platformInfo['hf_token']
|
|
repo_id = platformInfo['repo_id']
|
|
hf_filename = platformInfo['hf_filename']
|
|
|
|
#from huggingface_hub import hf_hub_download
|
|
# lutzapps - to get SHA256 hash from model
|
|
from huggingface_hub import (
|
|
# HfApi, # optional when not calling globally
|
|
get_paths_info #list_files_info #DEPRECATED/MISSING: list_files_info => get_paths_info
|
|
)
|
|
from huggingface_hub.hf_api import (
|
|
RepoFile, RepoFolder, BlobLfsInfo
|
|
)
|
|
|
|
## optionally configure a HfApi client instead of calling globally
|
|
# hf_api = HfApi(
|
|
# endpoint = "https://huggingface.co", # can be a Private Hub endpoint
|
|
# token = hf_token, # token is not persisted on the machine
|
|
# )
|
|
|
|
print(f"getting SHA256 Hash for '{model_name}' from repo {repo_id}/{hf_filename}")
|
|
# HfApi.list_files_info deprecated -> HfApi.get_paths_info (runs into exception, as connot be imported as missing)
|
|
#files_info = hf_api.list_files_info(repo_id, hf_filename, expand=True)
|
|
#paths_info = hf_api.get_paths_info(repo_id, hf_filename, expand=True) # use via HfApi
|
|
paths_info = get_paths_info(repo_id, hf_filename, expand=True) # use global (works fine)
|
|
|
|
repo_file = paths_info[0] # RepoFile or RepoFolder class instance
|
|
# check for RepoFolder or NON-LFS
|
|
if isinstance(repo_file, RepoFolder):
|
|
raise NotImplementedError("Downloading a folder is not implemented.")
|
|
if not repo_file.lfs:
|
|
raise NotImplementedError("Copying a non-LFS file is not implemented.")
|
|
|
|
lfs = repo_file.lfs # BlobLfsInfo class instance
|
|
download_sha256_hash = lfs.sha256.lower()
|
|
|
|
print(f"Metadata from RepoFile LFS '{repo_file.rfilename}'")
|
|
print(f"SHA256: {download_sha256_hash}")
|
|
|
|
case "civitai":
|
|
# get the platform-specific passed variables for "civitai"
|
|
civitai_file = platformInfo['civitai_file'] # civitai_file metadata dictionary
|
|
|
|
# get the SHA256 hash - used for compare against existing or downloaded model
|
|
download_sha256_hash = civitai_file['hashes']['SHA256'] # civitai_file = passed file
|
|
|
|
### END platform specific code
|
|
|
|
# check if model file already exists
|
|
if not os.path.exists(model_path):
|
|
message = f"No local model '{os.path.basename(model_path)}' installed"
|
|
print(message)
|
|
|
|
return download_sha256_hash, False, message
|
|
|
|
message = f"Model already exists: {os.path.basename(model_path)}, SHA256 check..."
|
|
print(message)
|
|
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': 0, # ugly
|
|
'stage': 'Downloading',
|
|
'message': message
|
|
})
|
|
|
|
# check if existing model is ident with model to download
|
|
# this can *take a while* for big models, but even better than to unnecessarily redownload the model
|
|
successfull_HashGeneration, model_sha256_hash = get_sha256_hash_from_file(model_path)
|
|
# if NOT successful, the hash contains the Exception
|
|
print(f"SHA256 hash generated from local file: '{model_path}'\n{model_sha256_hash}")
|
|
|
|
if successfull_HashGeneration and model_sha256_hash.lower() == download_sha256_hash.lower():
|
|
message = f"Existing and ident model already found for '{os.path.basename(model_path)}'"
|
|
print(message)
|
|
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': 100,
|
|
'stage': 'Complete',
|
|
'message': message
|
|
})
|
|
|
|
return download_sha256_hash, successfull_HashGeneration, message
|
|
|
|
else:
|
|
if successfull_HashGeneration: # the generated SHA256 file model Hash did not match against the metadata hash
|
|
message = f"Local installed model '{os.path.basename(model_path)}' has DIFFERENT \nSHA256: {model_sha256_hash}"
|
|
print(message)
|
|
|
|
return download_sha256_hash, False, message
|
|
|
|
|
|
else: # NOT successful, the hash contains the Exception
|
|
error_msg = model_sha256_hash
|
|
error_msg = f"Exception occured while generating the SHA256 hash for '{model_path}':\n{error_msg}"
|
|
print(error_msg)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Exception when downloading from {platform_name}: {str(e)}"
|
|
|
|
return "", False, error_msg # hash, identfile, message
|
|
|
|
|
|
# lutzapps - added SHA256 checks for already existing ident and downloaded HuggingFace model
|
|
def download_huggingface_model(url, model_name, model_type, repo_id, hf_filename, hf_folder_name, hf_branch_name, hf_token=None) -> tuple[bool, str]:
|
|
try:
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
if not model_name:
|
|
model_name = hf_filename
|
|
|
|
model_path = os.path.join(SHARED_MODELS_DIR, model_type, model_name)
|
|
|
|
platformInfo = {
|
|
"platform_name": 'huggingface',
|
|
"hf_token": hf_token,
|
|
"repo_id": repo_id,
|
|
"hf_filename": hf_filename
|
|
}
|
|
# call shared function for "huggingface" and "civitai" for SHA256 support and "Model Downloader UI" extended support
|
|
download_sha256_hash, found_ident_local_model, message = get_modelfile_hash_and_ident_existing_modelfile_exists(
|
|
model_name, model_type, model_path, # pass local workspace vars, then platform specific vars as dictionary
|
|
platformInfo) # [str, bool, str]
|
|
|
|
if found_ident_local_model:
|
|
return True, message
|
|
|
|
# model_path does NOT exist - run with original code
|
|
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': 0,
|
|
'stage': 'Downloading',
|
|
'message': f'Starting download from Hugging Face: {repo_id}'
|
|
})
|
|
|
|
kwargs = {
|
|
'repo_id': repo_id,
|
|
'filename': hf_filename,
|
|
'subfolder': hf_folder_name,
|
|
'revision': hf_branch_name,
|
|
'local_dir': os.path.dirname(model_path)
|
|
#'local_dir_use_symlinks': False # deprecated, should be removed
|
|
}
|
|
if hf_token:
|
|
kwargs['token'] = hf_token
|
|
|
|
file_path = hf_hub_download(**kwargs) ### HF_DOWNLOAD_START
|
|
### HF_DOWNLOAD COMPLETE
|
|
|
|
# SHA256 Hash checks of downloaded modelfile against its metadata hash
|
|
# call shared function for "huggingface" and "civitai" for SHA256 support and "Model Downloader UI" extended support
|
|
return check_downloaded_modelfile(file_path, download_sha256_hash, "huggingface") # [bool, str]
|
|
|
|
except Exception as e:
|
|
|
|
return False, f"Exception when downloading from 'HuggingFace': {str(e)}"
|
|
|
|
|
|
# lutzapps - added SHA256 check for downloaded CivitAI model
|
|
def download_file(url, download_sha256_hash, file_path, headers=None) -> tuple[bool, str]:
|
|
try:
|
|
response = requests.get(url, stream=True, headers=headers)
|
|
response.raise_for_status()
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
block_size = 8192
|
|
downloaded_size = 0
|
|
start_time = time.time()
|
|
|
|
with open(file_path, 'wb') as file: ### CIVITAI_DOWNLOAD
|
|
for data in response.iter_content(block_size):
|
|
size = file.write(data)
|
|
downloaded_size += size
|
|
current_time = time.time()
|
|
elapsed_time = current_time - start_time
|
|
|
|
if elapsed_time > 0:
|
|
speed = downloaded_size / elapsed_time
|
|
percentage = (downloaded_size / total_size) * 100 if total_size > 0 else 0
|
|
eta = (total_size - downloaded_size) / speed if speed > 0 else 0
|
|
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': round(percentage, 2),
|
|
'speed': f"{speed / (1024 * 1024):.2f} MB/s",
|
|
'eta': int(eta),
|
|
'stage': 'Downloading',
|
|
'message': f'Downloaded {format_size(downloaded_size)} / {format_size(total_size)}'
|
|
})
|
|
|
|
### CIVITAI_DOWNLOAD COMPLETE
|
|
|
|
# SHA256 Hash checks of downloaded modelfile against its metadata hash
|
|
# call shared function for "huggingface" and "civitai" for SHA256 support and "Model Downloader UI" extended support
|
|
return check_downloaded_modelfile(file_path, download_sha256_hash, "civitai") # [bool, str]
|
|
|
|
except Exception as e:
|
|
return False, f"Exception when downloading from CivitAI: {str(e)}"
|
|
|
|
# lutzapps - SHA256 Hash checks of downloaded modelfile against its metadata hash
|
|
# shared function for "huggingface" and "civitai" for SHA256 support and "Model Downloader UI" extended support
|
|
def check_downloaded_modelfile(model_path:str, download_sha256_hash:str, platform_name:str) -> tuple[bool, str]:
|
|
try:
|
|
# lutzapps - SHA256 check for download_sha256_hash
|
|
if download_sha256_hash == "":
|
|
|
|
return False, f"Downloaded model could not be verified with Metadata, no SHA256 hash found on '{platform_name}'"
|
|
|
|
# check if downloaded local model file is ident with HF model download_sha256_hash metadata
|
|
# this can take a while for big models, but even better than to have a corrupted model
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': 90, # change back from 100 to 90 (ugly)
|
|
'stage': 'Complete', # leave it as 'Complete' as this "clears" SPEED/ETA Divs
|
|
'message': f'SHA256 Check for Model: {os.path.basename(model_path)}'
|
|
})
|
|
|
|
successfull_HashGeneration, model_sha256_hash = get_sha256_hash_from_file(model_path)
|
|
if successfull_HashGeneration and model_sha256_hash.lower() == download_sha256_hash.lower():
|
|
send_websocket_message('model_download_progress', {
|
|
'percentage': 100,
|
|
'stage': 'Complete',
|
|
'message': f'Download complete: {os.path.basename(model_path)}'
|
|
})
|
|
|
|
update_model_symlinks() # create symlinks for this new downloaded model for all installed apps
|
|
|
|
return True, f"Successfully downloaded (SHA256 checked, and symlinked) '{os.path.basename(model_path)}' from {platform_name}"
|
|
|
|
else:
|
|
if successfull_HashGeneration: # the generated SHA256 file model Hash did not match against the metadata hash
|
|
message = f"The downloaded model '{os.path.basename(model_path)}' has DIFFERENT \nSHA256: {model_sha256_hash} as stored on {platform_name}\nFile is possibly corrupted and was DELETED!"
|
|
print(message)
|
|
|
|
os.remove(model_path) # delete corrupted, downloaded file
|
|
|
|
return download_sha256_hash, False, message
|
|
|
|
else: # NOT successful, the hash contains the Exception
|
|
error_msg = model_sha256_hash
|
|
error_msg = f"Exception occured while generating the SHA256 hash for '{model_path}':\n{error_msg}"
|
|
print(error_msg)
|
|
|
|
except Exception as e:
|
|
error_msg = f"Exception when downloading from {platform_name}: {str(e)}"
|
|
|
|
return False, error_msg
|
|
|
|
|
|
# smaller helper functions
|
|
def get_civitai_file_size(url, token):
|
|
headers = {'Authorization': f'Bearer {token}'}
|
|
try:
|
|
response = requests.head(url, headers=headers, allow_redirects=True)
|
|
return int(response.headers.get('content-length', 0))
|
|
except:
|
|
return None
|
|
|
|
def get_huggingface_file_size(repo_id, filename, folder_name, branch_name, token):
|
|
from huggingface_hub import hf_hub_url, HfApi
|
|
try:
|
|
api = HfApi()
|
|
file_info = api.hf_hub_url(repo_id, filename, subfolder=folder_name, revision=branch_name)
|
|
response = requests.head(file_info, headers={'Authorization': f'Bearer {token}'} if token else None)
|
|
return int(response.headers.get('content-length', 0))
|
|
except:
|
|
return None
|
|
|
|
def format_size(size_in_bytes):
|
|
if size_in_bytes == 0:
|
|
return "0 B"
|
|
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
|
|
i = int(math.floor(math.log(size_in_bytes, 1024)))
|
|
p = math.pow(1024, i)
|
|
s = round(size_in_bytes / p, 2)
|
|
return f"{s} {size_name[i]}"
|