r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. """ if resume_from_checkpoint is not None: raise ValueError("`resume_from_checkpoint` will be supported in the future version.") total_train_batch_size = ( self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.finetuning_args.ppo_buffer_size * self.args.world_size )
) -> "torch.optim.Optimizer": optimizer = create_custom_optimzer(model, training_args, finetuning_args) if optimizer is None: decay_params, nodecay_params = [], [] decay_param_names = self.get_decay_parameter_names(model) for name, param in model.named_parameters(): if param.requires_grad: if name in decay_param_names: decay_params.append(param) else:
return lr_scheduler @torch.no_grad() def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. """ if self.model_args.upcast_layernorm: layernorm_params = dump_layernorm(self.model) if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1 start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item() for k, v in batch.items(): batch[k] = v[:, start_index:] unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) generate_output: torch.Tensor = unwrapped_model.generate( generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch )
if self.model_args.upcast_layernorm: restore_layernorm(self.model, layernorm_params) query = batch["input_ids"].detach().cpu() response = generate_output[:, batch["input_ids"].size(-1) :].detach().cpu() queries, responses = [], [] for i in range(len(query)): query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() if len(response_index) == 0: response_length = 1 # allow empty response else:
response_length = response_index[-1].item() + 1 queries.append(query[i, query_start_index:]) # remove padding from left responses.append(response[i, :response_length]) # remove padding from right return queries, responses @torch.no_grad() def get_rewards( self, queries: List[torch.Tensor], responses: List[torch.Tensor], unwrapped_model: "AutoModelForCausalLMWithValueHead", ) -> List[torch.Tensor]: r""" Computes scores using given reward model. Both inputs and outputs are put on CPU. """ if self.finetuning_args.reward_model_type == "api": token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)] messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) return get_rewards_from_server(self.reward_model, messages) if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="reward") reward_model = self.model else:
reward_model = self.reward_model batch = self.prepare_model_inputs(queries, responses) with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False) if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture values = torch.transpose(values, 0, 1) rewards = [] for i in range(values.size(0)): end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero() end_index = end_indexes[-1].item() if len(end_indexes) else 0 rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type if self.finetuning_args.reward_model_type == "lora": replace_model(unwrapped_model, target="default") return rewards