Dataset for Stable Diffusion

打印 上一主题 下一主题

主题 864|帖子 864|积分 2592

1.Dataset for Stable Diffusion

条记泉源:
1.Flickr8k数据集处理
2.处理Flickr8k数据集
3.Github:pytorch-stable-diffusion
4.Flickr 8k Dataset
5.dataset_flickr8k.json
6.About Train, Validation and Test Sets in Machine Learning Tarang Shah Towards Data Science
7.What are hyperparameters?
1.1 Dataset

采用Flicker8k数据集,该数据集有两个文件,第一个文件为Flicker8k_Dataset (全部为图片),第二个文件为Flickr8k.token.txt (含两列image_id和caption),其中一个image_id对应5个caption (sentence)
           
     
           
     
           
     
           
      
1.2 Dataset description file

数据集文本形貌文件:dataset_flickr8k.json
文件格式如下:
{“images”: [ {“sentids”: [ ],“imgid”: 0,“sentences”:[{“tokens”:[ ]}, {“tokens”:[ ], “raw”: “…”, “imgid”:0, “sentid”:0}, …, “split”: “train”, “filename”: …jpg}, {“sentids”…} ], “dataset”: “flickr8k”}
参数解释“sentids”:[0,1,2,3,4]caption 的 id 范围(一个image对应5个caption,以是sentids从0到4)“imgid”:0image 的 id(从0到7999共8000张image)“sentences”:[ ]包含一张照片的5个caption“tokens”:[ ]每个caption分割为单个word“raw”: " "每个token连接起来的caption“imgid”: 0与caption相匹配的image的id“sentid”: 0imag0对应的详细的caption的id“split”:" "将该image和对应caption划分到练习集or验证集or测试集“filename”:“…jpg”image详细名称 dataset_flickr8k.json

1.3 Process Datasets

下面代码摘自:Flickr8k数据集处理(仅作学习使用)
  1. import json
  2. import os
  3. import random
  4. from collections import Counter, defaultdict
  5. from matplotlib import pyplot as plt
  6. from PIL import Image
  7. from argparse import Namespace
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. from torch.nn.utils.rnn import pack_padded_sequence
  12. from torch.utils.data import Dataset
  13. import torchvision
  14. import torchvision.transforms as transforms
  15. def create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):
  16.     """
  17.     Parameters:
  18.         dataset: Name of the dataset
  19.         captions_per_image: Number of captions per image
  20.         min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)
  21.         max_len: Maximum number of words in a caption. Captions longer than this will be truncated.
  22.     Output:
  23.         A vocabulary file: vocab.json
  24.         Three dataset files: train_data.json, val_data.json, test_data.json
  25.     """
  26.     # Paths for reading data and saving processed data
  27.     # Path to the dataset JSON file
  28.     flickr_json_path = ".../sd/data/dataset_flickr8k.json"
  29.     # Folder containing images
  30.     image_folder = ".../sd/data/Flicker8k_Dataset"
  31.     # Folder to save processed results
  32.     # The % operator is used to format the string by replacing %s with the value of the dataset variable.
  33.     # For example, if dataset is "flickr8k", the resulting output_folder will be
  34.     # /home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion/sd/data/flickr8k.
  35.     output_folder = ".../sd/data/%s" % dataset
  36.     # Ensure output directory exists
  37.     os.makedirs(output_folder, exist_ok=True)
  38.     print(f"Output folder: {output_folder}")
  39.     # Read the dataset JSON file
  40.     with open(file=flickr_json_path, mode="r") as j:
  41.         data = json.load(fp=j)
  42.     # Initialize containers for image paths, captions, and vocabulary
  43.     # Dictionary to store image paths
  44.     image_paths = defaultdict(list)
  45.     # Dictionary to store image captions
  46.     image_captions = defaultdict(list)
  47.     # Count the number of elements, then count and return a dictionary
  48.     # key:element value:the number of elements.
  49.     vocab = Counter()
  50.     # read from file dataset_flickr8k.json
  51.     for img in data["images"]:  # Iterate over each image in the dataset
  52.         split = img["split"]  # Determine the split (train, val, or test) for the image
  53.         captions = []
  54.         for c in img["sentences"]:  # Iterate over each caption for the image
  55.             # Update word frequency count, excluding test set data
  56.             if split != "test":  # Only update vocabulary for train/val splits
  57.                 # c['tokens'] is a list, The number of occurrences of each word in the list is increased by one
  58.                 vocab.update(c['tokens'])  # Update vocabulary with words in the caption
  59.             # Only consider captions that are within the maximum length
  60.             if len(c["tokens"]) <= max_len:
  61.                 captions.append(c["tokens"])  # Add the caption to the list if it meets the length requirement
  62.         if len(captions) == 0:  # Skip images with no valid captions
  63.             continue
  64.         # Construct the full image path/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion
  65.         # image_folder + image_name
  66.         # ./Flicker8k_Dataset/img['filename']
  67.         path = os.path.join(image_folder, img['filename'])
  68.         # Save the full image path and its captions in the respective dictionaries
  69.         image_paths[split].append(path)
  70.         image_captions[split].append(captions)
  71.     '''
  72.     After the above steps, we have:
  73.     - vocab(a dict) keys:words、values: counts of all words
  74.     - image_paths: (a dict) keys "train", "val", and "test"; values: lists of absolute image paths
  75.     - image_captions: (a dict) keys: "train", "val", and "test"; values: lists of captions
  76.     '''/home/wxy/Documents/PycharmProjects/pytorch-stable-diffusion
  77.                 ....
  78.                 ....
复制代码
我们通过dataset_flickr8k.json文件把数据集转化为三个词典
dictkeyvaluevacabwordfrequency of words in all captionsimage_path“train”、“val”、“test”lists of absolute image pathimage_captions“train”、“val”、“test”lists of captions 我们通过Debug打印其中的内容
  1. print(vocab)
  2. print(image_paths["train"][1])
  3. print(image_captions["train"][1])
复制代码

  1. def create_dataset(dataset='flickr8k', captions_per_image=5, min_word_count=5, max_len=30):
  2.     """
  3.     Parameters:
  4.         dataset: Name of the dataset
  5.         captions_per_image: Number of captions per image
  6.         min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)
  7.         max_len: Maximum number of words in a caption. Captions longer than this will be truncated.
  8.     Output:
  9.         A vocabulary file: vocab.json
  10.         Three dataset files: train_data.json, val_data.json, test_data.json
  11.     """
  12.     ....
  13.     ....
  14.     # Create the vocabulary, adding placeholders for special tokens
  15.     # Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end>
  16.     words = [w for w in vocab.keys() if vocab[w] > min_word_count]  # Filter words by minimum count
  17.     vocab = {k: v + 1 for v, k in enumerate(words)}  # Create the vocabulary with indices
  18.     # Add special tokens to the vocabulary
  19.     vocab['<pad>'] = 0
  20.     vocab['<unk>'] = len(vocab)
  21.     vocab['<start>'] = len(vocab)
  22.     vocab['<end>'] = len(vocab)
  23.     # Save the vocabulary to a file
  24.     with open(os.path.join(output_folder, 'vocab.json'), "w") as fw:
  25.         json.dump(vocab, fw)
  26.     # Process each dataset split (train, val, test)
  27.     # Iterate over each split: split = "train" 、 split = "val" 和 split = "test"
  28.     for split in image_paths:
  29.         # List of image paths for the split
  30.         imgpaths = image_paths[split]  # type(imgpaths)=list
  31.         # List of captions for the split
  32.         imcaps = image_captions[split]  # type(imcaps)=list
  33.         # store result that converting words of caption to their respective indices in the vocabulary
  34.         enc_captions = []
  35.         for i, path in enumerate(imgpaths):
  36.             # Check if the image can be opened
  37.             img = Image.open(path)
  38.             # Ensure each image has the required number of captions
  39.             if len(imcaps[i]) < captions_per_image:
  40.                 filled_num = captions_per_image - len(imcaps[i])
  41.                 # Repeat captions if needed
  42.                 captions = imcaps[i] + [random.choice(imcaps[i]) for _ in range(0, filled_num)]
  43.             else:
  44.                 # Randomly sample captions if there are more than needed
  45.                 captions = random.sample(imcaps[i], k=captions_per_image)
  46.             assert len(captions) == captions_per_image
  47.             for j, c in enumerate(captions):
  48.                 # Encode each caption by converting words to their respective indices in the vocabulary
  49.                 enc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]
  50.                 enc_captions.append(enc_c)
  51.         assert len(imgpaths) * captions_per_image == len(enc_captions)
  52.         data = {"IMAGES": imgpaths,
  53.                 "CAPTIONS": enc_captions}
  54.         # Save the processed dataset for the current split (train,val,test)
  55.         with open(os.path.join(output_folder, split + "_data.json"), 'w') as fw:
  56.             json.dump(data, fw)
  57. create_dataset()
复制代码
经过create_dataset函数,我们得到如下图的文件

四个文件的详细内容见下表
           
train_data.json中的第一个key:IMAGES     
           
train_data.json中的第二个key:CAPTIONS      
           
test_data.json中的第一个key:IMAGES     
           
test_data.json中的第二个key:CAPTIONS     
           
val_data.json中的第一个key:IMAGES     
           
val_data.json中的第二个key:CAPTIONS      
           
vocab.json开始部门     
           
vocab.json结尾部门      
天生vocab.json的关键代码
起首统计所有caption中word出现至少大于5次的word,而后给这些word依次赋予一个下标
  1. # Create the vocabulary, adding placeholders for special tokens
  2.     # Add placeholders<pad>, unregistered word identifiers<unk>, sentence beginning and end identifiers<start><end>
  3.     # Create a list of words from the vocabulary that have a frequency higher than 'min_word_count'
  4.     # min_word_count: Only consider words that appear at least this many times in the dataset (excluding the test set)
  5.     words = [w for w in vocab.keys() if vocab[w] > min_word_count]  # Filter words by minimum count
  6.     # assign an index to each word, starting from 1 (indices start from 0, so add 1)
  7.     vocab = {k: v + 1 for v, k in enumerate(words)}  # Create the vocabulary with indices
复制代码
终极天生vocab.json

天生 [“split”]_data.json 的关键
读入文件dataset_flickr8k.json,并创建两个字典,第一个字典放置每张image的绝对路径,第二个字典放置形貌image的caption,根据vocab将token换为下标保存,根据文件dataset_flickr8k.json中差异的split,这image的绝对路径和相应caption保存在差异文件中(train_data.json、test_data.json、val_data.json)
dataset_flickr8k.json

train_data.json

从vocab中获取token的下标得到CAPTION的编码
  1. for j, c in enumerate(captions):
  2.   # Encode each caption by converting words to their respective indices in the vocabulary
  3.   enc_c = [vocab['<start>']] + [vocab.get(word, vocab['<unk>']) for word in c] + [vocab["<end>"]]
  4.   enc_captions.append(enc_c)
复制代码

尝试使用上面天生的测试集文件test_data.json和vocab.json输出某张image以及对应的caption
下面代码摘自:Flickr8k数据集处理(仅作学习使用)
  1. '''
  2. test
  3. 1.Iterates over the 5 captions for 下面代码引用自:[Flickr8k数据集处理](https://blog.csdn.net/weixin_48981284/article/details/134676813)(仅作学习使用)the 250th image.
  4. 2.Retrieves the word indices for each caption.
  5. 3.Converts the word indices to words using vocab_idx2word.
  6. 4.Joins the words to form complete sentences.
  7. 5.Prints each caption.
  8. '''
  9. import json
  10. from PIL import Image
  11. from matplotlib import pyplot as plt
  12. # Load the vocabulary from the JSON file
  13. with open('.../sd/data/flickr8k/vocab.json', 'r') as f:
  14.     vocab = json.load(f)  # Load the vocabulary from the JSON file into a dictionary
  15. # Create a dictionary to map indices to words
  16. vocab_idx2word = {idx: word for word, idx in vocab.items()}
  17. # Load the test data from the JSON file
  18. with open('.../sd/data/flickr8k/test_data.json', 'r') as f:
  19.     data = json.load(f)  # Load the test data from the JSON file into a dictionary
  20. # Open and display the 250th image in the test set
  21. # Open the image at index 250 in the 'IMAGES' list
  22. content_img = Image.open(data['IMAGES'][250])
  23. plt.figure(figsize=(6, 6))
  24. plt.subplot(1,1,1)
  25. plt.imshow(content_img)
  26. plt.title('Image')
  27. plt.axis('off')
  28. plt.show()
  29. # Print the lengths of the data, image list, and caption list
  30. # Print the number of keys in the dataset dictionary (should be 2: 'IMAGES' and 'CAPTIONS')
  31. print(len(data))
  32. print(len(data['IMAGES']))  # Print the number of images in the 'IMAGES' list
  33. print(len(data["CAPTIONS"]))  # Print the number of captions in the 'CAPTIONS' list
  34. # Display the captions for the 300th image
  35. # Iterate over the 5 captions associated with the 300th image
  36. for i in range(5):
  37.     # Get the word indices for the i-th caption of the 300th image
  38.     word_indices = data['CAPTIONS'][250 * 5 + i]
  39.     # Convert indices to words and join them to form a caption
  40.     print(''.join([vocab_idx2word[idx] for idx in word_indices]))
复制代码

data 的 key 有两个 IMAGES 和 CAPTIONS
测试集image有1000张,每张对应5个caption,共5000个caption
第250张图片的5个caption如下图

1.4 Dataloader

下面代码摘自:Flickr8k数据集处理(仅作学习使用)
  1. import json
  2. import os
  3. import random
  4. from collections import Counter, defaultdict
  5. from PIL import Image
  6. import torch
  7. from torch.utils.data import Dataset
  8. from torch.utils import data
  9. import torchvision.transforms as transforms
  10. class ImageTextDataset(Dataset):
  11.     """
  12.     Pytorch Dataset class to generate data batches using torch DataLoader
  13.     """
  14.     def __init__(self, dataset_path, vocab_path, split, captions_per_image=5, max_len=30, transform=None):
  15.         """
  16.         Parameters:
  17.             dataset_path: Path to the JSON file containing the dataset
  18.             vocab_path: Path to the JSON file containing the vocabulary
  19.             split: The dataset split, which can be "train", "val", or "test"
  20.             captions_per_image: Number of captions per image
  21.             max_len: Maximum number of words per caption
  22.             transform: Image transformation methods
  23.         """
  24.         self.split = split
  25.         # Validate that the split is one of the allowed values
  26.         assert self.split in {"train", "val", "test"}
  27.         # Store captions per image
  28.         self.cpi = captions_per_image
  29.         # Store maximum caption length
  30.         self.max_len = max_len
  31.         # Load the dataset
  32.         with open(dataset_path, "r") as f:
  33.             self.data = json.load(f)
  34.         # Load the vocabulary
  35.         with open(vocab_path, "r") as f:
  36.             self.vocab = json.load(f)
  37.         # Store the image transformation methods
  38.         self.transform = transform
  39.         # Number of captions in the dataset
  40.         # Calculate the size of the dataset
  41.         self.dataset_size = len(self.data["CAPTIONS"])
  42.     def __getitem__(self, i):
  43.         """
  44.             Retrieve the i-th sample from the dataset
  45.         """
  46.         # Get [i // self.cpi]-th image corresponding to the i-th sample (each image has multiple captions)
  47.         img = Image.open(self.data['IMAGES'][i // self.cpi]).convert("RGB")
  48.         # Apply image transformation if provided
  49.         if self.transform is not None:
  50.             # Apply the transformation to the image
  51.             img = self.transform(img)
  52.         # Get the length of the caption
  53.         caplen = len(self.data["CAPTIONS"][i])
  54.         # Pad the caption if its length is less than max_len
  55.         pad_caps = [self.vocab['<pad>']] * (self.max_len + 2 - caplen)
  56.         # Convert the caption to a tensor and pad it
  57.         caption = torch.LongTensor(self.data["CAPTIONS"][i] + pad_caps)
  58.         return img, caption, caplen  # Return the image, caption, and caption length
  59.     def __len__(self):
  60.         return self.dataset_size  # Number of samples in the dataset
  61. def make_train_val(data_dir, vocab_path, batch_size, workers=4):
  62.     """
  63.         Create DataLoader objects for training, validation, and testing sets.
  64.         Parameters:
  65.             data_dir: Directory where the dataset JSON files are located
  66.             vocab_path: Path to the vocabulary JSON file
  67.             batch_size: Number of samples per batch
  68.             workers: Number of subprocesses to use for data loading (default is 4)
  69.         Returns:
  70.             train_loader: DataLoader for the training set
  71.             val_loader: DataLoader for the validation set
  72.             test_loader: DataLoader for the test set
  73.     """
  74.     # Define transformation for training set
  75.     train_tx = transforms.Compose([
  76.         transforms.Resize(256),  # Resize images to 256x256
  77.         transforms.ToTensor(),  # Convert image to PyTorch tensor
  78.         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet mean and std
  79.     ])
  80.     val_tx = transforms.Compose([
  81.         transforms.Resize(256),
  82.         transforms.ToTensor(),
  83.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  84.     ])
  85.     # Create dataset objects for training, validation, and test sets
  86.     train_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "train_data.json"), vocab_path=vocab_path,
  87.                                  split="train", transform=train_tx)
  88.     vaild_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "val_data.json"), vocab_path=vocab_path,
  89.                                  split="val", transform=val_tx)
  90.     test_set = ImageTextDataset(dataset_path=os.path.join(data_dir, "test_data.json"), vocab_path=vocab_path,
  91.                                 split="test", transform=val_tx)
  92.     # Create DataLoader for training set with data shuffling
  93.     train_loder = data.DataLoader(
  94.         dataset=train_set, batch_size=batch_size, shuffer=True,
  95.         num_workers=workers, pin_memory=True
  96.     )
  97.     # Create DataLoader for validation set without data shuffling
  98.     val_loder = data.DataLoader(
  99.         dataset=vaild_set, batch_size=batch_size, shuffer=False,
  100.         num_workers=workers, pin_memory=True, drop_last=False
  101.     )
  102.     # Create DataLoader for test set without data shuffling
  103.     test_loder = data.DataLoader(
  104.         dataset=test_set, batch_size=batch_size, shuffer=False,
  105.         num_workers=workers, pin_memory=True, drop_last=False
  106.     )
  107.     return train_loder, val_loder, test_loder
复制代码
创建好train_loader后,接下来我们就可以着手开始练习SD了!
1.5 Training、Validation、Test Set

相识练习集、测试集、验证集的作用

练习集
用于模子练习阶段
   Training Dataset: The sample of data used to fit the model.
  The actual dataset that we use to train the model (weights and biases in the case of a Neural Network). The model sees and learns from this data.
验证集
用于模子调参阶段
   Validation Dataset: The sample of data used to provide an unbiased evaluation of a model fit on the training dataset while tuning model hyperparameters. The evaluation becomes more biased as skill on the validation dataset is incorporated into the model configuration.
  The validation set is used to evaluate a given model, but this is for frequent evaluation. We, as machine learning engineers, use this data to fine-tune the model hyperparameters. Hence the model occasionally sees this data, but never does it “Learn” from this. We use the validation set results, and update higher level hyperparameters. So the validation set affects a model, but only indirectly. The validation set is also known as the Dev set or the Development set. This makes sense since this dataset helps during the “development” stage of the model.
测试集
在模子练习且调参阶段完成后测试模子性能
   Test Dataset: The sample of data used to provide an unbiased evaluation of a final model fit on the training dataset.
  The Test dataset provides the gold standard used to evaluate the model. It is only used once a model is completely trained(using the train and validation sets). The test set is generally what is used to evaluate competing models (For example on many Kaggle competitions, the validation set is released initially along with the training set and the actual test set is only released when the competition is about to close, and it is the result of the the model on the Test set that decides the winner). Many a times the validation set is used as the test set, but it is not good practice. The test set is generally well curated. It contains carefully sampled data that spans the various classes that the model would face, when used in the real world.
数据集划分比例
Now that you know what these datasets do, you might be looking for recommendations on how to split your dataset into Train, Validation and Test sets.
This mainly depends on 2 things. First, the total number of samples in your data and second, on the actual model you are training.
Some models need substantial data to train upon, so in this case you would optimize for the larger training sets. Models with very few hyperparameters will be easy to validate and tune, so you can probably reduce the size of your validation set, but if your model has many hyperparameters, you would want to have a large validation set as well(although you should also consider cross validation). Also, if you happen to have a model with no hyperparameters or ones that cannot be easily tuned, you probably don’t need a validation set too!
All in all, like many other things in machine learning, the train-test-validation split ratio is also quite specific to your use case and it gets easier to make judge ment as you train and build more and more models.
   What are hyperparameters?
Hyperparameters are external configuration variables that data scientists use to manage machine learning model training. Sometimes called model hyperparameters, the hyperparameters are manually set before training a model. They’re different from parameters, which are internal parameters automatically derived during the learning process and not set by data scientists.
Examples of hyperparameters include the number of nodes and layers in a neural network and the number of branches in a decision tree. Hyperparameters determine key features such as model architecture, learning rate, and model complexity.

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

莫张周刘王

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表