JAVA实现从零实现扩散模子stable diffusion系列(一)

打印 上一主题 下一主题

主题 1972|帖子 1972|积分 5916

概要

在上一个文章咱们已经实现了一个基于llama3.1的大语言模子(LLM模子)。本日咱们继承来利用Omega-AI深度学习引擎从零实现一个stable diffusion模子,并实现文生图场景应用。
Omega-AI深度学习引擎

Omega-AI:基于java打造的深度学习框架,资助你快速搭建神经网络,实现训练或测试模子,支持多卡训练,框架现在支持BP神经网络、卷积神经网络、循环神经网络、vgg16、resnet、yolo、lstm、transformer、diffusion、gpt、llama、llava等模子的构建,现在引擎最新版本支持CUDA和CUDNN两种GPU加速方式,关于GPU加速的环境配置与jcuda版本jar包的对应依靠。
Omega-AI简介
JAVA实现从零大语言模子llama3
结果展示

基于stable diffusion模子实现文生图

文生图演示图

文本1图片1文本2图片2a highly detailed anime landscape,big tree on the water, epic sky,golden grass,detailed.
3d art of a golden tree in the river,with intricate flora and flowing water,detailed.
a vibrant anime mountain lands
a dark warrior in epic armor stands among glowing crimson leaves in a mystical forest.
cute fluffy panda, anime, ghibli style, pastel colors, soft shadows, detailed fur, vibrant eyes, fantasy setting, digital art, 3d, by kazuo oga
a epic city,3d,detailed.
Quick Start

环境配置



  • JDK1.8以上
  • CUDA11.X/12.X
  1. // 检查当前安装的CUDA版本
  2. nvcc --version
复制代码
安装CUDA与CUDNN
https://developer.nvidia.com/cuda-toolkit-archive
下载与配置Omega-AI深度学习引擎



  • 下载Omega-AI深度学习引擎
  1. git clone https://github.com/dromara/Omega-AI.git
  2. git clone https://gitee.com/dromara/omega-ai.git
复制代码


  • 根据当前CUDA版本配置JCUDA依靠
    打开Omega-AI pom.xml文件,根据当前CUDA版本修改依靠
      提示:如您安装的cuda版本为12.x,请利用jcuda12.0.0版本
  1.         <properties>
  2.             <java.version>1.8</java.version>
  3.             <!--当前cuda版本为11.8.x,对应jcuda版本为11.8.0-->
  4.                 <jcuda.version>11.8.0</jcuda.version>
  5.                 <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
  6.                 <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
  7.                 <resource.delimiter>@</resource.delimiter>
  8.             <maven.compiler.source>${java.version}</maven.compiler.source>
  9.             <maven.compiler.target>${java.version}</maven.compiler.target>
  10.         </properties>
复制代码
stable diffusion架构

传统的扩散模子有两大限定:1.输入图片尺寸与计算量巨细限定导致效率低下,2.只能输入随机噪声导致无法控制输出结果。而stable diffsuion引入了latent space概念,使得在其可以在较少的内存占用完成高清的图片天生。在解决只能输入随机噪声的题目上,stable diffusion利用了clip text的text encoder把文本信息作为条件输入到text conditioned lantent unet当中,并利用cross attention把text条件与图像融合。总结以上内容,stable diffsuion总共分为三大组件:VAE(变分自编码器)负责把图片编码成相对较小的latent space数据和解码latent space还原成正常巨细的图片。CLIP TEXT当中的text encoder负责把文本内容编码成77*512的 token embeddings。lantent unet负责结合条件天生latent space,与传统的diffusion模子的unet相比,stable diffusion的unet利用的是cross attention机制,目标就是为了融合条件信息。以下是stable diffusion流程图:

1 STEP 训练VQ-VAE(变分自编码器)

1.1下载与预处理训练数据



  • 本次使命将利用开源动画风格的图文对数据集【rapidata】点击下载
  • 处理图片巨细统一为256 * 256或者512 * 512
  • 制作元数据并存储为json文件,数据格式为:[{“id”: “0”, “en”: “cinematic bokeh: ironcat, sharp focus on the cat’s eyes, blurred background, dramatic chiaroscuro lighting, deep shadows, high contrast, rich textures, high resolution”}]
      提示:可下载已经处理好的数据集
    点击下载已处理后的数据集
  • 利用数据加载器读取训练数据,代码如下:
  1.         int batchSize = 2;
  2.            int imageSize = 256;
  3.            float[] mean = new float[] {0.5f, 0.5f, 0.5f};
  4.     float[] std = new float[] {0.5f, 0.5f, 0.5f};
  5.            String imgDirPath = "I:\\dataset\\sd-anime\\anime_op\\256\";
  6.            DiffusionImageDataLoader dataLoader = new DiffusionImageDataLoader(imgDirPath, imageSize, imageSize, batchSize, true, false, mean, std);
复制代码
1.2 创建VQ-VAE模子

  1.                 /**
  2.                     * LossType lossType: 损失函数
  3.                     * UpdaterType updater, 参数更新方法
  4.                     * int z_dims, 输出latent space维度
  5.                     * int latendDim, 输出latent space通道数
  6.                     * latent space形状为[batchSize, latendDim, z_dims, z_dims]
  7.                     * int num_res_blocks, 每个采样层所包含的residual层数
  8.                     * int imageSize, 输入图片大小
  9.                     * int[] ch_mult, unet上下采样层通道倍数
  10.                     * int ch, unet上下采样层通道基数
  11.                     */
  12.                    VQVAE2 network = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imageSize, ch_mult, ch, num_res_blocks);
复制代码
1.3 创建LPIPS模子

为了增强vae的还原图片的清晰度,在训练vae模子的过程中添加lpips(感知丧失),该模子用于度量两张图片之间的差别。
  1.   /**
  2.    * LossType lossType: 损失函数(均方差损失函数)
  3.    * UpdaterType updater, 参数更新方法
  4.    * int imageSize, 输入图片大小
  5.    */
  6.   LPIPS lpips = new LPIPS(LossType.MSE, UpdaterType.adamw, imageSize);
复制代码
完整训练代码如下:

  1.   public static void anime_vqvae2_lpips_gandisc_32_nogan() {
  2.           try {
  3.                   nt batchSize = 16;
  4.                   int imageSize = 256;
  5.                   int z_dims = 32;
  6.                   int latendDim = 4;
  7.                   int num_vq_embeddings = 512;
  8.                   int num_res_blocks = 1;
  9.                   int[] ch_mult = new int[] {1, 2, 2, 4};
  10.                   int ch = 32;
  11.                   float[] mean = new float[] {0.5f, 0.5f, 0.5f};
  12.                   float[] std = new float[] {0.5f, 0.5f, 0.5f};
  13.                  
  14.                   String imgDirPath = "I:\\dataset\\sd-anime\\anime_op\\256\";
  15.                   DiffusionImageDataLoader dataLoader = new DiffusionImageDataLoader(imgDirPath, imageSize, imageSize, batchSize, true, false, mean, std);
  16.                   /**
  17.                    * LossType lossType: 损失函数
  18.                    * UpdaterType updater, 参数更新方法
  19.                    * int z_dims, 输出latent space维度
  20.                    * int latendDim, 输出latent space通道数
  21.                    * latent space形状为[batchSize, latendDim, z_dims, z_dims]
  22.                    * int num_res_blocks, 每个采样层所包含的residual层数
  23.                    * int imageSize, 输入图片大小
  24.                    * int[] ch_mult, unet上下采样层通道倍数
  25.                    * int ch, unet上下采样层通道基数
  26.                    */
  27.                   VQVAE2 network = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imageSize, ch_mult, ch, num_res_blocks);
  28.                   network.CUDNN = true;
  29.                   network.learnRate = 0.001f;
  30.                  
  31.                   LPIPS lpips = new LPIPS(LossType.MSE, UpdaterType.adamw, imageSize);
  32.                   //加载权重
  33.                   String lpipsWeight = "H:\\model\\lpips.json";
  34.                   LPIPSTest.loadLPIPSWeight(LagJsonReader.readJsonFileSmallWeight(lpipsWeight), lpips, false);
  35.                   lpips.CUDNN = true;
  36.                  
  37.                   MBSGDOptimizer optimizer = new MBSGDOptimizer(network, 200, 0.00001f, batchSize, LearnRateUpdate.CONSTANT, false);
  38.                   optimizer.trainVQVAE2_lpips_nogan(dataLoader, lpips);
  39.                   String save_model_path = "/omega/models/anime_vqvae2_256.model";
  40.                   ModelUtils.saveModel(network, save_model_path);
  41.           } catch (Exception e) {
  42.                   // TODO: handle exception
  43.                   e.printStackTrace();
  44.           }
  45.   }
复制代码
VQ-VAE演示图

原图VQ-VAE原图VQ-VAE
2 STEP 训练diffusion unet cond(条件扩散模子)

2.1 创建与加载Clip Text Encoder

本次使命利用clip-vit-base-patch32的encoder部分作为text encoder。
  1.         /**
  2.          * clipText shape[batchSize, 77, 512]
  3.          */
  4.         int time = maxContextLen;  //文本最大token长度
  5.         int maxPositionEmbeddingsSize = 77;  //文本最大token长度
  6.         int vocabSize = 49408;  //tokenizer词表长度
  7.         int headNum = 8;  //多头注意力头数
  8.         int n_layers = 12;  //CLIPEncoderLayer编码层层数
  9.         int textEmbedDim = 512;  //文本嵌入输出维度
  10.         ClipTextModel clip = new ClipTextModel(LossType.MSE, UpdaterType.adamw, headNum, time, vocabSize, textEmbedDim, maxPositionEmbeddingsSize, n_layers);
  11.         clip.CUDNN = true;
  12.         clip.time = time;
  13.         clip.RUN_MODEL = RunModel.EVAL; //设置推理模式
  14.        
  15.         String clipWeight = "H:\\model\\clip-vit-base-patch32.json";
  16.         ClipModelUtils.loadWeight(LagJsonReader.readJsonFileSmallWeight(clipWeight), clip, true);
复制代码
2.2 创建与加载VQ-VAE模子

  1.         /**
  2.          * LossType lossType: 损失函数
  3.          * UpdaterType updater, 参数更新方法
  4.          * int z_dims, 输出latent space维度
  5.          * int latendDim, 输出latent space通道数
  6.          * latent space形状为[batchSize, latendDim, z_dims, z_dims]
  7.          * int num_res_blocks, 每个采样层所包含的residual层数
  8.          * int imageSize, 输入图片大小
  9.          * int[] ch_mult, unet上下采样层通道倍数
  10.          * int ch, unet上下采样层通道基数
  11.          */
  12.         VQVAE2 vae = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imageSize, ch_mult, ch, num_res_blocks);
  13.         vae.RUN_MODEL = RunModel.EVAL;  //设置推理模式
  14.         //加载已训练好的vae模型权重
  15.         String vaeModel = "anime_vqvae2_256.model";
  16.         ModelUtils.loadModel(vae, vaeModel);
复制代码
2.3 创建Diffusion UNet Cond模子(条件扩散模子)

  1.         int unetHeadNum = 8;  //unet多头注意力头数
  2.         int[] downChannels = new int[] {128, 256, 512, 768};  //下采样通道数
  3.         int numLayer = 2;  //每层采样层的ResidualBlock个数
  4.         int timeSteps = 1000;  //扩散时间步数
  5.         int tEmbDim = 512;  //时序嵌入维度
  6.         int latentSize = 32;  //latent space维度
  7.         int groupNum = 32;  //group norm分组数
  8.                
  9.         DiffusionUNetCond2 unet = new DiffusionUNetCond2(LossType.MSE, UpdaterType.adamw, latendDim, latentSize, latentSize, downChannels, unetHeadNum, numLayer, timeSteps, tEmbDim, maxContextLen, textEmbedDim, groupNum);
  10.         unet.CUDNN = true;
  11.         unet.learnRate = 0.0001f;
复制代码
完整训练代码如下:

  1.         public static void tiny_sd_train_anime_32() throws Exception {
  2.                 String labelPath = "I:\\dataset\\sd-anime\\anime_op\\data.json";
  3.                 String imgDirPath = "I:\\dataset\\sd-anime\\anime_op\\256\";
  4.                 boolean horizontalFilp = true;
  5.                 int imgSize = 256;
  6.                 int maxContextLen = 77;
  7.                 int batchSize = 8;
  8.                 float[] mean = new float[] {0.5f, 0.5f,0.5f};
  9.                 float[] std = new float[] {0.5f, 0.5f,0.5f};
  10.                 //加载bpe tokenizer分词器
  11.                 String vocabPath = "H:\\model\\bpe_tokenizer\\vocab.json";
  12.                 String mergesPath = "H:\\model\\bpe_tokenizer\\merges.txt";
  13.                 BPETokenizerEN bpe = new BPETokenizerEN(vocabPath, mergesPath, 49406, 49407);
  14.                
  15.                 SDImageDataLoaderEN dataLoader = new SDImageDataLoaderEN(bpe, labelPath, imgDirPath, imgSize, imgSize, maxContextLen, batchSize, horizontalFilp, mean, std);
  16.                
  17.                 /**
  18.                  * clipText shape[batchSize, 77, 512]
  19.                  */
  20.                 int time = maxContextLen;  //文本最大token长度
  21.                 int maxPositionEmbeddingsSize = 77;  //文本最大token长度
  22.                 int vocabSize = 49408;  //tokenizer词表长度
  23.                 int headNum = 8;  //多头注意力头数
  24.                 int n_layers = 12;  //CLIPEncoderLayer编码层层数
  25.                 int textEmbedDim = 512;  //文本嵌入输出维度
  26.                 ClipTextModel clip = new ClipTextModel(LossType.MSE, UpdaterType.adamw, headNum, time, vocabSize, textEmbedDim, maxPositionEmbeddingsSize, n_layers);
  27.                 clip.CUDNN = true;
  28.                 clip.time = time;
  29.                 clip.RUN_MODEL = RunModel.EVAL;
  30.                
  31.                 String clipWeight = "H:\\model\\clip-vit-base-patch32.json";
  32.                 ClipModelUtils.loadWeight(LagJsonReader.readJsonFileSmallWeight(clipWeight), clip, true);
  33.                
  34.                 int z_dims = 128;
  35.                 int latendDim = 4;
  36.                 int num_vq_embeddings = 512;
  37.                 int num_res_blocks = 2;
  38.                 int[] ch_mult = new int[] {1, 2, 2, 4};
  39.                 int ch = 128;
  40.                
  41.                 VQVAE2 vae = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imgSize, ch_mult, ch, num_res_blocks);
  42.                 vae.CUDNN = true;
  43.                 vae.learnRate = 0.001f;
  44.                 vae.RUN_MODEL = RunModel.EVAL;
  45.                 String vaeModel = "anime_vqvae2_256.model";
  46.                 ModelUtils.loadModel(vae, vaeModel);
  47.                
  48.                 int unetHeadNum = 8;  //unet多头注意力头数
  49.                 int[] downChannels = new int[] {128, 256, 512, 768};  //下采样通道数
  50.                 int numLayer = 2;  //每层采样层的ResidualBlock个数
  51.                 int timeSteps = 1000;  //扩散时间步数
  52.                 int tEmbDim = 512;  //时序嵌入维度
  53.                 int latentSize = 32;  //latent space维度
  54.                 int groupNum = 32;  //group norm分组数
  55.                
  56.                 DiffusionUNetCond2 unet = new DiffusionUNetCond2(LossType.MSE, UpdaterType.adamw, latendDim, latentSize, latentSize, downChannels, unetHeadNum, numLayer, timeSteps, tEmbDim, maxContextLen, textEmbedDim, groupNum);
  57.                 unet.CUDNN = true;
  58.                 unet.learnRate = 0.0001f;
  59.                
  60.                 MBSGDOptimizer optimizer = new MBSGDOptimizer(unet, 500, 0.00001f, batchSize, LearnRateUpdate.CONSTANT, false);
  61.                 optimizer.trainTinySD_Anime(dataLoader, vae, clip);
  62.                 //保存训练完成的权重文件
  63.                 String save_model_path = "/omega/models/sd_anime256.model";
  64.                 ModelUtils.saveModel(unet, save_model_path);
  65.         }
复制代码
推理代码如下:

  1.         public static void tiny_sd_predict_anime_32() throws Exception {
  2.                
  3.                 int imgSize = 256;
  4.                 int maxContextLen = 77;
  5.                 String vocabPath = "H:\\model\\bpe_tokenizer\\vocab.json";
  6.                 String mergesPath = "H:\\model\\bpe_tokenizer\\merges.txt";
  7.                 BPETokenizerEN tokenizer = new BPETokenizerEN(vocabPath, mergesPath, 49406, 49407);
  8.                
  9.                 int time = maxContextLen;
  10.                 int maxPositionEmbeddingsSize = 77;
  11.                 int vocabSize = 49408;
  12.                 int headNum = 8;
  13.                 int n_layers = 12;
  14.                 int textEmbedDim = 512;
  15.                
  16.                 ClipTextModel clip = new ClipTextModel(LossType.MSE, UpdaterType.adamw, headNum, time, vocabSize, textEmbedDim, maxPositionEmbeddingsSize, n_layers);
  17.                 clip.CUDNN = true;
  18.                 clip.time = time;
  19.                 clip.RUN_MODEL = RunModel.EVAL;
  20.                
  21.                 String clipWeight = "H:\\model\\clip-vit-base-patch32.json";
  22.                 ClipModelUtils.loadWeight(LagJsonReader.readJsonFileSmallWeight(clipWeight), clip, true);
  23.                
  24.                 int z_dims = 128;
  25.                 int latendDim = 4;
  26.                 int num_vq_embeddings = 512;
  27.                 int num_res_blocks = 2;
  28.                 int[] ch_mult = new int[] {1, 2, 2, 4};
  29.                 int ch = 128;
  30.                
  31.                 VQVAE2 vae = new VQVAE2(LossType.MSE, UpdaterType.adamw, z_dims, latendDim, num_vq_embeddings, imgSize, ch_mult, ch, num_res_blocks);
  32.                 vae.CUDNN = true;
  33.                 vae.learnRate = 0.001f;
  34.                 vae.RUN_MODEL = RunModel.EVAL;
  35.                
  36.                 String vaeModel = "H:\\model\\anime_vqvae2_256.model";
  37.                 ModelUtils.loadModel(vae, vaeModel);
  38.                
  39.                 int unetHeadNum = 8;
  40.                 int[] downChannels = new int[] {64, 128, 256, 512};
  41.                 int numLayer = 2;
  42.                 int timeSteps = 1000;
  43.                 int tEmbDim = 512;
  44.                 int latendSize = 32;
  45.                 int groupNum = 32;
  46.                 int batchSize = 1;
  47.                
  48.                 DiffusionUNetCond2 unet = new DiffusionUNetCond2(LossType.MSE, UpdaterType.adamw, latendDim, latendSize, latendSize, downChannels, unetHeadNum, numLayer, timeSteps, tEmbDim, maxContextLen, textEmbedDim, groupNum);
  49.                 unet.CUDNN = true;
  50.                 unet.learnRate = 0.0001f;
  51.                 unet.RUN_MODEL = RunModel.TEST;
  52.                
  53.                 String model_path = "H:\\model\\sd_anime256.model";
  54.                 ModelUtils.loadModel(unet, model_path);
  55.                
  56.                 Scanner scanner = new Scanner(System.in);
  57.                
  58. //                Tensor latent = new Tensor(batchSize, latendDim, latendSize, latendSize, true);
  59.                 Tensor t = new Tensor(batchSize, 1, 1, 1, true);
  60.                 Tensor label = new Tensor(batchSize * unet.maxContextLen, 1, 1, 1, true);
  61.                
  62.                 Tensor input = new Tensor(batchSize, 3, imgSize, imgSize, true);
  63.                 Tensor latent = vae.encode(input);
  64.                
  65.                 while (true) {
  66.                         System.out.println("请输入英文:");
  67.                         String input_txt = scanner.nextLine();
  68.                         if(input_txt.equals("exit")){
  69.                                 break;
  70.                         }
  71.                         input_txt = input_txt.toLowerCase();
  72.                        
  73.                         loadLabels(input_txt, label, tokenizer, unet.maxContextLen);
  74.                         Tensor condInput = clip.forward(label);
  75.                         String[] labels = new String[] {input_txt, input_txt};
  76.                         MBSGDOptimizer.testSD(input_txt, latent, t, condInput, unet, vae, labels, "H:\\vae_dataset\\anime_test256\");
  77.                 }
  78.                 scanner.close();
  79.         }
复制代码
以上代码所需的文件请移步到百度云盘下载 点击下载

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

麻花痒

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表