mirror of
https://github.com/kodxana/madiator-docker-runpod.git
synced 2024-12-12 09:16:34 +01:00
753 lines
No EOL
31 KiB
Python
753 lines
No EOL
31 KiB
Python
from gevent import monkey
|
|
monkey.patch_all()
|
|
|
|
import os
|
|
import threading
|
|
import time
|
|
from flask import Flask, render_template, jsonify, request
|
|
from flask_sock import Sock
|
|
import re
|
|
import json
|
|
import signal
|
|
import shutil
|
|
import subprocess
|
|
import traceback
|
|
import logging
|
|
|
|
from utils.ssh_utils import setup_ssh, save_ssh_password, get_ssh_password, check_ssh_config, SSH_CONFIG_FILE
|
|
from utils.filebrowser_utils import configure_filebrowser, start_filebrowser, stop_filebrowser, get_filebrowser_status, FILEBROWSER_PORT
|
|
from utils.app_utils import (
|
|
run_app, run_bash_cmd, update_process_status, check_app_directories, get_app_status,
|
|
force_kill_process_by_name, find_and_kill_process_by_port, update_webui_user_sh,
|
|
fix_custom_nodes, is_process_running, install_app,
|
|
get_available_venvs, get_bkohya_launch_url, init_app_status, # lutzapps - support dynamic generated gradio url and venv_size checks
|
|
delete_app_installation, check_app_installation, refresh_app_installation # lutzapps - new app features for check and refresh app
|
|
)
|
|
from utils.websocket_utils import send_websocket_message, active_websockets
|
|
from utils.app_configs import get_app_configs, add_app_config, remove_app_config, app_configs, DEBUG_SETTINGS, APP_CONFIGS_MANIFEST_URL
|
|
from utils.model_utils import download_model, check_civitai_url, check_huggingface_url, format_size #, SHARED_MODELS_DIR # lutzapps - SHARED_MODELS_DIR is owned by shared_models module now
|
|
|
|
# lutzapps
|
|
LOCAL_DEBUG = os.environ.get('LOCAL_DEBUG', 'False') # support local browsing for development/debugging
|
|
|
|
# use the new "utils.shared_models" module for app model sharing
|
|
from utils.shared_models import (
|
|
update_model_symlinks, # main WORKER function (file/folder symlinks, Fix/remove broken symlinks, pull back local app models into shared)
|
|
SHARED_MODELS_DIR, SHARED_MODEL_FOLDERS, SHARED_MODEL_FOLDERS_FILE, ensure_shared_models_folders,
|
|
APP_INSTALL_DIRS, APP_INSTALL_DIRS_FILE, init_app_install_dirs, # APP_INSTALL_DIRS dict/file/function
|
|
MAP_APPS, sync_with_app_configs_install_dirs, # internal MAP_APPS dict and sync function
|
|
SHARED_MODEL_APP_MAP, SHARED_MODEL_APP_MAP_FILE, init_shared_model_app_map # SHARED_MODEL_APP_MAP dict/file/function
|
|
)
|
|
# the "update_model_symlinks()" function replaces the app.py function with the same same
|
|
# and redirects to same function name "update_model_symlinks()" in the new "utils.shared_models" module
|
|
#
|
|
# this function does ALL the link management, including deleting "stale" symlinks,
|
|
# so the "recreate_symlinks()" function will be also re-routed to the
|
|
# "utils.shared_models.update_model_symlinks()" function (see CHANGE #3a and CHANGE #3b)
|
|
|
|
# the "ensure_shared_models_folders()" function will be called from app.py::create_shared_folders(),
|
|
# and replaces this function (see CHANGE #3)
|
|
|
|
# the "init_app_install_dirs() function initializes the
|
|
# global module 'APP_INSTALL_DIRS' dict: { 'app_name': 'app_installdir' }
|
|
# which does a default mapping from app code or (if exists) from external JSON 'APP_INSTALL_DIRS_FILE' file
|
|
# NOTE: this APP_INSTALL_DIRS dict is temporary synced with the 'app_configs' dict (see next)
|
|
|
|
# the "sync_with_app_configs_install_dirs() function syncs the 'APP_INSTALL_DIRS' dict's 'app_installdir' entries
|
|
# from the 'app_configs' dict's 'app_path' entries and uses the MAP_APPS dict for this task
|
|
# NOTE: this syncing is a temporary solution, and needs to be better integrated later
|
|
|
|
# the "init_shared_model_app_map()" function initializes the
|
|
# global module 'SHARED_MODEL_APP_MAP' dict: 'model_type' -> 'app_name:app_model_dir' (relative path)
|
|
# which does a default mapping from app code or (if exists) from external JSON 'SHARED_MODEL_APP_MAP_FILE' file
|
|
|
|
"""
|
|
from flask import Flask
|
|
import logging
|
|
|
|
logging.basicConfig(filename='record.log', level=logging.DEBUG)
|
|
app = Flask(__name__)
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True)
|
|
"""
|
|
|
|
#logging.basicConfig(filename='better-ai-launcher.log', level=logging.INFO) # CRITICAL, ERROR, WARNING, INFO, DEBUG
|
|
|
|
app = Flask(__name__)
|
|
sock = Sock(app)
|
|
|
|
RUNPOD_POD_ID = os.environ.get('RUNPOD_POD_ID', 'localhost')
|
|
|
|
running_processes = {}
|
|
|
|
app_configs = get_app_configs()
|
|
|
|
SETTINGS_FILE = '/workspace/.app_settings.json'
|
|
|
|
CIVITAI_TOKEN_FILE = '/workspace/.civitai_token'
|
|
HF_TOKEN_FILE = '/workspace/.hf_token' # lutzapps - added support for HF_TOKEN_FILE
|
|
|
|
|
|
def load_settings():
|
|
if os.path.exists(SETTINGS_FILE):
|
|
with open(SETTINGS_FILE, 'r') as f:
|
|
return json.load(f)
|
|
return {'auto_generate_ssh_password': False}
|
|
|
|
def save_settings(settings):
|
|
with open(SETTINGS_FILE, 'w') as f:
|
|
json.dump(settings, f)
|
|
|
|
def check_running_processes():
|
|
while True:
|
|
for app_name in list(running_processes.keys()):
|
|
update_process_status(app_name, running_processes)
|
|
current_status = get_app_status(app_name, running_processes)
|
|
send_websocket_message('status_update', {app_name: current_status})
|
|
time.sleep(5)
|
|
|
|
@app.route('/')
|
|
def index():
|
|
settings = load_settings()
|
|
|
|
# Determine the current SSH authentication method
|
|
with open(SSH_CONFIG_FILE, 'r') as f:
|
|
ssh_config = f.read()
|
|
current_auth_method = 'key' if 'PasswordAuthentication no' in ssh_config else 'password'
|
|
|
|
# Get the current SSH password if it exists
|
|
ssh_password = get_ssh_password()
|
|
ssh_password_status = 'set' if ssh_password else 'not_set'
|
|
|
|
filebrowser_status = get_filebrowser_status()
|
|
app_status = init_app_status(running_processes)
|
|
|
|
return render_template('index.html',
|
|
apps=app_configs,
|
|
app_status=app_status,
|
|
pod_id=RUNPOD_POD_ID,
|
|
RUNPOD_PUBLIC_IP=os.environ.get('RUNPOD_PUBLIC_IP'),
|
|
RUNPOD_TCP_PORT_22=os.environ.get('RUNPOD_TCP_PORT_22'),
|
|
|
|
# lutzapps - allow localhost Url for unsecure "http" and "ws" WebSockets protocol,
|
|
# according to 'LOCAL_DEBUG' ENV var
|
|
enable_unsecure_localhost=os.environ.get('LOCAL_DEBUG'),
|
|
app_configs_manifest_url=APP_CONFIGS_MANIFEST_URL,
|
|
|
|
settings=settings,
|
|
current_auth_method=current_auth_method,
|
|
ssh_password=ssh_password,
|
|
ssh_password_status=ssh_password_status,
|
|
filebrowser_status=filebrowser_status)
|
|
|
|
@app.route('/start/<app_name>')
|
|
def start_app(app_name):
|
|
dirs_ok, message = check_app_directories(app_name, app_configs)
|
|
if not dirs_ok:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
if app_name in app_configs and get_app_status(app_name, running_processes) == 'stopped':
|
|
# Update webui-user.sh for Forge and A1111
|
|
if app_name in ['bforge', 'ba1111']:
|
|
update_webui_user_sh(app_name, app_configs)
|
|
|
|
command = app_configs[app_name]['command']
|
|
|
|
# bkohya enhancements
|
|
if app_name == 'bkohya':
|
|
# the --noverify flag currently is NOT supported anymore, need to check, in the meantime disable it
|
|
# if DEBUG_SETTINGS['bkohya_noverify']:
|
|
# # Use regex to search & replace command variable to launch bkohya
|
|
# #command = re.sub(r'kohya_gui.py', 'kohya_gui.py --noverify', command)
|
|
# print(f"launch bkohya with patched command '{command}'")
|
|
|
|
if DEBUG_SETTINGS['bkohya_run_tensorboard']: # default == True
|
|
# auto-launch tensorboard together with bkohya app
|
|
app_config = app_configs.get(app_name) # get bkohya app_config
|
|
app_path = app_config['app_path']
|
|
cmd_key = 'run-tensorboard' # read the tensorboard launch command from the 'run-tensorboard' cmd_key
|
|
|
|
### run_app() variant, but need to define as app
|
|
# tensorboard_command = app_config['bash_cmds'][cmd_key] # get the bash_cmd value from app_config
|
|
# message = f"Launch Tensorboard together with kohya_ss: cmd_key='{cmd_key}' ..."
|
|
# print(message)
|
|
# app_name = 'tensorboard'
|
|
# threading.Thread(target=run_app, args=(app_name, tensorboard_command, running_processes)).start()
|
|
|
|
### run_bash_cmd() variant
|
|
#run_bash_cmd(app_config, app_path, cmd_key=cmd_key)
|
|
threading.Thread(target=run_bash_cmd, args=(app_config, app_path, cmd_key)).start()
|
|
|
|
|
|
threading.Thread(target=run_app, args=(app_name, command, running_processes)).start()
|
|
return jsonify({'status': 'started'})
|
|
return jsonify({'status': 'already_running'})
|
|
|
|
@app.route('/stop/<app_name>', methods=['GET'])
|
|
def stop_app(app_name):
|
|
if app_name in running_processes and get_app_status(app_name, running_processes) == 'running':
|
|
try:
|
|
pgid = os.getpgid(running_processes[app_name]['pid'])
|
|
os.killpg(pgid, signal.SIGTERM)
|
|
|
|
for _ in range(10):
|
|
if not is_process_running(running_processes[app_name]['pid']):
|
|
break
|
|
time.sleep(1)
|
|
|
|
if is_process_running(running_processes[app_name]['pid']):
|
|
os.killpg(pgid, signal.SIGKILL)
|
|
|
|
running_processes[app_name]['status'] = 'stopped'
|
|
return jsonify({'status': 'stopped'})
|
|
except ProcessLookupError:
|
|
running_processes[app_name]['status'] = 'stopped'
|
|
return jsonify({'status': 'already_stopped'})
|
|
return jsonify({'status': 'not_running'})
|
|
|
|
@app.route('/status')
|
|
def get_status():
|
|
return jsonify({app_name: get_app_status(app_name, running_processes) for app_name in app_configs})
|
|
|
|
@app.route('/logs/<app_name>')
|
|
def get_logs(app_name):
|
|
if app_name in running_processes:
|
|
return jsonify({'logs': running_processes[app_name]['log'][-100:]})
|
|
return jsonify({'logs': []})
|
|
|
|
# lutzapps - support bkohya gradio url
|
|
@app.route('/get_bkohya_launch_url', methods=['GET'])
|
|
def get_bkohya_launch_url_route():
|
|
command = app_configs['bkohya']['command']
|
|
is_gradio = ("--share" in command.lower()) # gradio share mode
|
|
if is_gradio:
|
|
mode = 'gradio'
|
|
else:
|
|
mode = 'local'
|
|
|
|
launch_url = get_bkohya_launch_url() # get this from the app_utils global BKOHYA_GRADIO_URL, which is polled from the kohya log
|
|
return jsonify({ 'mode': mode, 'url': launch_url }) # used from the index.html:OpenApp() button click function
|
|
|
|
@app.route('/kill_all', methods=['POST'])
|
|
def kill_all():
|
|
try:
|
|
for app_key in app_configs:
|
|
if get_app_status(app_key, running_processes) == 'running':
|
|
stop_app(app_key)
|
|
return jsonify({'status': 'success'})
|
|
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)})
|
|
|
|
@app.route('/force_kill/<app_name>', methods=['GET'])
|
|
def force_kill_app(app_name):
|
|
try:
|
|
success, message = force_kill_process_by_name(app_name, app_configs)
|
|
if success:
|
|
return jsonify({'status': 'killed', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)})
|
|
|
|
@app.route('/force_kill_by_port/<port>', methods=['GET'])
|
|
def force_kill_by_port_route(port:int):
|
|
try:
|
|
success = find_and_kill_process_by_port(port)
|
|
message = ''
|
|
if success:
|
|
return jsonify({'status': 'killed', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)})
|
|
|
|
|
|
# lutzapps - added check app feature
|
|
|
|
@app.route('/delete_app/<app_name>', methods=['GET'])
|
|
def delete_app_installation_route(app_name:str):
|
|
try:
|
|
def progress_callback(message_type:str, message_data:str):
|
|
try:
|
|
send_websocket_message(message_type, message_data)
|
|
print(message_data) # additionally print to output
|
|
except Exception as e:
|
|
print(f"Error sending progress update: {str(e)}")
|
|
# Continue even if websocket fails
|
|
pass
|
|
|
|
success, message = delete_app_installation(app_name, app_configs, progress_callback)
|
|
if success:
|
|
return jsonify({'status': 'deleted', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)})
|
|
|
|
@app.route('/check_installation/<app_name>', methods=['GET'])
|
|
def check_app_installation_route(app_name:str):
|
|
try:
|
|
def progress_callback(message_type, message_data):
|
|
try:
|
|
send_websocket_message(message_type, message_data)
|
|
print(message_data) # additionally print to output
|
|
except Exception as e:
|
|
print(f"Error sending progress update: {str(e)}")
|
|
# Continue even if websocket fails
|
|
pass
|
|
|
|
success, message = check_app_installation(app_name, app_configs, progress_callback)
|
|
if success:
|
|
return jsonify({'status': 'checked', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)})
|
|
|
|
# lutzapps - added refresh app feature
|
|
@app.route('/refresh_installation/<app_name>', methods=['GET'])
|
|
def refresh_app_installation_route(app_name:str):
|
|
try:
|
|
def progress_callback(message_type, message_data):
|
|
try:
|
|
send_websocket_message(message_type, message_data)
|
|
print(message_data) # additionally print to output
|
|
except Exception as e:
|
|
print(f"Error sending progress update: {str(e)}")
|
|
# Continue even if websocket fails
|
|
pass
|
|
|
|
success, message = refresh_app_installation(app_name, app_configs, progress_callback)
|
|
if success:
|
|
return jsonify({'status': 'refreshed', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)})
|
|
|
|
|
|
from gevent.lock import RLock
|
|
websocket_lock = RLock()
|
|
|
|
@sock.route('/ws')
|
|
def websocket(ws):
|
|
with websocket_lock:
|
|
active_websockets.add(ws)
|
|
try:
|
|
while ws.connected: # Check connection status
|
|
try:
|
|
message = ws.receive(timeout=70) # Add timeout slightly higher than heartbeat
|
|
if message:
|
|
data = json.loads(message)
|
|
if data['type'] == 'heartbeat':
|
|
ws.send(json.dumps({'type': 'heartbeat'}))
|
|
else:
|
|
# Handle other message types
|
|
pass
|
|
except Exception as e:
|
|
if "timed out" in str(e).lower():
|
|
# Handle timeout gracefully
|
|
continue
|
|
print(f"Error handling websocket message: {str(e)}")
|
|
if not ws.connected:
|
|
break
|
|
continue
|
|
except Exception as e:
|
|
print(f"WebSocket error: {str(e)}")
|
|
finally:
|
|
with websocket_lock:
|
|
try:
|
|
active_websockets.remove(ws)
|
|
except KeyError:
|
|
pass
|
|
|
|
def send_heartbeat():
|
|
while True:
|
|
try:
|
|
time.sleep(60) # Fixed 60 second interval
|
|
with websocket_lock:
|
|
for ws in list(active_websockets): # Create a copy of the set
|
|
try:
|
|
if ws.connected:
|
|
ws.send(json.dumps({'type': 'heartbeat', 'data': {}}))
|
|
except Exception as e:
|
|
print(f"Error sending heartbeat: {str(e)}")
|
|
except Exception as e:
|
|
print(f"Error in heartbeat thread: {str(e)}")
|
|
|
|
# Start heartbeat thread
|
|
threading.Thread(target=send_heartbeat, daemon=True).start()
|
|
|
|
|
|
@app.route('/available_venvs/<app_name>', methods=['GET'])
|
|
def available_venvs_route(app_name):
|
|
try:
|
|
success, venvs = get_available_venvs(app_name)
|
|
if success:
|
|
return jsonify({'status': 'success', 'available_venvs': venvs})
|
|
else:
|
|
return jsonify({'status': 'error', 'error': venvs})
|
|
|
|
except Exception as e:
|
|
error_message = f"Error for {app_name}: {str(e)}\n{traceback.format_exc()}"
|
|
app.logger.error(error_message)
|
|
return jsonify({'status': 'error', 'message': error_message}), 500
|
|
|
|
# lutzapps - added venv_version
|
|
@app.route('/install/<app_name>/<venv_version>', methods=['GET'])
|
|
def install_app_route(app_name, venv_version):
|
|
try:
|
|
def progress_callback(message_type, message_data):
|
|
try:
|
|
send_websocket_message(message_type, message_data)
|
|
except Exception as e:
|
|
print(f"Error sending progress update: {str(e)}")
|
|
# Continue even if websocket fails
|
|
pass
|
|
|
|
success, message = install_app(app_name, venv_version, progress_callback)
|
|
if success:
|
|
return jsonify({'status': 'success', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
except Exception as e:
|
|
error_message = f"Installation error for {app_name}: {str(e)}\n{traceback.format_exc()}"
|
|
app.logger.error(error_message)
|
|
return jsonify({'status': 'error', 'message': error_message}), 500
|
|
|
|
@app.route('/fix_custom_nodes/<app_name>', methods=['GET'])
|
|
def fix_custom_nodes_route(app_name):
|
|
success, message = fix_custom_nodes(app_name, app_configs)
|
|
if success:
|
|
return jsonify({'status': 'success', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message})
|
|
|
|
@app.route('/set_ssh_password', methods=['POST'])
|
|
def set_ssh_password():
|
|
try:
|
|
data = request.json
|
|
new_password = data.get('password')
|
|
|
|
if not new_password:
|
|
return jsonify({'status': 'error', 'message': 'No password provided'})
|
|
|
|
print("Attempting to set new password...")
|
|
|
|
# Use chpasswd to set the password
|
|
process = subprocess.Popen(['chpasswd'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
output, error = process.communicate(input=f"root:{new_password}\n")
|
|
|
|
if process.returncode != 0:
|
|
raise Exception(f"Failed to set password: {error}")
|
|
|
|
# Save the new password
|
|
save_ssh_password(new_password)
|
|
|
|
# Configure SSH to allow root login with password
|
|
print("Configuring SSH to allow root login with a password...")
|
|
subprocess.run(["sed", "-i", 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/', "/etc/ssh/sshd_config"], check=True)
|
|
subprocess.run(["sed", "-i", 's/#PasswordAuthentication no/PasswordAuthentication yes/', "/etc/ssh/sshd_config"], check=True)
|
|
|
|
# Restart SSH service to apply changes
|
|
print("Restarting SSH service...")
|
|
subprocess.run(['service', 'ssh', 'restart'], check=True)
|
|
print("SSH service restarted successfully.")
|
|
|
|
print("SSH Configuration Updated and Password Set.")
|
|
|
|
return jsonify({'status': 'success', 'message': 'SSH password set successfully. Note: Key-based authentication is more secure.'})
|
|
except Exception as e:
|
|
error_message = f"Error in set_ssh_password: {str(e)}\n{traceback.format_exc()}"
|
|
print(error_message)
|
|
return jsonify({'status': 'error', 'message': error_message})
|
|
|
|
@app.route('/start_filebrowser')
|
|
def start_filebrowser_route():
|
|
if start_filebrowser():
|
|
return jsonify({'status': 'started'})
|
|
return jsonify({'status': 'already_running'})
|
|
|
|
@app.route('/stop_filebrowser')
|
|
def stop_filebrowser_route():
|
|
if stop_filebrowser():
|
|
return jsonify({'status': 'stopped'})
|
|
return jsonify({'status': 'already_stopped'})
|
|
|
|
@app.route('/filebrowser_status')
|
|
def filebrowser_status_route():
|
|
try:
|
|
status = get_filebrowser_status()
|
|
return jsonify({'status': status if status else 'unknown'})
|
|
except Exception as e:
|
|
app.logger.error(f"Error getting filebrowser status: {str(e)}")
|
|
return jsonify({'status': 'error', 'message': str(e)}), 500
|
|
|
|
@app.route('/add_app_config', methods=['POST'])
|
|
def add_new_app_config():
|
|
data = request.json
|
|
app_name = data.get('app_name')
|
|
config = data.get('config')
|
|
if app_name and config:
|
|
add_app_config(app_name, config)
|
|
return jsonify({'status': 'success', 'message': f'App {app_name} added successfully'})
|
|
return jsonify({'status': 'error', 'message': 'Invalid data provided'})
|
|
|
|
@app.route('/remove_app_config/<app_name>', methods=['POST'])
|
|
def remove_existing_app_config(app_name):
|
|
if app_name in app_configs:
|
|
remove_app_config(app_name)
|
|
return jsonify({'status': 'success', 'message': f'App {app_name} removed successfully'})
|
|
return jsonify({'status': 'error', 'message': f'App {app_name} not found'})
|
|
|
|
# modified function
|
|
def setup_shared_models():
|
|
# lutzapps - CHANGE #4 - use the new "shared_models" module for app model sharing
|
|
jsonResult = update_model_symlinks()
|
|
|
|
return SHARED_MODELS_DIR # shared_models_dir is now owned and managed by the "shared_models" utils module
|
|
|
|
def update_symlinks_periodically():
|
|
while True:
|
|
update_model_symlinks()
|
|
time.sleep(300) # Check every 5 minutes
|
|
|
|
def start_symlink_update_thread():
|
|
thread = threading.Thread(target=update_symlinks_periodically, daemon=True)
|
|
thread.start()
|
|
|
|
# modified function
|
|
@app.route('/recreate_symlinks', methods=['GET'])
|
|
def recreate_symlinks_route():
|
|
# lutzapps - use the new "shared_models" module for app model sharing
|
|
jsonResult = update_model_symlinks()
|
|
|
|
return jsonResult
|
|
|
|
# modified function
|
|
@app.route('/create_shared_folders', methods=['GET'])
|
|
def create_shared_folders():
|
|
# lutzapps - use the new "shared_models" module for app model sharing
|
|
jsonResult = ensure_shared_models_folders()
|
|
return jsonResult
|
|
|
|
def save_civitai_token(token):
|
|
with open(CIVITAI_TOKEN_FILE, 'w') as f:
|
|
json.dump({'token': token}, f)
|
|
|
|
# lutzapps - added function - 'HF_TOKEN' ENV var
|
|
def load_huggingface_token()->str:
|
|
# look FIRST for Huggingface token passed in as 'HF_TOKEN' ENV var
|
|
HF_TOKEN = os.environ.get('HF_TOKEN', '')
|
|
|
|
if not HF_TOKEN == "":
|
|
print("'HF_TOKEN' ENV var found")
|
|
## send the found token to the WebUI "Models Downloader" 'hfToken' Password field to use
|
|
# send_websocket_message('extend_ui_helper', {
|
|
# 'cmd': 'hfToken', # 'hfToken' must match the DOM Id of the WebUI Password field in "index.html"
|
|
# 'message': "Put the HF_TOKEN in the WebUI Password field 'hfToken'"
|
|
# } )
|
|
|
|
return HF_TOKEN
|
|
|
|
# only if the 'HF_API_TOKEN' ENV var was not found, then handle it via local hidden HF_TOKEN_FILE
|
|
try:
|
|
if os.path.exists(HF_TOKEN_FILE):
|
|
with open(HF_TOKEN_FILE, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
return data.get('token')
|
|
except:
|
|
return None
|
|
|
|
return None
|
|
|
|
# lutzapps - modified function - support 'CIVITAI_API_TOKEN' ENV var
|
|
def load_civitai_token()->str:
|
|
# look FIRST for CivitAI token passed in as 'CIVITAI_API_TOKEN' ENV var
|
|
CIVITAI_API_TOKEN = os.environ.get('CIVITAI_API_TOKEN', '')
|
|
|
|
if not CIVITAI_API_TOKEN == "":
|
|
print("'CIVITAI_API_TOKEN' ENV var found")
|
|
## send the found token to the WebUI "Models Downloader" 'hfToken' Password field to use
|
|
# send_websocket_message('extend_ui_helper', {
|
|
# 'cmd': 'civitaiToken', # 'civitaiToken' must match the DOM Id of the WebUI Password field in "index.html"
|
|
# 'message': 'Put the CIVITAI_API_TOKEN in the WebUI Password field "civitaiToken"'
|
|
# } )
|
|
|
|
return CIVITAI_API_TOKEN
|
|
|
|
# only if the 'CIVITAI_API_TOKEN' ENV var is not found, then handle it via local hidden CIVITAI_TOKEN_FILE
|
|
try:
|
|
if os.path.exists(CIVITAI_TOKEN_FILE):
|
|
with open(CIVITAI_TOKEN_FILE, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
return data.get('token')
|
|
except:
|
|
return None
|
|
|
|
return None
|
|
|
|
@app.route('/save_civitai_token', methods=['POST'])
|
|
def save_civitai_token_route():
|
|
token = request.json.get('token')
|
|
if token:
|
|
save_civitai_token(token)
|
|
return jsonify({'status': 'success', 'message': 'Civitai token saved successfully.'})
|
|
return jsonify({'status': 'error', 'message': 'No token provided.'}), 400
|
|
|
|
@app.route('/get_civitai_token', methods=['GET'])
|
|
def get_civitai_token_route():
|
|
token = load_civitai_token()
|
|
return jsonify({'token': token})
|
|
|
|
# lutzapps - add support for passed in "HF_TOKEN" ENV var
|
|
@app.route('/get_huggingface_token', methods=['GET'])
|
|
def get_hugginface_token_route():
|
|
token = load_huggingface_token()
|
|
return jsonify({'token': token})
|
|
|
|
# lutzapps - CHANGE #9 - return model_types to populate the Download manager Select Option
|
|
# new function to support the "Model Downloader" with the 'SHARED_MODEL_FOLDERS' dictionary
|
|
@app.route('/get_model_types', methods=['GET'])
|
|
def get_model_types_route():
|
|
model_types_dict = {}
|
|
|
|
# check if the SHARED_MODELS_DIR exists at the "/workspace" location!
|
|
# that only happens AFTER the the user clicked the "Create Shared Folders" button
|
|
# on the "Settings" Tab of the app's WebUI!
|
|
# to reload existing SHARED_MODEL_FOLDERS into the select options dropdown list,
|
|
# we send a WebSockets message to "index.html"
|
|
|
|
if not os.path.exists(SHARED_MODELS_DIR):
|
|
# return an empty model_types_dict, so the "Download Manager" does NOT get
|
|
# the already in-memory SHARED_MODEL_FOLDERS code-generated default dict
|
|
# BEFORE the workspace folders in SHARED_MODELS_DIR exists
|
|
return model_types_dict
|
|
|
|
i = 0
|
|
for model_type, model_type_description in SHARED_MODEL_FOLDERS.items():
|
|
model_types_dict[i] = {
|
|
'modelfolder': model_type,
|
|
'desc': model_type_description
|
|
}
|
|
|
|
i += 1
|
|
|
|
return model_types_dict
|
|
|
|
@app.route('/download_model', methods=['POST'])
|
|
def download_model_route():
|
|
# this function will be called first from the model downloader, which only paasses the url,
|
|
# but did not parse for already existing version_id or file_index
|
|
# if we ignore the already wanted version_id, the user will end up with the model-picker dialog
|
|
# just to select the wanted version_id again, and then the model-picker calls also into this function,
|
|
# but now with a non-blank version_id
|
|
|
|
try:
|
|
data = request.json
|
|
url = data.get('url')
|
|
model_name = data.get('model_name')
|
|
model_type = data.get('model_type')
|
|
civitai_token = data.get('civitai_token') or load_civitai_token() # If no token provided in request, try to read from ENV and last from file
|
|
hf_token = data.get('hf_token') or load_huggingface_token() # If no token provided in request, try to read from ENV and last from file
|
|
version_id = data.get('version_id')
|
|
file_index = data.get('file_index')
|
|
|
|
is_civitai, _, url_model_id, url_version_id = check_civitai_url(url)
|
|
if version_id == None: # model-picker dialog not used already
|
|
version_id = url_version_id # get a possible version_id from the copy-pasted url
|
|
|
|
is_huggingface, _, _, _, _ = check_huggingface_url(url)
|
|
|
|
# only CivitAI or Huggingface model downloads are supported for now
|
|
if not (is_civitai or is_huggingface):
|
|
return jsonify({'status': 'error', 'message': 'Unsupported URL. Please use Civitai or Hugging Face URLs.'}), 400
|
|
|
|
# CivitAI downloads require an API Token needed (e.g. for model variant downloads and private models)
|
|
if is_civitai and not civitai_token:
|
|
return jsonify({'status': 'error', 'message': 'Civitai token is required for downloading from Civitai.'}), 400
|
|
|
|
try:
|
|
success, message = download_model(url, model_name, model_type, civitai_token, hf_token, version_id, file_index)
|
|
if success:
|
|
if isinstance(message, dict) and 'choice_required' in message:
|
|
return jsonify({'status': 'choice_required', 'data': message['choice_required']})
|
|
return jsonify({'status': 'success', 'message': message})
|
|
else:
|
|
return jsonify({'status': 'error', 'message': message}), 400
|
|
except Exception as e:
|
|
error_message = f"Model download error: {str(e)}\n{traceback.format_exc()}"
|
|
app.logger.error(error_message)
|
|
return jsonify({'status': 'error', 'message': error_message}), 500
|
|
|
|
except Exception as e:
|
|
error_message = f"Error processing request: {str(e)}\n{traceback.format_exc()}"
|
|
app.logger.error(error_message)
|
|
return jsonify({'status': 'error', 'message': error_message}), 400
|
|
|
|
@app.route('/get_model_folders')
|
|
def get_model_folders():
|
|
folders = {}
|
|
|
|
# lutzapps - replace the hard-coded model types
|
|
for folder, model_type_description in SHARED_MODEL_FOLDERS.items():
|
|
#for folder in ['Stable-diffusion', 'VAE', 'Lora', 'ESRGAN']:
|
|
folder_path = os.path.join(SHARED_MODELS_DIR, folder)
|
|
if os.path.exists(folder_path):
|
|
total_size = 0
|
|
file_count = 0
|
|
for dirpath, dirnames, filenames in os.walk(folder_path):
|
|
for f in filenames:
|
|
fp = os.path.join(dirpath, f)
|
|
total_size += os.path.getsize(fp)
|
|
file_count += 1
|
|
folders[folder] = {
|
|
'size': format_size(total_size),
|
|
'file_count': file_count
|
|
}
|
|
return jsonify(folders)
|
|
|
|
@app.route('/update_symlinks', methods=['POST'])
|
|
def update_symlinks_route():
|
|
try:
|
|
update_model_symlinks()
|
|
return jsonify({'status': 'success', 'message': 'Symlinks updated successfully'})
|
|
except Exception as e:
|
|
return jsonify({'status': 'error', 'message': str(e)}), 500
|
|
|
|
if __name__ == '__main__':
|
|
shared_models_path = setup_shared_models()
|
|
print(f"Shared models directory: {shared_models_path}")
|
|
|
|
if setup_ssh():
|
|
print("SSH setup completed successfully.")
|
|
else:
|
|
print("Failed to set up SSH. Please check the logs.")
|
|
|
|
print("Configuring File Browser...")
|
|
if configure_filebrowser():
|
|
print("File Browser configuration completed successfully.")
|
|
print("Attempting to start File Browser...")
|
|
if start_filebrowser():
|
|
print("File Browser started successfully.")
|
|
else:
|
|
print("Failed to start File Browser. Please check the logs.")
|
|
else:
|
|
print("Failed to configure File Browser. Please check the logs.")
|
|
|
|
threading.Thread(target=check_running_processes, daemon=True).start()
|
|
|
|
# Start the thread to periodically update model symlinks
|
|
start_symlink_update_thread()
|
|
|
|
app.run(debug=True, host='0.0.0.0', port=7223) |