This commit is contained in:
Fanghua-Yu 2024-02-05 20:47:41 +08:00
parent f6a31e9563
commit c1c5728601
2 changed files with 10 additions and 10 deletions

View file

@ -45,8 +45,11 @@ class SUPIRModel(DiffusionEngine):
return z
@torch.no_grad()
def encode_first_stage_with_denoise(self, x, use_sample=True):
def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
with torch.autocast("cuda", dtype=self.ae_dtype):
if is_stage1:
h = self.first_stage_model.denoise_encoder_s1(x)
else:
h = self.first_stage_model.denoise_encoder(x)
moments = self.first_stage_model.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
@ -65,11 +68,11 @@ class SUPIRModel(DiffusionEngine):
return out.float()
@torch.no_grad()
def batchify_denoise(self, x):
def batchify_denoise(self, x, is_stage1=False):
'''
[N, C, H, W], [-1, 1], RGB
'''
x = self.encode_first_stage_with_denoise(x, use_sample=False)
x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
return self.decode_first_stage(x)
@torch.no_grad()

View file

@ -9,6 +9,7 @@ from PIL import Image
from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH
import einops
import copy
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='0.0.0.0')
@ -31,6 +32,7 @@ else:
# load SUPIR
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
model.current_model = 'v0-Q'
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
# load LLaVA
@ -41,17 +43,12 @@ else:
def stage1_process(input_image, gamma_correction):
# force to v0-Q
if model.current_model != 'v0-Q':
print('load v0-Q')
model.load_state_dict(ckpt_Q, strict=False)
model.current_model = 'v0-Q'
LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512)
# stage1
LQ = np.array(LQ) / 255 * 2 - 1
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
LQ = model.batchify_denoise(LQ)
LQ = model.batchify_denoise(LQ, is_stage1=True)
LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
# gamma correction
LQ = LQ / 255.0