torch
https://pytorch.org/docs/stable/torch.html
- torch.topk (Python function, in torch.topk)
- torch.Tensor.topk (Python method, in torch.Tensor.topk)
1. torch.topk
https://pytorch.org/docs/stable/generated/torch.topk.html
- torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
复制代码 Returns the k largest elements of the given input tensor along a given dimension.
返回给定 input 张量沿给定维度的前 k 个最大元素。
If dim is not given, the last dimension of the input is chosen.
如果未提供 dim,则选择 input 的末了一个维度。
If largest is False then the k smallest elements are returned.
如果 largest 为 False,则返回前 k 个最小元素。
A namedtuple of (values, indices) is returned with the values and indices of the largest k elements of each row of the input tensor in the given dimension dim.
返回一个元组 (values, indices),其中包罗 input 张量在给定维度 dim 上每行的前 k 个最大元素的 values 和 indices。
The boolean option sorted if True, will make sure that the returned k elements are themselves sorted.
如果布尔选项 sorted 为 True,则确保返回的 k 个元素本身已排序。
input (Tensor) - the input tensor.
k (int) - the k in “top-k”
dim (int, optional) - the dimension to sort along
要排序的维度
largest (bool, optional) - controls whether to return largest or smallest elements
控制是否返回最大或最小元素
sorted (bool, optional) - controls whether to return the elements in sorted order
控制是否按排序次序返回元素
out (tuple, optional) - the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
2. Example
- (base) yongqiang@yongqiang:~$ python
- Python 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] on linux
- Type "help", "copyright", "credits" or "license" for more information.
- >>> import torch
- >>>
- >>> input = torch.arange(1., 9.)
- >>> input
- tensor([1., 2., 3., 4., 5., 6., 7., 8.])
- >>>
- >>> torch.topk(input, 3)
- torch.return_types.topk(
- values=tensor([8., 7., 6.]),
- indices=tensor([7, 6, 5]))
- >>>
- >>> values, indices = torch.topk(input, 4)
- >>> values
- tensor([8., 7., 6., 5.])
- >>> indices
- tensor([7, 6, 5, 4])
- >>> exit()
- (base) yongqiang@yongqiang:~$
复制代码 3. Example
https://github.com/karpathy/llama2.c/blob/master/model.py
- import torch
- logits = torch.arange(1., 11.)
- print("logits.shape:", logits.shape)
- print("logits:\n", logits)
- logits = logits.view(1, 10)
- print("\nlogits.shape:", logits.shape)
- print("logits:\n", logits)
- values, indices = torch.topk(logits, k=1, dim=-1)
- print("\nvalues:\n", values)
- print("indices:\n", indices)
- top_k = 5
- print("\nlogits.size(-1):", logits.size(-1))
- values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
- print("values:\n", values)
- print("indices:\n", indices)
复制代码- /home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
- logits.shape: torch.Size([10])
- logits:
- tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
- logits.shape: torch.Size([1, 10])
- logits:
- tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]])
- values:
- tensor([[10.]])
- indices:
- tensor([[9]])
- logits.size(-1): 10
- values:
- tensor([[10., 9., 8., 7., 6.]])
- indices:
- tensor([[9, 8, 7, 6, 5]])
- Process finished with exit code 0
复制代码 References
[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |