20240205
This commit is contained in:
parent
f6a31e9563
commit
c1c5728601
2 changed files with 10 additions and 10 deletions
|
@ -45,9 +45,12 @@ class SUPIRModel(DiffusionEngine):
|
|||
return z
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage_with_denoise(self, x, use_sample=True):
|
||||
def encode_first_stage_with_denoise(self, x, use_sample=True, is_stage1=False):
|
||||
with torch.autocast("cuda", dtype=self.ae_dtype):
|
||||
h = self.first_stage_model.denoise_encoder(x)
|
||||
if is_stage1:
|
||||
h = self.first_stage_model.denoise_encoder_s1(x)
|
||||
else:
|
||||
h = self.first_stage_model.denoise_encoder(x)
|
||||
moments = self.first_stage_model.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
if use_sample:
|
||||
|
@ -65,11 +68,11 @@ class SUPIRModel(DiffusionEngine):
|
|||
return out.float()
|
||||
|
||||
@torch.no_grad()
|
||||
def batchify_denoise(self, x):
|
||||
def batchify_denoise(self, x, is_stage1=False):
|
||||
'''
|
||||
[N, C, H, W], [-1, 1], RGB
|
||||
'''
|
||||
x = self.encode_first_stage_with_denoise(x, use_sample=False)
|
||||
x = self.encode_first_stage_with_denoise(x, use_sample=False, is_stage1=is_stage1)
|
||||
return self.decode_first_stage(x)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -9,6 +9,7 @@ from PIL import Image
|
|||
from llava.llava_agent import LLavaAgent
|
||||
from CKPT_PTH import LLAVA_MODEL_PATH
|
||||
import einops
|
||||
import copy
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default='0.0.0.0')
|
||||
|
@ -31,6 +32,7 @@ else:
|
|||
|
||||
# load SUPIR
|
||||
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
|
||||
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
|
||||
model.current_model = 'v0-Q'
|
||||
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
|
||||
# load LLaVA
|
||||
|
@ -41,17 +43,12 @@ else:
|
|||
|
||||
|
||||
def stage1_process(input_image, gamma_correction):
|
||||
# force to v0-Q
|
||||
if model.current_model != 'v0-Q':
|
||||
print('load v0-Q')
|
||||
model.load_state_dict(ckpt_Q, strict=False)
|
||||
model.current_model = 'v0-Q'
|
||||
LQ = HWC3(input_image)
|
||||
LQ = fix_resize(LQ, 512)
|
||||
# stage1
|
||||
LQ = np.array(LQ) / 255 * 2 - 1
|
||||
LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
|
||||
LQ = model.batchify_denoise(LQ)
|
||||
LQ = model.batchify_denoise(LQ, is_stage1=True)
|
||||
LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
|
||||
# gamma correction
|
||||
LQ = LQ / 255.0
|
||||
|
|
Loading…
Add table
Reference in a new issue