区分stable diffusion中的通道数与张量维度

打印 上一主题 下一主题

主题 685|帖子 685|积分 2055

前言:通道数与张量外形都在数值3和4之间变换,容易混淆。
1.通道数:

1.1 channel = 3

RGB 图像具有 3 个通道(红色、绿色和蓝色)。
1.2 channel = 4

Stable Diffusion has 4 latent channels。
怎样理解卷积神经网络中的通道(channel)
2.张量外形

2.1 3D 张量

外形为 (C, H, W),此中 C 是通道数,H 是高度,W 是宽度。这实用于单个图像。
2.2 4D 张量

2.2.1 通常

外形为 (B, C, H, W),此中 B 是批次大小,C 是通道数,H 是高度,W 是宽度。这实用于多个图像(比方,批量处理)。
2.2.2 stable diffusion

在img2img中,将image用vae编码并按照timestep加噪:
  1.                 # This code copyed from diffusers.pipline_controlnet_img2img.py
  2.         # 6. Prepare latent variables
  3.         latents = self.prepare_latents(
  4.             image,
  5.             latent_timestep,
  6.             batch_size,
  7.             num_images_per_prompt,
  8.             prompt_embeds.dtype,
  9.             device,
  10.             generator,
  11.         )
复制代码
image的dim(维度)是3,而latents的dim为4。
让我们先看text2img的prepare_latents函数:
  1.     # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
  2.     def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
  3.         shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
  4.         if isinstance(generator, list) and len(generator) != batch_size:
  5.             raise ValueError(
  6.                 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  7.                 f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  8.             )
  9.         if latents is None:
  10.             latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
  11.         else:
  12.             latents = latents.to(device)
  13.         # scale the initial noise by the standard deviation required by the scheduler
  14.         latents = latents * self.scheduler.init_noise_sigma
  15.         return latents
复制代码
显然,shape已经规定了latents的dim(4)和排列次序。
在img2img中:
  1.     # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
  2.     def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
  3.         if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
  4.             raise ValueError(
  5.                 f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
  6.             )
  7.         image = image.to(device=device, dtype=dtype)
  8.         batch_size = batch_size * num_images_per_prompt
  9.         if image.shape[1] == 4:
  10.             init_latents = image
  11.         else:
  12.             if isinstance(generator, list) and len(generator) != batch_size:
  13.                 raise ValueError(
  14.                     f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
  15.                     f" size of {batch_size}. Make sure the batch size matches the length of the generators."
  16.                 )
  17.             elif isinstance(generator, list):
  18.                 init_latents = [
  19.                     self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
  20.                 ]
  21.                 init_latents = torch.cat(init_latents, dim=0)
  22.             else:
  23.                 init_latents = self.vae.encode(image).latent_dist.sample(generator)
  24.             init_latents = self.vae.config.scaling_factor * init_latents
  25.         if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
  26.             # expand init_latents for batch_size
  27.             deprecation_message = (
  28.                 f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
  29.                 " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
  30.                 " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
  31.                 " your script to pass as many initial images as text prompts to suppress this warning."
  32.             )
  33.             deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
  34.             additional_image_per_prompt = batch_size // init_latents.shape[0]
  35.             
  36.             init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
  37.         elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
  38.             raise ValueError(
  39.                 f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
  40.             )
  41.         else:
  42.             init_latents = torch.cat([init_latents], dim=0)
  43.         shape = init_latents.shape
  44.         noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
  45.         # get latents
  46.         init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
  47.         latents = init_latents
  48.         return latents
复制代码
3.应用

3.1 问题

  1. new_map = texture.permute(1, 2, 0)
  2. RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3
复制代码
该问题是张量外形的问题,跟通道数毫无关系。
3.2 举例

问:4D 张量:外形为 (B, C, H, W),此中C可以为3吗?
答:4D 张量的外形为 (B,C,H,W),此中 C 表示通道数。通常环境下,C 可以为 3,这对应于 RGB 图像的三个颜色通道(红色、绿色和蓝色)。
3.3 张量可以理解为多维可变数组

  1. print("sample:", sample.shape)
  2. print("sample:", sample[0].shape)
  3. print("sample:", sample[0][0].shape)
复制代码
  1. >>
  2. sample: torch.Size([10, 4, 96, 96])
  3. sample: torch.Size([4, 96, 96])
  4. sample: torch.Size([96, 96])
复制代码
由此可见,可以将张量外形为torch.size([10, 4, 96, 96])理解为一个4维可变数组。
3.4 将张量化为list

3.4.1

  1. # sample: torch.Size([10, 4, 96, 96])
  2. views = [view for view in sample]
  3. print("views:", views.shape)
复制代码
  1. >>AttributeError: 'list' object has no attribute 'shape'
复制代码
此时应该:
  1. print("views:", views[0].shape)
复制代码
  1. >>views: torch.Size([4, 96, 96])
复制代码
3.4.2

  1. # 方法二
  2. for i, view in enumerate(prev_views):
  3.         pred_prev_sample[i] = view
复制代码
3.5 将list化为张量

3.5.1

  1. # 定义一个Python列表
  2. my_list = [1, 2, 3, 4, 5]
  3. # 将Python列表转换为PyTorch张量
  4. my_tensor = torch.tensor(my_list)
  5. print(my_tensor)
复制代码
  1. >>tensor([1, 2, 3, 4, 5])
复制代码
3.5.2

  1. # 假设你有一个包含多个张量的列表
  2. tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
  3. # 使用torch.stack将它们堆叠成一个新的张量
  4. stacked_tensor = torch.stack(tensor_list)
  5. print(stacked_tensor)
复制代码
  1. >>tensor([[1, 2, 3],
  2.           [4, 5, 6],
  3.           [7, 8, 9]])
复制代码
张量运算时对轴参数的设定非经常见,在 Numpy 中一般是参数axis,在 Pytorch 中一般是参数dim,但它们寄义是一样的。
深度学习中的轴/axis/dim全解
  1. # 默认情况下,它在新的维度(即0维)上堆叠这些张量。
  2. # views is a list,and views[0].shape is ([4, 96, 96]).
  3. views = torch.stack(views, axis=0) # ([10, 4, 96, 96])
复制代码
3.5.3 沿着现有维度拼接/在新的维度上增长维度

  1. import torch
  2. # 假设你有一个包含多个张量的列表
  3. tensor_list = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
  4. # 使用 torch.cat 将张量沿着现有维度拼接
  5. concatenated_tensor = torch.cat(tensor_list, dim=0)
  6. # 使用 torch.unsqueeze 在新的维度上增加维度
  7. stacked_tensor = torch.unsqueeze(concatenated_tensor, dim=0)
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

反转基因福娃

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表