madiator-docker-runpod/official-templates/better-ai-launcher/app.py
2024-10-21 11:03:33 +02:00

574 lines
No EOL
23 KiB
Python

from gevent import monkey
monkey.patch_all()
import os
import threading
import time
from flask import Flask, render_template, jsonify, request
from flask_sock import Sock
import json
import signal
import shutil
import subprocess
import traceback
from utils.ssh_utils import setup_ssh, save_ssh_password, get_ssh_password, check_ssh_config, SSH_CONFIG_FILE
from utils.filebrowser_utils import configure_filebrowser, start_filebrowser, stop_filebrowser, get_filebrowser_status, FILEBROWSER_PORT
from utils.app_utils import (
run_app, update_process_status, check_app_directories, get_app_status,
force_kill_process_by_name, update_webui_user_sh, save_install_status,
get_install_status, download_and_unpack_venv, fix_custom_nodes, is_process_running, install_app, update_model_symlinks
)
from utils.websocket_utils import send_websocket_message, active_websockets
from utils.app_configs import get_app_configs, add_app_config, remove_app_config, app_configs
from utils.model_utils import download_model, check_civitai_url, check_huggingface_url, SHARED_MODELS_DIR, format_size
app = Flask(__name__)
sock = Sock(app)
RUNPOD_POD_ID = os.environ.get('RUNPOD_POD_ID', 'localhost')
running_processes = {}
app_configs = get_app_configs()
S3_BASE_URL = "https://better.s3.madiator.com/"
SETTINGS_FILE = '/workspace/.app_settings.json'
CIVITAI_TOKEN_FILE = '/workspace/.civitai_token'
def load_settings():
if os.path.exists(SETTINGS_FILE):
with open(SETTINGS_FILE, 'r') as f:
return json.load(f)
return {'auto_generate_ssh_password': False}
def save_settings(settings):
with open(SETTINGS_FILE, 'w') as f:
json.dump(settings, f)
def check_running_processes():
while True:
for app_name in list(running_processes.keys()):
update_process_status(app_name, running_processes)
current_status = get_app_status(app_name, running_processes)
send_websocket_message('status_update', {app_name: current_status})
time.sleep(5)
@app.route('/')
def index():
settings = load_settings()
# Determine the current SSH authentication method
with open(SSH_CONFIG_FILE, 'r') as f:
ssh_config = f.read()
current_auth_method = 'key' if 'PasswordAuthentication no' in ssh_config else 'password'
# Get the current SSH password if it exists
ssh_password = get_ssh_password()
ssh_password_status = 'set' if ssh_password else 'not_set'
app_status = {}
for app_name, config in app_configs.items():
dirs_ok, message = check_app_directories(app_name, app_configs)
status = get_app_status(app_name, running_processes)
install_status = get_install_status(app_name)
app_status[app_name] = {
'name': config['name'],
'dirs_ok': dirs_ok,
'message': message,
'port': config['port'],
'status': status,
'installed': dirs_ok,
'install_status': install_status,
'is_bcomfy': app_name == 'bcomfy' # Add this line
}
filebrowser_status = get_filebrowser_status()
return render_template('index.html',
apps=app_configs,
app_status=app_status,
pod_id=RUNPOD_POD_ID,
RUNPOD_PUBLIC_IP=os.environ.get('RUNPOD_PUBLIC_IP'),
RUNPOD_TCP_PORT_22=os.environ.get('RUNPOD_TCP_PORT_22'),
settings=settings,
current_auth_method=current_auth_method,
ssh_password=ssh_password,
ssh_password_status=ssh_password_status,
filebrowser_status=filebrowser_status)
@app.route('/start/<app_name>')
def start_app(app_name):
dirs_ok, message = check_app_directories(app_name, app_configs)
if not dirs_ok:
return jsonify({'status': 'error', 'message': message})
if app_name in app_configs and get_app_status(app_name, running_processes) == 'stopped':
# Update webui-user.sh for Forge and A1111
if app_name in ['bforge', 'ba1111']:
update_webui_user_sh(app_name, app_configs)
command = app_configs[app_name]['command']
threading.Thread(target=run_app, args=(app_name, command, running_processes)).start()
return jsonify({'status': 'started'})
return jsonify({'status': 'already_running'})
@app.route('/stop/<app_name>')
def stop_app(app_name):
if app_name in running_processes and get_app_status(app_name, running_processes) == 'running':
try:
pgid = os.getpgid(running_processes[app_name]['pid'])
os.killpg(pgid, signal.SIGTERM)
for _ in range(10):
if not is_process_running(running_processes[app_name]['pid']):
break
time.sleep(1)
if is_process_running(running_processes[app_name]['pid']):
os.killpg(pgid, signal.SIGKILL)
running_processes[app_name]['status'] = 'stopped'
return jsonify({'status': 'stopped'})
except ProcessLookupError:
running_processes[app_name]['status'] = 'stopped'
return jsonify({'status': 'already_stopped'})
return jsonify({'status': 'not_running'})
@app.route('/status')
def get_status():
return jsonify({app_name: get_app_status(app_name, running_processes) for app_name in app_configs})
@app.route('/logs/<app_name>')
def get_logs(app_name):
if app_name in running_processes:
return jsonify({'logs': running_processes[app_name]['log'][-100:]})
return jsonify({'logs': []})
@app.route('/kill_all', methods=['POST'])
def kill_all():
try:
for app_key in app_configs:
if get_app_status(app_key, running_processes) == 'running':
stop_app(app_key)
return jsonify({'status': 'success'})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)})
@app.route('/force_kill/<app_name>', methods=['POST'])
def force_kill_app(app_name):
try:
success, message = force_kill_process_by_name(app_name, app_configs)
if success:
return jsonify({'status': 'killed', 'message': message})
else:
return jsonify({'status': 'error', 'message': message})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)})
@sock.route('/ws')
def websocket(ws):
active_websockets.add(ws)
try:
while True:
message = ws.receive()
data = json.loads(message)
if data['type'] == 'heartbeat':
ws.send(json.dumps({'type': 'heartbeat'}))
else:
# Handle other message types
pass
except Exception as e:
print(f"WebSocket error: {str(e)}")
finally:
active_websockets.remove(ws)
def send_heartbeat():
initial_interval = 5 # 5 seconds
max_interval = 60 # 60 seconds
current_interval = initial_interval
start_time = time.time()
while True:
time.sleep(current_interval)
send_websocket_message('heartbeat', {})
# Gradually increase the interval
elapsed_time = time.time() - start_time
if elapsed_time < 60: # First minute
current_interval = min(current_interval * 1.5, max_interval)
else:
current_interval = max_interval
# Start heartbeat thread
threading.Thread(target=send_heartbeat, daemon=True).start()
@app.route('/install/<app_name>', methods=['POST'])
def install_app_route(app_name):
try:
success, message = install_app(app_name, app_configs, send_websocket_message)
if success:
return jsonify({'status': 'success', 'message': message})
else:
return jsonify({'status': 'error', 'message': message})
except Exception as e:
error_message = f"Installation error for {app_name}: {str(e)}\n{traceback.format_exc()}"
app.logger.error(error_message)
return jsonify({'status': 'error', 'message': error_message}), 500
@app.route('/fix_custom_nodes/<app_name>', methods=['POST'])
def fix_custom_nodes_route(app_name):
success, message = fix_custom_nodes(app_name)
if success:
return jsonify({'status': 'success', 'message': message})
else:
return jsonify({'status': 'error', 'message': message})
@app.route('/set_ssh_password', methods=['POST'])
def set_ssh_password():
try:
data = request.json
new_password = data.get('password')
if not new_password:
return jsonify({'status': 'error', 'message': 'No password provided'})
print("Attempting to set new password...")
# Use chpasswd to set the password
process = subprocess.Popen(['chpasswd'], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
output, error = process.communicate(input=f"root:{new_password}\n")
if process.returncode != 0:
raise Exception(f"Failed to set password: {error}")
# Save the new password
save_ssh_password(new_password)
# Configure SSH to allow root login with password
print("Configuring SSH to allow root login with a password...")
subprocess.run(["sed", "-i", 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/', "/etc/ssh/sshd_config"], check=True)
subprocess.run(["sed", "-i", 's/#PasswordAuthentication no/PasswordAuthentication yes/', "/etc/ssh/sshd_config"], check=True)
# Restart SSH service to apply changes
print("Restarting SSH service...")
subprocess.run(['service', 'ssh', 'restart'], check=True)
print("SSH service restarted successfully.")
print("SSH Configuration Updated and Password Set.")
return jsonify({'status': 'success', 'message': 'SSH password set successfully. Note: Key-based authentication is more secure.'})
except Exception as e:
error_message = f"Error in set_ssh_password: {str(e)}\n{traceback.format_exc()}"
print(error_message)
return jsonify({'status': 'error', 'message': error_message})
@app.route('/start_filebrowser')
def start_filebrowser_route():
if start_filebrowser():
return jsonify({'status': 'started'})
return jsonify({'status': 'already_running'})
@app.route('/stop_filebrowser')
def stop_filebrowser_route():
if stop_filebrowser():
return jsonify({'status': 'stopped'})
return jsonify({'status': 'already_stopped'})
@app.route('/filebrowser_status')
def filebrowser_status_route():
return jsonify({'status': get_filebrowser_status()})
@app.route('/add_app_config', methods=['POST'])
def add_new_app_config():
data = request.json
app_name = data.get('app_name')
config = data.get('config')
if app_name and config:
add_app_config(app_name, config)
return jsonify({'status': 'success', 'message': f'App {app_name} added successfully'})
return jsonify({'status': 'error', 'message': 'Invalid data provided'})
@app.route('/remove_app_config/<app_name>', methods=['POST'])
def remove_existing_app_config(app_name):
if app_name in app_configs:
remove_app_config(app_name)
return jsonify({'status': 'success', 'message': f'App {app_name} removed successfully'})
return jsonify({'status': 'error', 'message': f'App {app_name} not found'})
def setup_shared_models():
shared_models_dir = '/workspace/shared_models'
model_types = ['Stable-diffusion', 'VAE', 'Lora', 'ESRGAN']
# Create shared models directory if it doesn't exist
os.makedirs(shared_models_dir, exist_ok=True)
for model_type in model_types:
shared_model_path = os.path.join(shared_models_dir, model_type)
# Create shared model type directory if it doesn't exist
os.makedirs(shared_model_path, exist_ok=True)
# Create a README file in the shared models directory
readme_path = os.path.join(shared_models_dir, 'README.txt')
if not os.path.exists(readme_path):
with open(readme_path, 'w') as f:
f.write("Upload your models to the appropriate folders:\n\n")
f.write("- Stable-diffusion: for Stable Diffusion models\n")
f.write("- VAE: for VAE models\n")
f.write("- Lora: for LoRA models\n")
f.write("- ESRGAN: for ESRGAN upscaling models\n\n")
f.write("These models will be automatically linked to all supported apps.")
print(f"Shared models directory created at {shared_models_dir}")
print("Shared models setup completed.")
return shared_models_dir
def update_model_symlinks():
shared_models_dir = '/workspace/shared_models'
apps = {
'stable-diffusion-webui': '/workspace/stable-diffusion-webui/models',
'stable-diffusion-webui-forge': '/workspace/stable-diffusion-webui-forge/models',
'ComfyUI': '/workspace/ComfyUI/models'
}
model_types = ['Stable-diffusion', 'VAE', 'Lora', 'ESRGAN']
for model_type in model_types:
shared_model_path = os.path.join(shared_models_dir, model_type)
if not os.path.exists(shared_model_path):
continue
for app, app_models_dir in apps.items():
if app == 'ComfyUI':
if model_type == 'Stable-diffusion':
app_model_path = os.path.join(app_models_dir, 'checkpoints')
elif model_type == 'Lora':
app_model_path = os.path.join(app_models_dir, 'loras')
elif model_type == 'ESRGAN':
app_model_path = os.path.join(app_models_dir, 'upscale_models')
else:
app_model_path = os.path.join(app_models_dir, model_type.lower())
else:
app_model_path = os.path.join(app_models_dir, model_type)
# Create the app model directory if it doesn't exist
os.makedirs(app_model_path, exist_ok=True)
# Create symlinks for each file in the shared model directory
for filename in os.listdir(shared_model_path):
src = os.path.join(shared_model_path, filename)
dst = os.path.join(app_model_path, filename)
if os.path.isfile(src) and not os.path.exists(dst):
os.symlink(src, dst)
print("Model symlinks updated.")
def update_symlinks_periodically():
while True:
update_model_symlinks()
time.sleep(300) # Check every 5 minutes
def start_symlink_update_thread():
thread = threading.Thread(target=update_symlinks_periodically, daemon=True)
thread.start()
def recreate_symlinks():
shared_models_dir = '/workspace/shared_models'
apps = {
'stable-diffusion-webui': '/workspace/stable-diffusion-webui/models',
'stable-diffusion-webui-forge': '/workspace/stable-diffusion-webui-forge/models',
'ComfyUI': '/workspace/ComfyUI/models'
}
model_types = ['Stable-diffusion', 'VAE', 'Lora', 'ESRGAN']
for model_type in model_types:
shared_model_path = os.path.join(shared_models_dir, model_type)
if not os.path.exists(shared_model_path):
continue
for app, app_models_dir in apps.items():
if app == 'ComfyUI':
if model_type == 'Stable-diffusion':
app_model_path = os.path.join(app_models_dir, 'checkpoints')
elif model_type == 'Lora':
app_model_path = os.path.join(app_models_dir, 'loras')
elif model_type == 'ESRGAN':
app_model_path = os.path.join(app_models_dir, 'upscale_models')
else:
app_model_path = os.path.join(app_models_dir, model_type.lower())
else:
app_model_path = os.path.join(app_models_dir, model_type)
# Remove existing symlinks
if os.path.islink(app_model_path):
os.unlink(app_model_path)
elif os.path.isdir(app_model_path):
shutil.rmtree(app_model_path)
# Create the app model directory if it doesn't exist
os.makedirs(app_model_path, exist_ok=True)
# Create symlinks for each file in the shared model directory
for filename in os.listdir(shared_model_path):
src = os.path.join(shared_model_path, filename)
dst = os.path.join(app_model_path, filename)
if os.path.isfile(src) and not os.path.exists(dst):
os.symlink(src, dst)
return "Symlinks recreated successfully."
@app.route('/recreate_symlinks', methods=['POST'])
def recreate_symlinks_route():
try:
message = recreate_symlinks()
return jsonify({'status': 'success', 'message': message})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)})
@app.route('/create_shared_folders', methods=['POST'])
def create_shared_folders():
try:
shared_models_dir = '/workspace/shared_models'
model_types = ['Stable-diffusion', 'Lora', 'embeddings', 'VAE', 'hypernetworks', 'aesthetic_embeddings', 'controlnet', 'ESRGAN']
# Create shared models directory if it doesn't exist
os.makedirs(shared_models_dir, exist_ok=True)
for model_type in model_types:
shared_model_path = os.path.join(shared_models_dir, model_type)
# Create shared model type directory if it doesn't exist
os.makedirs(shared_model_path, exist_ok=True)
# Create a README file in the shared models directory
readme_path = os.path.join(shared_models_dir, 'README.txt')
if not os.path.exists(readme_path):
with open(readme_path, 'w') as f:
f.write("Upload your models to the appropriate folders:\n\n")
f.write("- Stable-diffusion: for Stable Diffusion checkpoints\n")
f.write("- Lora: for LoRA models\n")
f.write("- embeddings: for Textual Inversion embeddings\n")
f.write("- VAE: for VAE models\n")
f.write("- hypernetworks: for Hypernetwork models\n")
f.write("- aesthetic_embeddings: for Aesthetic Gradient embeddings\n")
f.write("- controlnet: for ControlNet models\n")
f.write("- ESRGAN: for ESRGAN upscaling models\n\n")
f.write("These models will be automatically linked to all supported apps.")
return jsonify({'status': 'success', 'message': 'Shared model folders created successfully.'})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)})
def save_civitai_token(token):
with open(CIVITAI_TOKEN_FILE, 'w') as f:
json.dump({'token': token}, f)
def load_civitai_token():
if os.path.exists(CIVITAI_TOKEN_FILE):
with open(CIVITAI_TOKEN_FILE, 'r') as f:
data = json.load(f)
return data.get('token')
return None
@app.route('/save_civitai_token', methods=['POST'])
def save_civitai_token_route():
token = request.json.get('token')
if token:
save_civitai_token(token)
return jsonify({'status': 'success', 'message': 'Civitai token saved successfully.'})
return jsonify({'status': 'error', 'message': 'No token provided.'}), 400
@app.route('/get_civitai_token', methods=['GET'])
def get_civitai_token_route():
token = load_civitai_token()
return jsonify({'token': token})
@app.route('/download_model', methods=['POST'])
def download_model_route():
url = request.json.get('url')
model_name = request.json.get('model_name')
model_type = request.json.get('model_type')
civitai_token = request.json.get('civitai_token') or load_civitai_token()
hf_token = request.json.get('hf_token')
version_id = request.json.get('version_id')
file_index = request.json.get('file_index')
is_civitai, _, _, _ = check_civitai_url(url)
is_huggingface, _, _, _, _ = check_huggingface_url(url)
if not (is_civitai or is_huggingface):
return jsonify({'status': 'error', 'message': 'Unsupported URL. Please use Civitai or Hugging Face URLs.'}), 400
if is_civitai and not civitai_token:
return jsonify({'status': 'error', 'message': 'Civitai token is required for downloading from Civitai.'}), 400
try:
success, message = download_model(url, model_name, model_type, send_websocket_message, civitai_token, hf_token, version_id, file_index)
if success:
if isinstance(message, dict) and 'choice_required' in message:
return jsonify({'status': 'choice_required', 'data': message['choice_required']})
return jsonify({'status': 'success', 'message': message})
else:
return jsonify({'status': 'error', 'message': message}), 400
except Exception as e:
error_message = f"Model download error: {str(e)}\n{traceback.format_exc()}"
app.logger.error(error_message)
return jsonify({'status': 'error', 'message': error_message}), 500
@app.route('/get_model_folders')
def get_model_folders():
folders = {}
for folder in ['Stable-diffusion', 'VAE', 'Lora', 'ESRGAN']:
folder_path = os.path.join(SHARED_MODELS_DIR, folder)
if os.path.exists(folder_path):
total_size = 0
file_count = 0
for dirpath, dirnames, filenames in os.walk(folder_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
file_count += 1
folders[folder] = {
'size': format_size(total_size),
'file_count': file_count
}
return jsonify(folders)
@app.route('/update_symlinks', methods=['POST'])
def update_symlinks_route():
try:
update_model_symlinks()
return jsonify({'status': 'success', 'message': 'Symlinks updated successfully'})
except Exception as e:
return jsonify({'status': 'error', 'message': str(e)}), 500
if __name__ == '__main__':
shared_models_path = setup_shared_models()
print(f"Shared models directory: {shared_models_path}")
if setup_ssh():
print("SSH setup completed successfully.")
else:
print("Failed to set up SSH. Please check the logs.")
print("Configuring File Browser...")
if configure_filebrowser():
print("File Browser configuration completed successfully.")
print("Attempting to start File Browser...")
if start_filebrowser():
print("File Browser started successfully.")
else:
print("Failed to start File Browser. Please check the logs.")
else:
print("Failed to configure File Browser. Please check the logs.")
threading.Thread(target=check_running_processes, daemon=True).start()
# Start the thread to periodically update model symlinks
start_symlink_update_thread()
app.run(debug=True, host='0.0.0.0', port=7223)