diff --git a/predict.py b/predict.py index 89e5ece..8fcdd9f 100644 --- a/predict.py +++ b/predict.py @@ -84,10 +84,10 @@ class Predictor(BasePredictor): k: create_SUPIR_model("options/SUPIR_v0.yaml", SUPIR_sign=k).to( self.supir_device ) - for k in ["Q", "F"][1:] + for k in ["Q", "F"] } - for k in ["Q", "F"][1:]: + for k in ["Q", "F"]: self.models[k].ae_dtype = convert_dtype(ae_dtype) self.models[k].model.dtype = convert_dtype(diff_dtype)