完整地实现了推荐系统的构建、实验和评估过程,为不同推荐算法在同一数据集 ...

打印 上一主题 下一主题

主题 1044|帖子 1044|积分 3132

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. {
  2.    
  3. "cells": [
  4.   {
  5.    
  6.    "cell_type": "markdown",
  7.    "metadata": {
  8.    },
  9.    "source": [
  10.     "# 基于用户的协同过滤算法"
  11.    ]
  12.   },
  13.   {
  14.    
  15.    "cell_type": "code",
  16.    "execution_count": 1,
  17.    "metadata": {
  18.    },
  19.    "outputs": [],
  20.    "source": [
  21.     "# 导入包\n",
  22.     "import random\n",
  23.     "import math\n",
  24.     "import time\n",
  25.     "from tqdm import tqdm"
  26.    ]
  27.   },
  28.   {
  29.    
  30.    "cell_type": "markdown",
  31.    "metadata": {
  32.    },
  33.    "source": [
  34.     "## 一. 通用函数定义"
  35.    ]
  36.   },
  37.   {
  38.    
  39.    "cell_type": "code",
  40.    "execution_count": 2,
  41.    "metadata": {
  42.    },
  43.    "outputs": [],
  44.    "source": [
  45.     "# 定义装饰器,监控运行时间\n",
  46.     "def timmer(func):\n",
  47.     "    def wrapper(*args, **kwargs):\n",
  48.     "        start_time = time.time()\n",
  49.     "        res = func(*args, **kwargs)\n",
  50.     "        stop_time = time.time()\n",
  51.     "        print('Func %s, run time: %s' % (func.__name__, stop_time - start_time))\n",
  52.     "        return res\n",
  53.     "    return wrapper"
  54.    ]
  55.   },
  56.   {
  57.    
  58.    "cell_type": "markdown",
  59.    "metadata": {
  60.    },
  61.    "source": [
  62.     "### 1. 数据处理相关\n",
  63.     "1. load data\n",
  64.     "2. split data"
  65.    ]
  66.   },
  67.   {
  68.    
  69.    "cell_type": "code",
  70.    "execution_count": 3,
  71.    "metadata": {
  72.    },
  73.    "outputs": [],
  74.    "source": [
  75.     "class Dataset():\n",
  76.     "    \n",
  77.     "    def __init__(self, fp):\n",
  78.     "        # fp: data file path\n",
  79.     "        self.data = self.loadData(fp)\n",
  80.     "    \n",
  81.     "    @timmer\n",
  82.     "    def loadData(self, fp):\n",
  83.     "        data = []\n",
  84.     "        for l in open(fp):\n",
  85.     "            data.append(tuple(map(int, l.strip().split('::')[:2])))\n",
  86.     "        return data\n",
  87.     "    \n",
  88.     "    @timmer\n",
  89.     "    def splitData(self, M, k, seed=1):\n",
  90.     "        '''\n",
  91.     "        :params: data, 加载的所有(user, item)数据条目\n",
  92.     "        :params: M, 划分的数目,最后需要取M折的平均\n",
  93.     "        :params: k, 本次是第几次划分,k~[0, M)\n",
  94.     "        :params: seed, random的种子数,对于不同的k应设置成一样的\n",
  95.     "        :return: train, test\n",
  96.     "        '''\n",
  97.     "        train, test = [], []\n",
  98.     "        random.seed(seed)\n",
  99.     "        for user, item in self.data:\n",
  100.     "            # 这里与书中的不一致,本人认为取M-1较为合理,因randint是左右都覆盖的\n",
  101.     "            if random.randint(0, M-1) == k:  \n",
  102.     "                test.append((user, item))\n",
  103.     "            else:\n",
  104.     "                train.append((user, item))\n",
  105.     "\n",
  106.     "        # 处理成字典的形式,user->set(items)\n",
  107.     "        def convert_dict(data):\n",
  108.     "            data_dict = {}\n",
  109.     "            for user, item in data:\n",
  110.     "                if user not in data_dict:\n",
  111.     "                    data_dict[user] = set()\n",
  112.     "                data_dict[user].add(item)\n",
  113.     "            data_dict = {k: list(data_dict[k]) for k in data_dict}\n",
  114.     "            return data_dict\n",
  115.     "\n",
  116.     "        return convert_dict(train), convert_dict(test)"
  117.    ]
  118.   },
  119.   {
  120.    
  121.    "cell_type": "markdown",
  122.    "metadata": {
  123.    },
  124.    "source": [
  125.     "### 2. 评价指标\n",
  126.     "1. Precision\n",
  127.     "2. Recall\n",
  128.     "3. Coverage\n",
  129.     "4. Popularity(Novelty)"
  130.    ]
  131.   },
  132.   {
  133.    
  134.    "cell_type": "code",
  135.    "execution_count": 4,
  136.    "metadata": {
  137.    },
  138.    "outputs": [],
  139.    "source": [
  140.     "class Metric():\n",
  141.     "    \n",
  142.     "    def __init__(self, train, test, GetRecommendation):\n",
  143.     "        '''\n",
  144.     "        :params: train, 训练数据\n",
  145.     "        :params: test, 测试数据\n",
  146.     "        :params: GetRecommendation, 为某个用户获取推荐物品的接口函数\n",
  147.     "        '''\n",
  148.     "        self.train = train\n",
  149.     "        self.test = test\n",
  150.     "        self.GetRecommendation = GetRecommendation\n",
  151.     "        self.recs = self.getRec()\n",
  152.     "        \n",
  153.     "    # 为test中的每个用户进行推荐\n",
  154.     "    def getRec(self):\n",
  155.     "        recs = {}\n",
  156.     "        for user in self.test:\n",
  157.     "            rank = self.GetRecommendation(user)\n",
  158.     "            recs[user] = rank\n",
  159.     "        return recs\n",
  160.     "        \n",
  161.     "    # 定义精确率指标计算方式\n",
  162.     "    def precision(self):\n",
  163.     "        all, hit = 0, 0\n",
  164.     "        for user in self.test:\n",
  165.     "            test_items = set(self.test[user])\n",
  166.     "            rank = self.recs[user]\n",
  167.     "            for item, score in rank:\n",
  168.     "                if item in test_items:\n",
  169.     "                    hit += 1\n",
  170.     "            all += len(rank)\n",
  171.     "        return round(hit / all * 100, 2)\n",
  172.     "    \n",
  173.     "    # 定义召回率指标计算方式\n",
  174.     "    def recall(self):\n",
  175.     "        all, hit = 0, 0\n",
  176.     "        for user in self.test:\n",
  177.     "            test_items = set(self.test[user])\n",
  178.     "            rank = self.recs[user]\n",
  179.     "            for item, score in rank:\n",
  180.     "                if item in test_items:\n",
  181.     "                    hit += 1\n",
  182.     "            all += len(test_items)\n",
  183.     "        return round(hit / all * 100, 2)\n",
  184.     "    \n",
  185.     "    # 定义覆盖率指标计算方式\n",
  186.     "    def coverage(self):\n",
  187.     "        all_item, recom_item = set(), set()\n",
  188.     "        for user in self.test:\n",
  189.     "            for item in self.train[user]:\n",
  190.     "                all_item.add(item)\n",
  191.     "            rank = self.recs[user]\n",
  192.     "            for item, score in rank:\n",
  193.     "                recom_item.add(item)\n",
  194.     "        return round(len(recom_item) / len(all_item) * 100, 2)\n",
  195.     "    \n",
  196.     "    # 定义新颖度指标计算方式\n",
  197.     "    def popularity(self):\n",
  198.     "        # 计算物品的流行度\n",
  199.     "        item_pop = {}\n",
  200.     "        for user in self.train:\n",
  201.     "            for item in self.train[user]:\n",
  202.     "                if item not in item_pop:\n",
  203.     "                    item_pop[item] = 0\n",
  204.     "                item_pop[item] += 1\n",
  205.     "\n",
  206.     "        num, pop = 0, 0\n",
  207.     "        for user in self.test:\n",
  208.     "            rank = self.recs[user]\n",
  209.     "            for item, score in rank:\n",
  210.     "                # 取对数,防止因长尾问题带来的被流行物品所主导\n",
  211.     "                pop += math.log(1 + item_pop[item])\n",
  212.     "                num += 1\n",
  213.     "        return round(pop / num, 6)\n",
  214.     "    \n",
  215.     "    def eval(self):\n",
  216.     "        metric = {'Precision': self.precision(),\n",
  217.     "                  'Recall': self.recall(),\n",
  218.     "                  'Coverage': self.coverage(),\n",
  219.     "                  'Popularity': self.popularity()}\n",
  220.     "        print('Metric:', metric)\n",
  221.     "        return metric"
  222.    ]
  223.   },
  224.   {
  225.    
  226.    "cell_type": "markdown",
  227.    "metadata": {
  228.    },
  229.    "source": [
  230.     "## 二. 算法实现\n",
  231.     "1. Random\n",
  232.     "2. MostPopular\n",
  233.     "3. UserCF\n",
  234.     "4. UserIIF"
  235.    ]
  236.   },
  237.   {
  238.    
  239.    "cell_type": "code",
  240.    "execution_count": 5,
  241.    "metadata": {
  242.    },
  243.    "outputs": [],
  244.    "source": [
  245.     "# 1. 随机推荐\n",
  246.     "def Random(train, K, N):\n",
  247.     "    '''\n",
  248.     "    :params: train, 训练数据集\n",
  249.     "    :params: K, 可忽略\n",
  250.     "    :params: N, 超参数,设置取TopN推荐物品数目\n",
  251.     "    :return: GetRecommendation,推荐接口函数\n",
  252.     "    '''\n",
  253.     "    items = {}\n",
  254.     "    for user in train:\n",
  255.     "        for item in train[user]:\n",
  256.     "            items[item] = 1\n",
  257.     "    \n",
  258.     "    def GetRecommendation(user):\n",
  259.     "        # 随机推荐N个未见过的\n",
  260.     "        user_items = set(train[user])\n",
  261.     "        rec_items = {k: items[k] for k in items if k not in user_items}\n",
  262.     "        rec_items = list(rec_items.items())\n",
  263.     "        random.shuffle(rec_items)\n",
  264.     "        return rec_items[:N]\n",
  265.     "    \n",
  266.     "    return GetRecommendation"
  267.    ]
  268.   },
  269.   {
  270.    
  271.    "cell_type": "code",
  272.    "execution_count": 6,
  273.    "metadata": {
  274.    },
  275.    "outputs": [],
  276.    "source": [
  277.     "# 2. 热门推荐\n",
  278.     "def MostPopular(train, K, N):\n",
  279.     "    '''\n",
  280.     "    :params: train, 训练数据集\n",
  281.     "    :params: K, 可忽略\n",
  282.     "    :params: N, 超参数,设置取TopN推荐物品数目\n",
  283.     "    :return: GetRecommendation, 推荐接口函数\n",
  284.     "    '''\n",
  285.     "    items = {}\n",
  286.     "    for user in train:\n",
  287.     "        for item in train[user]:\n",
  288.     "            if item not in items:\n",
  289.     "                items[item] = 0\n",
  290.     "            items[item] += 1\n",
  291.     "        \n",
  292.     "    def GetRecommendation(user):\n",
  293.     "        # 随机推荐N个没见过的最热门的\n",
  294.     "        user_items = set(train[user])\n",
  295.     "        rec_items = {k: items[k] for k in items if k not in user_items}\n",
  296.     "        rec_items = list(sorted(rec_items.items(), key=lambda x: x[1], reverse=True))\n",
  297.     "        return rec_items[:N]\n",
  298.     "    \n",
  299.     "    return GetRecommendation"
  300.    ]
  301.   },
  302.   {
  303.    
  304.    "cell_type": "code",
  305.    "execution_count": 7,
  306.    "metadata": {
  307.    },
  308.    "outputs": [],
  309.    "source": [
  310.     "# 3. 基于用户余弦相似度的推荐\n",
  311.     "def UserCF(train, K, N):\n",
  312.     "    '''\n",
  313.     "    :params: train, 训练数据集\n",
  314.     "    :params: K, 超参数,设置取TopK相似用户数目\n",
  315.     "    :params: N, 超参数,设置取TopN推荐物品数目\n",
  316.     "    :return: GetRecommendation, 推荐接口函数\n",
  317.     "    '''\n",
  318.     "    # 计算item->user的倒排索引\n",
  319.     "    item_users = {}\n",
  320.     "    for user in train:\n",
  321.     "        for item in train[user]:\n",
  322.     "            if item not in item_users:\n",
  323.     "                item_users[item] = []\n",
  324.     "            item_users[item].append(user)\n",
  325.     "    \n",
  326.     "    # 计算用户相似度矩阵\n",
  327.     "    sim = {}\n",
  328.     "    num = {}\n",
  329.     "    for item in item_users:\n",
  330.     "        users = item_users[item]\n",
  331.     "        for i in range(len(users)):\n",
  332.     "            u = users[i]\n",
  333.     "            if u not in num:\n",
  334.     "                num[u] = 0\n",
  335.     "            num[u] += 1\n",
  336.     "            if u not in sim:\n",
  337.     "                sim[u] = {}\n",
  338.     "            for j in range(len(users)):\n",
  339.     "                if j == i: continue\n",
  340.     "                v = users[j]\n",
  341.     "                if v not in sim[u]:\n",
  342.     "                    sim[u][v] = 0\n",
  343.     "                sim[u][v] += 1\n",
  344.     "    for u in sim:\n",
  345.     "        for v in sim[u]:\n",
  346.     "            sim[u][v] /= math.sqrt(num[u] * num[v])\n",
  347.     "    \n",
  348.     "    # 按照相似度排序\n",
  349.     "    sorted_user_sim = {k: list(sorted(v.items(), \\\n",
  350.     "                               key=lambda x: x[1], reverse=True)) \\\n",
  351.     "                       for k, v in sim.items()}\n",
  352.     "    \n",
  353.     "    # 获取接口函数\n",
  354.     "    def GetRecommendation(user):\n",
  355.     "        items = {}\n",
  356.     "        seen_items = set(train[user])\n",
  357.     "        for u, _ in sorted_user_sim[user][:K]:\n",
  358.     "            for item in train[u]:\n",
  359.     "                # 要去掉用户见过的\n",
  360.     "                if item not in seen_items:\n",
  361.     "                    if item not in items:\n",
  362.     "                        items[item] = 0\n",
  363.     "                    items[item] += sim[user][u]\n",
  364.     "        recs = list(sorted(items.items(), key=lambda x: x[1], reverse=True))[:N]\n",
  365.     "        return recs\n",
  366.     "    \n",
  367.     "    return GetRecommendation"
  368.    ]
  369.   },
  370.   {
  371.    
  372.    "cell_type": "code",
  373.    "execution_count": 8,
  374.    "metadata": {
  375.    },
  376.    "outputs": [],
  377.    "source": [
  378.     "# 4. 基于改进的用户余弦相似度的推荐\n",
  379.     "def UserIIF(train, K, N):\n",
  380.     "    '''\n",
  381.     "    :params: train, 训练数据集\n",
  382.     "    :params: K, 超参数,设置取TopK相似用户数目\n",
  383.     "    :params: N, 超参数,设置取TopN推荐物品数目\n",
  384.     "    :return: GetRecommendation, 推荐接口函数\n",
  385.     "    '''\n",
  386.     "    # 计算item->user的倒排索引\n",
  387.     "    item_users = {}\n",
  388.     "    for user in train:\n",
  389.     "        for item in train[user]:\n",
  390.     "            if item not in item_users:\n",
  391.     "                item_users[item] = []\n",
  392.     "            item_users[item].append(user)\n",
  393.     "    \n",
  394.     "    # 计算用户相似度矩阵\n",
  395.     "    sim = {}\n",
  396.     "    num = {}\n",
  397.     "    for item in item_users:\n",
  398.     "        users = item_users[item]\n",
  399.     "        for i in range(len(users)):\n",
  400.     "            u = users[i]\n",
  401.     "            if u not in num:\n",
  402.     "                num[u] = 0\n",
  403.     "            num[u] += 1\n",
  404.     "            if u not in sim:\n",
  405.     "                sim[u] = {}\n",
  406.     "            for j in range(len(users)):\n",
  407.     "                if j == i: continue\n",
  408.     "                v = users[j]\n",
  409.     "                if v not in sim[u]:\n",
  410.     "                    sim[u][v] = 0\n",
  411.     "                # 相比UserCF,主要是改进了这里\n",
  412.     "                sim[u][v] += 1 / math.log(1 + len(users))\n",
  413.     "    for u in sim:\n",
  414.     "        for v in sim[u]:\n",
  415.     "            sim[u][v] /&#
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

汕尾海湾

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表