PyTorch torch.topk

打印 上一主题 下一主题

主题 861|帖子 861|积分 2585

torch
https://pytorch.org/docs/stable/torch.html


  • torch.topk (Python function, in torch.topk)
  • torch.Tensor.topk (Python method, in torch.Tensor.topk)
1. torch.topk

https://pytorch.org/docs/stable/generated/torch.topk.html
  1. torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
复制代码
Returns the k largest elements of the given input tensor along a given dimension.
返回给定 input 张量沿给定维度的前 k 个最大元素。
If dim is not given, the last dimension of the input is chosen.
如果未提供 dim,则选择 input 的末了一个维度。
If largest is False then the k smallest elements are returned.
如果 largest 为 False,则返回前 k 个最小元素。
A namedtuple of (values, indices) is returned with the values and indices of the largest k elements of each row of the input tensor in the given dimension dim.
返回一个元组 (values, indices),其中包罗 input 张量在给定维度 dim 上每行的前 k 个最大元素的 values 和 indices。
The boolean option sorted if True, will make sure that the returned k elements are themselves sorted.
如果布尔选项 sorted 为 True,则确保返回的 k 个元素本身已排序。


  • Parameters
input (Tensor) - the input tensor.
k (int) - the k in “top-k”
dim (int, optional) - the dimension to sort along
要排序的维度
largest (bool, optional) - controls whether to return largest or smallest elements
控制是否返回最大或最小元素
sorted (bool, optional) - controls whether to return the elements in sorted order
控制是否按排序次序返回元素


  • Keyword Arguments
out (tuple, optional) - the output tuple of (Tensor, LongTensor) that can be optionally given to be used as output buffers
2. Example

  1. (base) yongqiang@yongqiang:~$ python
  2. Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
  3. Type "help", "copyright", "credits" or "license" for more information.
  4. >>> import torch
  5. >>>
  6. >>> input = torch.arange(1., 9.)
  7. >>> input
  8. tensor([1., 2., 3., 4., 5., 6., 7., 8.])
  9. >>>
  10. >>> torch.topk(input, 3)
  11. torch.return_types.topk(
  12. values=tensor([8., 7., 6.]),
  13. indices=tensor([7, 6, 5]))
  14. >>>
  15. >>> values, indices = torch.topk(input, 4)
  16. >>> values
  17. tensor([8., 7., 6., 5.])
  18. >>> indices
  19. tensor([7, 6, 5, 4])
  20. >>> exit()
  21. (base) yongqiang@yongqiang:~$
复制代码
3. Example

https://github.com/karpathy/llama2.c/blob/master/model.py
  1. import torch
  2. logits = torch.arange(1., 11.)
  3. print("logits.shape:", logits.shape)
  4. print("logits:\n", logits)
  5. logits = logits.view(1, 10)
  6. print("\nlogits.shape:", logits.shape)
  7. print("logits:\n", logits)
  8. values, indices = torch.topk(logits, k=1, dim=-1)
  9. print("\nvalues:\n", values)
  10. print("indices:\n", indices)
  11. top_k = 5
  12. print("\nlogits.size(-1):", logits.size(-1))
  13. values, indices = torch.topk(logits, min(top_k, logits.size(-1)))
  14. print("values:\n", values)
  15. print("indices:\n", indices)
复制代码
  1. /home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py
  2. logits.shape: torch.Size([10])
  3. logits:
  4. tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
  5. logits.shape: torch.Size([1, 10])
  6. logits:
  7. tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]])
  8. values:
  9. tensor([[10.]])
  10. indices:
  11. tensor([[9]])
  12. logits.size(-1): 10
  13. values:
  14. tensor([[10.,  9.,  8.,  7.,  6.]])
  15. indices:
  16. tensor([[9, 8, 7, 6, 5]])
  17. Process finished with exit code 0
复制代码
References

[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

用多少眼泪才能让你相信

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表