Python项目-基于Python的网络爬虫与数据可视化系统

打印 上一主题 下一主题

主题 1623|帖子 1623|积分 4869

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
1. 项目简介

在当今数据驱动的期间,网络爬虫和数据可视化已成为获取、分析和展示信息的重要工具。本文将详细介绍如何使用Python构建一个完备的网络爬虫与数据可视化系统,该系统能够自动从互联网收集数据,进行处置处罚分析,并通过直观的图表展示结果。
2. 技术栈



  • Python 3.8+:重要编程语言
  • 网络爬虫:Requests、BeautifulSoup4、Scrapy、Selenium
  • 数据处置处罚:Pandas、NumPy
  • 数据可视化:Matplotlib、Seaborn、Plotly、Dash
  • 数据存储:SQLite、MongoDB
  • 其他工具:Jupyter Notebook、Flask
3. 系统架构

  1. 网络爬虫与数据可视化系统
  2. ├── 爬虫模块
  3. │   ├── 数据采集器
  4. │   ├── 解析器
  5. │   └── 数据清洗器
  6. ├── 数据存储模块
  7. │   ├── 关系型数据库接口
  8. │   └── NoSQL数据库接口
  9. ├── 数据分析模块
  10. │   ├── 统计分析
  11. │   └── 数据挖掘
  12. └── 可视化模块
  13.     ├── 静态图表生成器
  14.     ├── 交互式图表生成器
  15.     └── Web展示界面
复制代码
4. 爬虫模块实现

4.1 根本爬虫实现

首先,我们使用Requests和BeautifulSoup构建一个简单的爬虫:
  1. import requests
  2. from bs4 import BeautifulSoup
  3. import pandas as pd
  4. class BasicScraper:
  5.     """基础网页爬虫类"""
  6.    
  7.     def __init__(self, user_agent=None):
  8.         """初始化爬虫"""
  9.         self.session = requests.Session()
  10.         default_ua = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
  11.         self.headers = {'User-Agent': user_agent if user_agent else default_ua}
  12.    
  13.     def fetch_page(self, url, params=None):
  14.         """获取网页内容"""
  15.         try:
  16.             response = self.session.get(url, headers=self.headers, params=params)
  17.             response.raise_for_status()  # 检查请求是否成功
  18.             return response.text
  19.         except requests.exceptions.RequestException as e:
  20.             print(f"请求错误: {e}")
  21.             return None
  22.    
  23.     def parse_html(self, html, parser='html.parser'):
  24.         """解析HTML内容"""
  25.         if html:
  26.             return BeautifulSoup(html, parser)
  27.         return None
  28.    
  29.     def extract_data(self, soup, selectors):
  30.         """提取数据
  31.         
  32.         参数:
  33.             soup: BeautifulSoup对象
  34.             selectors: 字典,键为数据名称,值为CSS选择器
  35.             
  36.         返回:
  37.             pandas.DataFrame: 提取的数据
  38.         """
  39.         data = {}
  40.         for key, selector in selectors.items():
  41.             elements = soup.select(selector)
  42.             data[key] = [element.text.strip() for element in elements]
  43.         
  44.         # 确保所有列的长度一致
  45.         max_length = max([len(v) for v in data.values()]) if data else 0
  46.         for key in data:
  47.             if len(data[key]) < max_length:
  48.                 data[key].extend([None] * (max_length - len(data[key])))
  49.         
  50.         return pd.DataFrame(data)
  51.    
  52.     def scrape(self, url, selectors, params=None):
  53.         """执行完整的爬取过程"""
  54.         html = self.fetch_page(url, params)
  55.         if not html:
  56.             return pd.DataFrame()
  57.         
  58.         soup = self.parse_html(html)
  59.         if not soup:
  60.             return pd.DataFrame()
  61.         
  62.         return self.extract_data(soup, selectors)
  63. # 使用示例
  64. def scrape_books_example():
  65.     scraper = BasicScraper()
  66.     url = "http://books.toscrape.com/"
  67.     selectors = {
  68.         "title": ".product_pod h3 a",
  69.         "price": ".price_color",
  70.         "rating": ".star-rating",
  71.         "availability": ".availability"
  72.     }
  73.    
  74.     # 爬取数据
  75.     books_data = scraper.scrape(url, selectors)
  76.    
  77.     # 数据清洗
  78.     if not books_data.empty:
  79.         # 处理价格 - 移除货币符号并转换为浮点数
  80.         books_data['price'] = books_data['price'].str.replace('£', '').astype(float)
  81.         
  82.         # 处理评分 - 从类名中提取星级
  83.         books_data['rating'] = books_data['rating'].apply(lambda x: x.split()[1] + ' stars' if x else None)
  84.         
  85.         # 处理库存状态
  86.         books_data['availability'] = books_data['availability'].str.strip()
  87.    
  88.     return books_data
  89. # 执行爬取
  90. if __name__ == "__main__":
  91.     books = scrape_books_example()
  92.     print(f"爬取到 {len(books)} 本书的信息")
  93.     print(books.head())
复制代码
4.2 使用Scrapy框架构建爬虫

对于更复杂的爬虫需求,我们可以使用Scrapy框架:
  1. # 文件结构:
  2. # my_scraper/
  3. # ├── scrapy.cfg
  4. # └── my_scraper/
  5. #     ├── __init__.py
  6. #     ├── items.py
  7. #     ├── middlewares.py
  8. #     ├── pipelines.py
  9. #     ├── settings.py
  10. #     └── spiders/
  11. #         ├── __init__.py
  12. #         └── book_spider.py
  13. # items.py
  14. import scrapy
  15. class BookItem(scrapy.Item):
  16.     """定义爬取的图书项目"""
  17.     title = scrapy.Field()
  18.     price = scrapy.Field()
  19.     rating = scrapy.Field()
  20.     availability = scrapy.Field()
  21.     category = scrapy.Field()
  22.     description = scrapy.Field()
  23.     upc = scrapy.Field()
  24.     image_url = scrapy.Field()
  25.     url = scrapy.Field()
  26. # book_spider.py
  27. import scrapy
  28. from ..items import BookItem
  29. class BookSpider(scrapy.Spider):
  30.     """图书爬虫"""
  31.     name = 'bookspider'
  32.     allowed_domains = ['books.toscrape.com']
  33.     start_urls = ['http://books.toscrape.com/']
  34.    
  35.     def parse(self, response):
  36.         """解析图书列表页面"""
  37.         # 提取当前页面的所有图书
  38.         books = response.css('article.product_pod')
  39.         
  40.         for book in books:
  41.             # 获取图书详情页链接
  42.             book_url = book.css('h3 a::attr(href)').get()
  43.             if book_url:
  44.                 if 'catalogue/' not in book_url:
  45.                     book_url = 'catalogue/' + book_url
  46.                 book_url = response.urljoin(book_url)
  47.                 yield scrapy.Request(book_url, callback=self.parse_book)
  48.         
  49.         # 处理分页
  50.         next_page = response.css('li.next a::attr(href)').get()
  51.         if next_page:
  52.             yield response.follow(next_page, self.parse)
  53.    
  54.     def parse_book(self, response):
  55.         """解析图书详情页面"""
  56.         book = BookItem()
  57.         
  58.         # 提取基本信息
  59.         book['title'] = response.css('div.product_main h1::text').get()
  60.         book['price'] = response.css('p.price_color::text').get()
  61.         book['availability'] = response.css('p.availability::text').extract()[1].strip()
  62.         
  63.         # 提取评分
  64.         rating_class = response.css('p.star-rating::attr(class)').get()
  65.         if rating_class:
  66.             book['rating'] = rating_class.split()[1]
  67.         
  68.         # 提取产品信息表格
  69.         rows = response.css('table.table-striped tr')
  70.         for row in rows:
  71.             header = row.css('th::text').get()
  72.             if header == 'UPC':
  73.                 book['upc'] = row.css('td::text').get()
  74.             elif header == 'Product Type':
  75.                 book['category'] = row.css('td::text').get()
  76.         
  77.         # 提取描述
  78.         book['description'] = response.css('div#product_description + p::text').get()
  79.         
  80.         # 提取图片URL
  81.         image_url = response.css('div.item.active img::attr(src)').get()
  82.         if image_url:
  83.             book['image_url'] = response.urljoin(image_url)
  84.         
  85.         book['url'] = response.url
  86.         
  87.         yield book
  88. # pipelines.py (数据处理管道)
  89. import re
  90. from itemadapter import ItemAdapter
  91. class BookPipeline:
  92.     """图书数据处理管道"""
  93.    
  94.     def process_item(self, item, spider):
  95.         adapter = ItemAdapter(item)
  96.         
  97.         # 清洗价格字段
  98.         if adapter.get('price'):
  99.             price_str = adapter['price']
  100.             # 提取数字并转换为浮点数
  101.             price_match = re.search(r'(\d+\.\d+)', price_str)
  102.             if price_match:
  103.                 adapter['price'] = float(price_match.group(1))
  104.         
  105.         # 标准化评分
  106.         rating_map = {
  107.             'One': 1,
  108.             'Two': 2,
  109.             'Three': 3,
  110.             'Four': 4,
  111.             'Five': 5
  112.         }
  113.         if adapter.get('rating'):
  114.             adapter['rating'] = rating_map.get(adapter['rating'], 0)
  115.         
  116.         # 处理库存信息
  117.         if adapter.get('availability'):
  118.             if 'In stock' in adapter['availability']:
  119.                 # 提取库存数量
  120.                 stock_match = re.search(r'(\d+)', adapter['availability'])
  121.                 if stock_match:
  122.                     adapter['availability'] = int(stock_match.group(1))
  123.                 else:
  124.                     adapter['availability'] = 'In stock'
  125.             else:
  126.                 adapter['availability'] = 'Out of stock'
  127.         
  128.         return item
  129. # 运行爬虫的脚本 (run_spider.py)
  130. from scrapy.crawler import CrawlerProcess
  131. from scrapy.utils.project import get_project_settings
  132. def run_spider():
  133.     """运行Scrapy爬虫"""
  134.     process = CrawlerProcess(get_project_settings())
  135.     process.crawl('bookspider')
  136.     process.start()
  137. if __name__ == '__main__':
  138.     run_spider()
复制代码
4.3 处置处罚动态网页的爬虫

对于JavaScript渲染的网页,我们必要使用Selenium:
  1. from selenium import webdriver
  2. from selenium.webdriver.chrome.options import Options
  3. from selenium.webdriver.chrome.service import Service
  4. from selenium.webdriver.common.by import By
  5. from selenium.webdriver.support.ui import WebDriverWait
  6. from selenium.webdriver.support import expected_conditions as EC
  7. from webdriver_manager.chrome import ChromeDriverManager
  8. import pandas as pd
  9. import time
  10. import logging
  11. class DynamicScraper:
  12.     """动态网页爬虫类"""
  13.    
  14.     def __init__(self, headless=True, wait_time=10):
  15.         """初始化爬虫
  16.         
  17.         参数:
  18.             headless: 是否使用无头模式
  19.             wait_time: 等待元素出现的最大时间(秒)
  20.         """
  21.         self.wait_time = wait_time
  22.         self.logger = self._setup_logger()
  23.         self.driver = self._setup_driver(headless)
  24.    
  25.     def _setup_logger(self):
  26.         """设置日志记录器"""
  27.         logger = logging.getLogger('DynamicScraper')
  28.         logger.setLevel(logging.INFO)
  29.         
  30.         if not logger.handlers:
  31.             handler = logging.StreamHandler()
  32.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  33.             handler.setFormatter(formatter)
  34.             logger.addHandler(handler)
  35.         
  36.         return logger
  37.    
  38.     def _setup_driver(self, headless):
  39.         """设置WebDriver"""
  40.         try:
  41.             chrome_options = Options()
  42.             if headless:
  43.                 chrome_options.add_argument("--headless")
  44.             
  45.             # 添加其他有用的选项
  46.             chrome_options.add_argument("--disable-gpu")
  47.             chrome_options.add_argument("--no-sandbox")
  48.             chrome_options.add_argument("--disable-dev-shm-usage")
  49.             chrome_options.add_argument("--window-size=1920,1080")
  50.             
  51.             # 使用webdriver_manager自动管理ChromeDriver
  52.             service = Service(ChromeDriverManager().install())
  53.             driver = webdriver.Chrome(service=service, options=chrome_options)
  54.             
  55.             return driver
  56.         except Exception as e:
  57.             self.logger.error(f"设置WebDriver时出错: {e}")
  58.             raise
  59.    
  60.     def navigate_to(self, url):
  61.         """导航到指定URL"""
  62.         try:
  63.             self.logger.info(f"正在导航到: {url}")
  64.             self.driver.get(url)
  65.             return True
  66.         except Exception as e:
  67.             self.logger.error(f"导航到 {url} 时出错: {e}")
  68.             return False
  69.    
  70.     def wait_for_element(self, by, value):
  71.         """等待元素出现
  72.         
  73.         参数:
  74.             by: 定位方式 (By.ID, By.CSS_SELECTOR 等)
  75.             value: 定位值
  76.             
  77.         返回:
  78.             找到的元素或None
  79.         """
  80.         try:
  81.             element = WebDriverWait(self.driver, self.wait_time).until(
  82.                 EC.presence_of_element_located((by, value))
  83.             )
  84.             return element
  85.         except Exception as e:
  86.             self.logger.warning(f"等待元素 {value} 超时: {e}")
  87.             return None
  88.    
  89.     def wait_for_elements(self, by, value):
  90.         """等待多个元素出现"""
  91.         try:
  92.             elements = WebDriverWait(self.driver, self.wait_time).until(
  93.                 EC.presence_of_all_elements_located((by, value))
  94.             )
  95.             return elements
  96.         except Exception as e:
  97.             self.logger.warning(f"等待元素 {value} 超时: {e}")
  98.             return []
  99.    
  100.     def extract_data(self, selectors):
  101.         """从当前页面提取数据
  102.         
  103.         参数:
  104.             selectors: 字典,键为数据名称,值为(定位方式, 定位值)元组
  105.             
  106.         返回:
  107.             pandas.DataFrame: 提取的数据
  108.         """
  109.         data = {}
  110.         
  111.         for key, (by, value) in selectors.items():
  112.             try:
  113.                 elements = self.driver.find_elements(by, value)
  114.                 data[key] = [element.text for element in elements]
  115.                 self.logger.info(f"提取了 {len(elements)} 个 '{key}' 元素")
  116.             except Exception as e:
  117.                 self.logger.error(f"提取 '{key}' 数据时出错: {e}")
  118.                 data[key] = []
  119.         
  120.         # 确保所有列的长度一致
  121.         max_length = max([len(v) for v in data.values()]) if data else 0
  122.         for key in data:
  123.             if len(data[key]) < max_length:
  124.                 data[key].extend([None] * (max_length - len(data[key])))
  125.         
  126.         return pd.DataFrame(data)
  127.    
  128.     def scroll_to_bottom(self, scroll_pause_time=1.0):
  129.         """滚动到页面底部以加载更多内容"""
  130.         self.logger.info("开始滚动页面以加载更多内容")
  131.         
  132.         # 获取初始页面高度
  133.         last_height = self.driver.execute_script("return document.body.scrollHeight")
  134.         
  135.         while True:
  136.             # 滚动到底部
  137.             self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
  138.             
  139.             # 等待页面加载
  140.             time.sleep(scroll_pause_time)
  141.             
  142.             # 计算新的页面高度并与上一个高度比较
  143.             new_height = self.driver.execute_script("return document.body.scrollHeight")
  144.             if new_height == last_height:
  145.                 # 如果高度没有变化,说明已经到底了
  146.                 break
  147.             last_height = new_height
  148.         
  149.         self.logger.info("页面滚动完成")
  150.    
  151.     def click_element(self, by, value):
  152.         """点击元素"""
  153.         try:
  154.             element = self.wait_for_element(by, value)
  155.             if element:
  156.                 element.click()
  157.                 return True
  158.             return False
  159.         except Exception as e:
  160.             self.logger.error(f"点击元素 {value} 时出错: {e}")
  161.             return False
  162.    
  163.     def close(self):
  164.         """关闭浏览器"""
  165.         if self.driver:
  166.             self.driver.quit()
  167.             self.logger.info("浏览器已关闭")
  168. # 使用示例
  169. def scrape_dynamic_website_example():
  170.     """爬取动态网站示例"""
  171.     # 创建爬虫实例
  172.     scraper = DynamicScraper(headless=True)
  173.    
  174.     try:
  175.         # 导航到目标网站 (以SPA电商网站为例)
  176.         url = "https://www.example-dynamic-site.com/products"
  177.         if not scraper.navigate_to(url):
  178.             return pd.DataFrame()
  179.         
  180.         # 等待页面加载完成
  181.         scraper.wait_for_element(By.CSS_SELECTOR, ".product-grid")
  182.         
  183.         # 滚动页面以加载更多产品
  184.         scraper.scroll_to_bottom(scroll_pause_time=2.0)
  185.         
  186.         # 定义要提取的数据选择器
  187.         selectors = {
  188.             "product_name": (By.CSS_SELECTOR, ".product-item .product-name"),
  189.             "price": (By.CSS_SELECTOR, ".product-item .product-price"),
  190.             "rating": (By.CSS_SELECTOR, ".product-item .product-rating"),
  191.             "reviews_count": (By.CSS_SELECTOR, ".product-item .reviews-count")
  192.         }
  193.         
  194.         # 提取数据
  195.         products_data = scraper.extract_data(selectors)
  196.         
  197.         # 数据清洗
  198.         if not products_data.empty:
  199.             # 处理价格 - 移除货币符号并转换为浮点数
  200.             products_data['price'] = products_data['price'].str.replace('$', '').str.replace(',', '').astype(float)
  201.             
  202.             # 处理评分 - 提取数值
  203.             products_data['rating'] = products_data['rating'].str.extract(r'(\d\.\d)').astype(float)
  204.             
  205.             # 处理评论数 - 提取数值
  206.             products_data['reviews_count'] = products_data['reviews_count'].str.extract(r'(\d+)').astype(int)
  207.         
  208.         return products_data
  209.    
  210.     finally:
  211.         # 确保浏览器关闭
  212.         scraper.close()
  213. # 执行爬取
  214. if __name__ == "__main__":
  215.     products = scrape_dynamic_website_example()
  216.     print(f"爬取到 {len(products)} 个产品的信息")
  217.     print(products.head())
复制代码
4.4 爬虫管理器

创建一个爬虫管理器来统一调用不同类型的爬虫:
  1. class ScraperManager:
  2.     """爬虫管理器,用于管理不同类型的爬虫"""
  3.    
  4.     def __init__(self):
  5.         self.scrapers = {}
  6.    
  7.     def register_scraper(self, name, scraper_class, **kwargs):
  8.         """注册爬虫
  9.         
  10.         参数:
  11.             name: 爬虫名称
  12.             scraper_class: 爬虫类
  13.             kwargs: 传递给爬虫构造函数的参数
  14.         """
  15.         self.scrapers[name] = (scraper_class, kwargs)
  16.         print(f"已注册爬虫: {name}")
  17.    
  18.     def get_scraper(self, name):
  19.         """获取爬虫实例"""
  20.         if name not in self.scrapers:
  21.             raise ValueError(f"未找到名为 '{name}' 的爬虫")
  22.         
  23.         scraper_class, kwargs = self.scrapers[name]
  24.         return scraper_class(**kwargs)
  25.    
  26.     def run_scraper(self, name, *args, **kwargs):
  27.         """运行指定的爬虫
  28.         
  29.         参数:
  30.             name: 爬虫名称
  31.             args, kwargs: 传递给爬虫方法的参数
  32.             
  33.         返回:
  34.             爬虫返回的数据
  35.         """
  36.         scraper = self.get_scraper(name)
  37.         
  38.         if hasattr(scraper, 'scrape'):
  39.             return scraper.scrape(*args, **kwargs)
  40.         elif hasattr(scraper, 'run'):
  41.             return scraper.run(*args, **kwargs)
  42.         else:
  43.             raise AttributeError(f"爬虫 '{name}' 没有 'scrape' 或 'run' 方法")
  44. # 使用示例
  45. def scraper_manager_example():
  46.     # 创建爬虫管理器
  47.     manager = ScraperManager()
  48.    
  49.     # 注册基础爬虫
  50.     manager.register_scraper('basic', BasicScraper)
  51.    
  52.     # 注册动态爬虫
  53.     manager.register_scraper('dynamic', DynamicScraper, headless=True, wait_time=15)
  54.    
  55.     # 使用基础爬虫爬取数据
  56.     url = "http://books.toscrape.com/"
  57.     selectors = {
  58.         "title": ".product_pod h3 a",
  59.         "price": ".price_color",
  60.         "rating": ".star-rating"
  61.     }
  62.    
  63.     books_data = manager.run_scraper('basic', url, selectors)
  64.    
  65.     print(f"使用基础爬虫爬取到 {len(books_data)} 本书的信息")
  66.    
  67.     return books_data
  68. # 执行示例
  69. if __name__ == "__main__":
  70.     data = scraper_manager_example()
  71.     print(data.head())
复制代码
4.5 署理IP和请求头轮换

为了制止被目的网站封锁,我们可以实当署理IP和请求头轮换功能:
  1. import random
  2. import time
  3. from fake_useragent import UserAgent
  4. class ProxyRotator:
  5.     """代理IP轮换器"""
  6.    
  7.     def __init__(self, proxies=None):
  8.         """初始化代理轮换器
  9.         
  10.         参数:
  11.             proxies: 代理列表,格式为 [{'http': 'http://ip:port', 'https': 'https://ip:port'}, ...]
  12.         """
  13.         self.proxies = proxies or []
  14.         self.current_index = 0
  15.    
  16.     def add_proxy(self, proxy):
  17.         """添加代理"""
  18.         self.proxies.append(proxy)
  19.    
  20.     def get_proxy(self):
  21.         """获取下一个代理"""
  22.         if not self.proxies:
  23.             return None
  24.         
  25.         proxy = self.proxies[self.current_index]
  26.         self.current_index = (self.current_index + 1) % len(self.proxies)
  27.         return proxy
  28.    
  29.     def remove_proxy(self, proxy):
  30.         """移除失效的代理"""
  31.         if proxy in self.proxies:
  32.             self.proxies.remove(proxy)
  33.             self.current_index = self.current_index % max(1, len(self.proxies))
  34. class UserAgentRotator:
  35.     """User-Agent轮换器"""
  36.    
  37.     def __init__(self, use_fake_ua=True):
  38.         """初始化User-Agent轮换器"""
  39.         self.use_fake_ua = use_fake_ua
  40.         self.ua = UserAgent() if use_fake_ua else None
  41.         
  42.         # 预定义的User-Agent列表(备用)
  43.         self.user_agents = [
  44.             'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
  45.             'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15',
  46.             'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
  47.             'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36',
  48.             'Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1'
  49.         ]
  50.    
  51.     def get_random_ua(self):
  52.         """获取随机User-Agent"""
  53.         if self.use_fake_ua and self.ua:
  54.             try:
  55.                 return self.ua.random
  56.             except:
  57.                 pass
  58.         
  59.         return random.choice(self.user_agents)
  60. class EnhancedScraper(BasicScraper):
  61.     """增强型爬虫,支持代理和请求头轮换"""
  62.    
  63.     def __init__(self, proxy_rotator=None, ua_rotator=None, retry_times=3, retry_delay=2):
  64.         """初始化增强型爬虫
  65.         
  66.         参数:
  67.             proxy_rotator: 代理轮换器
  68.             ua_rotator: User-Agent轮换器
  69.             retry_times: 请求失败重试次数
  70.             retry_delay: 重试延迟时间(秒)
  71.         """
  72.         super().__init__()
  73.         self.proxy_rotator = proxy_rotator or ProxyRotator()
  74.         self.ua_rotator = ua_rotator or UserAgentRotator()
  75.         self.retry_times = retry_times
  76.         self.retry_delay = retry_delay
  77.    
  78.     def fetch_page(self, url, params=None):
  79.         """获取网页内容,支持代理和重试"""
  80.         for attempt in range(self.retry_times):
  81.             try:
  82.                 # 获取代理和User-Agent
  83.                 proxy = self.proxy_rotator.get_proxy()
  84.                 user_agent = self.ua_rotator.get_random_ua()
  85.                
  86.                 # 更新请求头
  87.                 self.headers['User-Agent'] = user_agent
  88.                
  89.                 # 发送请求
  90.                 response = self.session.get(
  91.                     url,
  92.                     headers=self.headers,
  93.                     params=params,
  94.                     proxies=proxy,
  95.                     timeout=10
  96.                 )
  97.                
  98.                 # 检查请求是否成功
  99.                 response.raise_for_status()
  100.                 return response.text
  101.             
  102.             except requests.exceptions.RequestException as e:
  103.                 print(f"请求错误 (尝试 {attempt+1}/{self.retry_times}): {e}")
  104.                
  105.                 # 如果是代理问题,移除当前代理
  106.                 if proxy and (isinstance(e, requests.exceptions.ProxyError) or
  107.                              isinstance(e, requests.exceptions.ConnectTimeout)):
  108.                     self.proxy_rotator.remove_proxy(proxy)
  109.                
  110.                 # 最后一次尝试失败
  111.                 if attempt == self.retry_times - 1:
  112.                     return None
  113.                
  114.                 # 延迟后重试
  115.                 time.sleep(self.retry_delay)
  116.         
  117.         return None
  118. # 使用示例
  119. def enhanced_scraper_example():
  120.     # 创建代理轮换器
  121.     proxy_rotator = ProxyRotator([
  122.         {'http': 'http://proxy1.example.com:8080', 'https': 'https://proxy1.example.com:8080'},
  123.         {'http': 'http://proxy2.example.com:8080', 'https': 'https://proxy2.example.com:8080'}
  124.     ])
  125.    
  126.     # 创建User-Agent轮换器
  127.     ua_rotator = UserAgentRotator()
  128.    
  129.     # 创建增强型爬虫
  130.     scraper = EnhancedScraper(proxy_rotator, ua_rotator, retry_times=3)
  131.    
  132.     # 爬取数据
  133.     url = "http://books.toscrape.com/"
  134.     selectors = {
  135.         "title": ".product_pod h3 a",
  136.         "price": ".price_color",
  137.         "rating": ".star-rating"
  138.     }
  139.    
  140.     books_data = scraper.scrape(url, selectors)
  141.     return books_data
  142. # 执行示例
  143. if __name__ == "__main__":
  144.     data = enhanced_scraper_example()
  145.     print(f"爬取到 {len(data)} 本书的信息")
  146.     print(data.head())
复制代码
5. 数据存储模块

数据存储模块负责将爬取的数据生存到不同类型的存储系统中,包罗关系型数据库、NoSQL数据库和文件系统。
5.1 SQLite数据库存储

SQLite是一种轻量级的关系型数据库,恰当单机应用和原型开发
  1. import sqlite3
  2. import pandas as pd
  3. import os
  4. import logging
  5. import csv
  6. from datetime import datetime
  7. class SQLiteStorage:
  8.     """SQLite数据存储类"""
  9.    
  10.     def __init__(self, db_path):
  11.         """初始化SQLite数据库连接
  12.         
  13.         参数:
  14.             db_path: 数据库文件路径
  15.         """
  16.         self.db_path = db_path
  17.         self.logger = self._setup_logger()
  18.         
  19.         # 确保数据库目录存在
  20.         os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True)
  21.         
  22.         try:
  23.             self.conn = sqlite3.connect(db_path)
  24.             self.cursor = self.conn.cursor()
  25.             self.logger.info(f"成功连接到SQLite数据库: {db_path}")
  26.         except sqlite3.Error as e:
  27.             self.logger.error(f"连接SQLite数据库时出错: {e}")
  28.             raise
  29.    
  30.     def _setup_logger(self):
  31.         """设置日志记录器"""
  32.         logger = logging.getLogger('SQLiteStorage')
  33.         logger.setLevel(logging.INFO)
  34.         
  35.         if not logger.handlers:
  36.             handler = logging.StreamHandler()
  37.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  38.             handler.setFormatter(formatter)
  39.             logger.addHandler(handler)
  40.         
  41.         return logger
  42.    
  43.     def create_table(self, table_name, columns):
  44.         """创建数据表
  45.         
  46.         参数:
  47.             table_name: 表名
  48.             columns: 列定义字典,键为列名,值为数据类型
  49.         """
  50.         try:
  51.             # 构建CREATE TABLE语句
  52.             columns_str = ', '.join([f"{col} {dtype}" for col, dtype in columns.items()])
  53.             query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_str})"
  54.             
  55.             # 执行SQL
  56.             self.cursor.execute(query)
  57.             self.conn.commit()
  58.             self.logger.info(f"成功创建表: {table_name}")
  59.             return True
  60.         except sqlite3.Error as e:
  61.             self.logger.error(f"创建表 {table_name} 时出错: {e}")
  62.             self.conn.rollback()
  63.             return False
  64.    
  65.     def insert_data(self, table_name, data):
  66.         """插入数据
  67.         
  68.         参数:
  69.             table_name: 表名
  70.             data: 要插入的数据,可以是DataFrame或列表
  71.         """
  72.         try:
  73.             if isinstance(data, pd.DataFrame):
  74.                 # 使用pandas的to_sql方法插入DataFrame
  75.                 data.to_sql(table_name, self.conn, if_exists='append', index=False)
  76.                 self.logger.info(f"成功插入 {len(data)} 行数据到表 {table_name}")
  77.             elif isinstance(data, list) and len(data) > 0:
  78.                 # 处理列表数据
  79.                 if isinstance(data[0], dict):
  80.                     # 字典列表
  81.                     if not data:
  82.                         return True
  83.                     
  84.                     # 获取所有键
  85.                     columns = list(data[0].keys())
  86.                     
  87.                     # 准备INSERT语句
  88.                     placeholders = ', '.join(['?'] * len(columns))
  89.                     columns_str = ', '.join(columns)
  90.                     query = f"INSERT INTO {table_name} ({columns_str}) VALUES ({placeholders})"
  91.                     
  92.                     # 准备数据
  93.                     values = [[row.get(col) for col in columns] for row in data]
  94.                     
  95.                     # 执行插入
  96.                     self.cursor.executemany(query, values)
  97.                 else:
  98.                     # 值列表
  99.                     placeholders = ', '.join(['?'] * len(data[0]))
  100.                     query = f"INSERT INTO {table_name} VALUES ({placeholders})"
  101.                     self.cursor.executemany(query, data)
  102.                
  103.                 self.conn.commit()
  104.                 self.logger.info(f"成功插入 {len(data)} 行数据到表 {table_name}")
  105.             else:
  106.                 self.logger.warning(f"没有数据可插入到表 {table_name}")
  107.             
  108.             return True
  109.         except Exception as e:
  110.             self.logger.error(f"插入数据到表 {table_name} 时出错: {e}")
  111.             self.conn.rollback()
  112.             return False
  113.    
  114.     def query_data(self, query, params=None):
  115.         """执行查询
  116.         
  117.         参数:
  118.             query: SQL查询语句
  119.             params: 查询参数(可选)
  120.             
  121.         返回:
  122.             pandas.DataFrame: 查询结果
  123.         """
  124.         try:
  125.             if params:
  126.                 return pd.read_sql_query(query, self.conn, params=params)
  127.             else:
  128.                 return pd.read_sql_query(query, self.conn)
  129.         except Exception as e:
  130.             self.logger.error(f"执行查询时出错: {e}")
  131.             return pd.DataFrame()
  132.    
  133.     def execute_query(self, query, params=None):
  134.         """执行任意SQL查询
  135.         
  136.         参数:
  137.             query: SQL查询语句
  138.             params: 查询参数(可选)
  139.             
  140.         返回:
  141.             bool: 是否成功
  142.         """
  143.         try:
  144.             if params:
  145.                 self.cursor.execute(query, params)
  146.             else:
  147.                 self.cursor.execute(query)
  148.             
  149.             self.conn.commit()
  150.             return True
  151.         except Exception as e:
  152.             self.logger.error(f"执行查询时出错: {e}")
  153.             self.conn.rollback()
  154.             return False
  155.    
  156.     def table_exists(self, table_name):
  157.         """检查表是否存在
  158.         
  159.         参数:
  160.             table_name: 表名
  161.             
  162.         返回:
  163.             bool: 表是否存在
  164.         """
  165.         query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
  166.         self.cursor.execute(query, (table_name,))
  167.         return self.cursor.fetchone() is not None
  168.    
  169.     def get_table_info(self, table_name):
  170.         """获取表信息
  171.         
  172.         参数:
  173.             table_name: 表名
  174.             
  175.         返回:
  176.             list: 表的列信息
  177.         """
  178.         if not self.table_exists(table_name):
  179.             return []
  180.         
  181.         query = f"PRAGMA table_info({table_name})"
  182.         return self.cursor.execute(query).fetchall()
  183.    
  184.     def close(self):
  185.         """关闭数据库连接"""
  186.         if hasattr(self, 'conn') and self.conn:
  187.             self.conn.close()
  188.             self.logger.info("数据库连接已关闭")
  189.    
  190.     def __enter__(self):
  191.         """上下文管理器入口"""
  192.         return self
  193.    
  194.     def __exit__(self, exc_type, exc_val, exc_tb):
  195.         """上下文管理器退出"""
  196.         self.close()
  197. # 使用示例
  198. def sqlite_example():
  199.     # 创建SQLite存储实例
  200.     db = SQLiteStorage('data/books.db')
  201.    
  202.     try:
  203.         # 创建表
  204.         db.create_table('books', {
  205.             'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
  206.             'title': 'TEXT NOT NULL',
  207.             'price': 'REAL',
  208.             'rating': 'INTEGER',
  209.             'category': 'TEXT',
  210.             'description': 'TEXT',
  211.             'created_at': 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP'
  212.         })
  213.         
  214.         # 准备示例数据
  215.         books_data = pd.DataFrame({
  216.             'title': ['Python编程', '数据科学入门', '机器学习实战'],
  217.             'price': [59.9, 69.9, 79.9],
  218.             'rating': [5, 4, 5],
  219.             'category': ['编程', '数据科学', '机器学习'],
  220.             'description': ['Python基础教程', '数据分析入门', '机器学习算法详解']
  221.         })
  222.         
  223.         # 插入数据
  224.         db.insert_data('books', books_data)
  225.         
  226.         # 查询数据
  227.         results = db.query_data("SELECT * FROM books WHERE rating >= ?", (4,))
  228.         print(f"查询结果: {len(results)} 行")
  229.         print(results)
  230.         
  231.         return results
  232.    
  233.     finally:
  234.         # 确保关闭连接
  235.         db.close()
  236. if __name__ == "__main__":
  237.     sqlite_example()
复制代码
5.2 MongoDB数据库存储

MongoDB是一种流行的NoSQL数据库,恰当存储非布局化或半布局化数据:
  1. import pymongo
  2. import pandas as pd
  3. import json
  4. import logging
  5. from bson import ObjectId
  6. from datetime import datetime
  7. class MongoDBStorage:
  8.     """MongoDB数据存储类"""
  9.    
  10.     def __init__(self, connection_string, database_name):
  11.         """初始化MongoDB连接
  12.         
  13.         参数:
  14.             connection_string: MongoDB连接字符串
  15.             database_name: 数据库名称
  16.         """
  17.         self.connection_string = connection_string
  18.         self.database_name = database_name
  19.         self.logger = self._setup_logger()
  20.         
  21.         try:
  22.             # 连接到MongoDB
  23.             self.client = pymongo.MongoClient(connection_string)
  24.             self.db = self.client[database_name]
  25.             
  26.             # 测试连接
  27.             self.client.server_info()
  28.             self.logger.info(f"成功连接到MongoDB数据库: {database_name}")
  29.         except Exception as e:
  30.             self.logger.error(f"连接MongoDB数据库时出错: {e}")
  31.             raise
  32.    
  33.     def _setup_logger(self):
  34.         """设置日志记录器"""
  35.         logger = logging.getLogger('MongoDBStorage')
  36.         logger.setLevel(logging.INFO)
  37.         
  38.         if not logger.handlers:
  39.             handler = logging.StreamHandler()
  40.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  41.             handler.setFormatter(formatter)
  42.             logger.addHandler(handler)
  43.         
  44.         return logger
  45.    
  46.     def _convert_to_json_serializable(self, data):
  47.         """转换数据为JSON可序列化格式
  48.         
  49.         参数:
  50.             data: 要转换的数据
  51.             
  52.         返回:
  53.             转换后的数据
  54.         """
  55.         if isinstance(data, dict):
  56.             return {k: self._convert_to_json_serializable(v) for k, v in data.items()}
  57.         elif isinstance(data, list):
  58.             return [self._convert_to_json_serializable(item) for item in data]
  59.         elif isinstance(data, (ObjectId, datetime)):
  60.             return str(data)
  61.         else:
  62.             return data
  63.    
  64.     def insert_document(self, collection_name, document):
  65.         """插入单个文档
  66.         
  67.         参数:
  68.             collection_name: 集合名称
  69.             document: 要插入的文档(字典)
  70.             
  71.         返回:
  72.             插入的文档ID
  73.         """
  74.         try:
  75.             collection = self.db[collection_name]
  76.             result = collection.insert_one(document)
  77.             self.logger.info(f"成功插入文档到集合 {collection_name}, ID: {result.inserted_id}")
  78.             return result.inserted_id
  79.         except Exception as e:
  80.             self.logger.error(f"插入文档到集合 {collection_name} 时出错: {e}")
  81.             return None
  82.    
  83.     def insert_many(self, collection_name, documents):
  84.         """插入多个文档
  85.         
  86.         参数:
  87.             collection_name: 集合名称
  88.             documents: 要插入的文档列表
  89.             
  90.         返回:
  91.             插入的文档ID列表
  92.         """
  93.         try:
  94.             collection = self.db[collection_name]
  95.             result = collection.insert_many(documents)
  96.             self.logger.info(f"成功插入 {len(result.inserted_ids)} 个文档到集合 {collection_name}")
  97.             return result.inserted_ids
  98.         except Exception as e:
  99.             self.logger.error(f"插入多个文档到集合 {collection_name} 时出错: {e}")
  100.             return []
  101.    
  102.     def insert_dataframe(self, collection_name, df):
  103.         """插入DataFrame数据
  104.         
  105.         参数:
  106.             collection_name: 集合名称
  107.             df: pandas DataFrame
  108.             
  109.         返回:
  110.             bool: 是否成功
  111.         """
  112.         try:
  113.             if df.empty:
  114.                 self.logger.warning(f"DataFrame为空,未插入数据到集合 {collection_name}")
  115.                 return True
  116.             
  117.             # 将DataFrame转换为字典列表
  118.             records = df.to_dict('records')
  119.             
  120.             # 插入数据
  121.             collection = self.db[collection_name]
  122.             result = collection.insert_many(records)
  123.             
  124.             self.logger.info(f"成功插入 {len(result.inserted_ids)} 行数据到集合 {collection_name}")
  125.             return True
  126.         except Exception as e:
  127.             self.logger.error(f"插入DataFrame到集合 {collection_name} 时出错: {e}")
  128.             return False
  129.    
  130.     def find_documents(self, collection_name, query=None, projection=None, limit=0):
  131.         """查询文档
  132.         
  133.         参数:
  134.             collection_name: 集合名称
  135.             query: 查询条件(可选)
  136.             projection: 投影字段(可选)
  137.             limit: 结果限制数量(可选)
  138.             
  139.         返回:
  140.             pandas.DataFrame: 查询结果
  141.         """
  142.         try:
  143.             collection = self.db[collection_name]
  144.             
  145.             # 执行查询
  146.             if query is None:
  147.                 query = {}
  148.             
  149.             cursor = collection.find(query, projection)
  150.             
  151.             if limit > 0:
  152.                 cursor = cursor.limit(limit)
  153.             
  154.             # 将结果转换为列表
  155.             results = list(cursor)
  156.             
  157.             # 将ObjectId转换为字符串
  158.             for doc in results:
  159.                 if '_id' in doc:
  160.                     doc['_id'] = str(doc['_id'])
  161.             
  162.             # 转换为DataFrame
  163.             if results:
  164.                 return pd.DataFrame(results)
  165.             else:
  166.                 return pd.DataFrame()
  167.         except Exception as e:
  168.             self.logger.error(f"查询集合 {collection_name} 时出错: {e}")
  169.             return pd.DataFrame()
  170.    
  171.     def update_document(self, collection_name, query, update_data, upsert=False):
  172.         """更新文档
  173.         
  174.         参数:
  175.             collection_name: 集合名称
  176.             query: 查询条件
  177.             update_data: 更新数据
  178.             upsert: 如果不存在是否插入
  179.             
  180.         返回:
  181.             int: 更新的文档数量
  182.         """
  183.         try:
  184.             collection = self.db[collection_name]
  185.             
  186.             # 确保update_data使用$set操作符
  187.             if not any(k.startswith('$') for k in update_data.keys()):
  188.                 update_data = {'$set': update_data}
  189.             
  190.             result = collection.update_one(query, update_data, upsert=upsert)
  191.             
  192.             self.logger.info(f"更新集合 {collection_name} 中的文档: 匹配 {result.matched_count}, 修改 {result.modified_count}")
  193.             return result.modified_count
  194.         except Exception as e:
  195.             self.logger.error(f"更新集合 {collection_name} 中的文档时出错: {e}")
  196.             return 0
  197.    
  198.     def update_many(self, collection_name, query, update_data):
  199.         """更新多个文档
  200.         
  201.         参数:
  202.             collection_name: 集合名称
  203.             query: 查询条件
  204.             update_data: 更新数据
  205.             
  206.         返回:
  207.             int: 更新的文档数量
  208.         """
  209.         try:
  210.             collection = self.db[collection_name]
  211.             
  212.             # 确保update_data使用$set操作符
  213.             if not any(k.startswith('$') for k in update_data.keys()):
  214.                 update_data = {'$set': update_data}
  215.             
  216.             result = collection.update_many(query, update_data)
  217.             
  218.             self.logger.info(f"更新集合 {collection_name} 中的多个文档: 匹配 {result.matched_count}, 修改 {result.modified_count}")
  219.             return result.modified_count
  220.         except Exception as e:
  221.             self.logger.error(f"更新集合 {collection_name} 中的多个文档时出错: {e}")
  222.             return 0
  223.    
  224.     def delete_document(self, collection_name, query):
  225.         """删除文档
  226.         
  227.         参数:
  228.             collection_name: 集合名称
  229.             query: 查询条件
  230.             
  231.         返回:
  232.             int: 删除的文档数量
  233.         """
  234.         try:
  235.             collection = self.db[collection_name]
  236.             result = collection.delete_one(query)
  237.             
  238.             self.logger.info(f"从集合 {collection_name} 中删除了 {result.deleted_count} 个文档")
  239.             return result.deleted_count
  240.         except Exception as e:
  241.             self.logger.error(f"从集合 {collection_name} 中删除文档时出错: {e}")
  242.             return 0
  243.    
  244.     def delete_many(self, collection_name, query):
  245.         """删除多个文档
  246.         
  247.         参数:
  248.             collection_name: 集合名称
  249.             query: 查询条件
  250.             
  251.         返回:
  252.             int: 删除的文档数量
  253.         """
  254.         try:
  255.             collection = self.db[collection_name]
  256.             result = collection.delete_many(query)
  257.             
  258.             self.logger.info(f"从集合 {collection_name} 中删除了 {result.deleted_count} 个文档")
  259.             return result.deleted_count
  260.         except Exception as e:
  261.             self.logger.error(f"从集合 {collection_name} 中删除多个文档时出错: {e}")
  262.             return 0
  263.    
  264.     def create_index(self, collection_name, keys, **kwargs):
  265.         """创建索引
  266.         
  267.         参数:
  268.             collection_name: 集合名称
  269.             keys: 索引键
  270.             **kwargs: 其他索引选项
  271.             
  272.         返回:
  273.             str: 创建的索引名称
  274.         """
  275.         try:
  276.             collection = self.db[collection_name]
  277.             index_name = collection.create_index(keys, **kwargs)
  278.             
  279.             self.logger.info(f"在集合 {collection_name} 上创建索引: {index_name}")
  280.             return index_name
  281.         except Exception as e:
  282.             self.logger.error(f"在集合 {collection_name} 上创建索引时出错: {e}")
  283.             return None
  284.    
  285.     def drop_collection(self, collection_name):
  286.         """删除集合
  287.         
  288.         参数:
  289.             collection_name: 集合名称
  290.             
  291.         返回:
  292.             bool: 是否成功
  293.         """
  294.         try:
  295.             self.db.drop_collection(collection_name)
  296.             self.logger.info(f"成功删除集合: {collection_name}")
  297.             return True
  298.         except Exception as e:
  299.             self.logger.error(f"删除集合 {collection_name} 时出错: {e}")
  300.             return False
  301.    
  302.     def close(self):
  303.         """关闭数据库连接"""
  304.         if hasattr(self, 'client') and self.client:
  305.             self.client.close()
  306.             self.logger.info("MongoDB连接已关闭")
  307.    
  308.     def __enter__(self):
  309.         """上下文管理器入口"""
  310.         return self
  311.    
  312.     def __exit__(self, exc_type, exc_val, exc_tb):
  313.         """上下文管理器退出"""
  314.         self.close()
  315. # 使用示例
  316. def mongodb_example():
  317.     # 创建MongoDB存储实例
  318.     mongo = MongoDBStorage('mongodb://localhost:27017', 'web_scraping_db')
  319.    
  320.     try:
  321.         # 准备示例数据
  322.         products_data = pd.DataFrame({
  323.             'name': ['智能手机', '笔记本电脑', '平板电脑'],
  324.             'price': [2999, 4999, 3999],
  325.             'brand': ['品牌A', '品牌B', '品牌A'],
  326.             'features': [
  327.                 ['5G', '高清摄像头', '快速充电'],
  328.                 ['高性能CPU', '大内存', 'SSD'],
  329.                 ['触控屏', '长续航', '轻薄']
  330.             ],
  331.             'in_stock': [True, False, True],
  332.             'last_updated': [datetime.now() for _ in range(3)]
  333.         })
  334.         
  335.         # 插入DataFrame数据
  336.         mongo.insert_dataframe('products', products_data)
  337.         
  338.         # 插入单个文档
  339.         review = {
  340.             'product_id': '123456',
  341.             'user': '用户A',
  342.             'rating': 5,
  343.             'comment': '非常好用的产品',
  344.             'date': datetime.now()
  345.         }
  346.         review_id = mongo.insert_document('reviews', review)
  347.         
  348.         # 查询数据
  349.         results = mongo.find_documents('products', {'brand': '品牌A'})
  350.         print(f"查询结果: {len(results)} 行")
  351.         print(results)
  352.         
  353.         # 更新数据
  354.         mongo.update_document('products', {'name': '智能手机'}, {'$set': {'price': 2899}})
  355.         
  356.         # 创建索引
  357.         mongo.create_index('products', [('name', pymongo.ASCENDING)], unique=True)
  358.         
  359.         return results
  360.    
  361.     finally:
  362.         # 确保关闭连接
  363.         mongo.close()
  364. if __name__ == "__main__":
  365.     mongodb_example()
复制代码
5.3 CSV文件存储

CSV是一种常用的数据交换格式,恰当存储表格数据:
  1. import pandas as pd
  2. import os
  3. import logging
  4. import csv
  5. from datetime import datetime
  6. class CSVStorage:
  7.     """CSV文件存储类"""
  8.    
  9.     def __init__(self, base_dir='data/csv'):
  10.         """初始化CSV存储
  11.         
  12.         参数:
  13.             base_dir: CSV文件存储的基础目录
  14.         """
  15.         self.base_dir = base_dir
  16.         self.logger = self._setup_logger()
  17.         
  18.         # 确保目录存在
  19.         os.makedirs(base_dir, exist_ok=True)
  20.         self.logger.info(f"CSV存储目录: {base_dir}")
  21.    
  22.     def _setup_logger(self):
  23.         """设置日志记录器"""
  24.         logger = logging.getLogger('CSVStorage')
  25.         logger.setLevel(logging.INFO)
  26.         
  27.         if not logger.handlers:
  28.             handler = logging.StreamHandler()
  29.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  30.             handler.setFormatter(formatter)
  31.             logger.addHandler(handler)
  32.         
  33.         return logger
  34.    
  35.     def _get_file_path(self, file_name):
  36.         """获取文件完整路径
  37.         
  38.         参数:
  39.             file_name: 文件名
  40.             
  41.         返回:
  42.             str: 文件完整路径
  43.         """
  44.         # 确保文件名有.csv后缀
  45.         if not file_name.endswith('.csv'):
  46.             file_name += '.csv'
  47.         
  48.         return os.path.join(self.base_dir, file_name)
  49.    
  50.     def save_dataframe(self, df, file_name, index=False):
  51.         """保存DataFrame到CSV文件
  52.         
  53.         参数:
  54.             df: 要保存的DataFrame
  55.             file_name: 文件名
  56.             index: 是否保存索引
  57.             
  58.         返回:
  59.             bool: 是否成功
  60.         """
  61.         try:
  62.             file_path = self._get_file_path(file_name)
  63.             df.to_csv(file_path, index=index, encoding='utf-8')
  64.             self.logger.info(f"成功保存 {len(df)} 行数据到文件: {file_path}")
  65.             return True
  66.         except Exception as e:
  67.             self.logger.error(f"保存数据到文件 {file_name} 时出错: {e}")
  68.             return False
  69.    
  70.     def append_dataframe(self, df, file_name, index=False):
  71.         """追加DataFrame到CSV文件
  72.         
  73.         参数:
  74.             df: 要追加的DataFrame
  75.             file_name: 文件名
  76.             index: 是否保存索引
  77.             
  78.         返回:
  79.             bool: 是否成功
  80.         """
  81.         try:
  82.             file_path = self._get_file_path(file_name)
  83.             
  84.             # 检查文件是否存在
  85.             file_exists = os.path.isfile(file_path)
  86.             
  87.             # 如果文件存在,追加数据;否则创建新文件
  88.             df.to_csv(file_path, mode='a', header=not file_exists, index=index, encoding='utf-8')
  89.             
  90.             self.logger.info(f"成功追加 {len(df)} 行数据到文件: {file_path}")
  91.             return True
  92.         except Exception as e:
  93.             self.logger.error(f"追加数据到文件 {file_name} 时出错: {e}")
  94.             return False
  95.    
  96.     def load_csv(self, file_name, **kwargs):
  97.         """加载CSV文件到DataFrame
  98.         
  99.         参数:
  100.             file_name: 文件名
  101.             **kwargs: 传递给pd.read_csv的参数
  102.             
  103.         返回:
  104.             pandas.DataFrame: 加载的数据
  105.         """
  106.         try:
  107.             file_path = self._get_file_path(file_name)
  108.             
  109.             if not os.path.isfile(file_path):
  110.                 self.logger.warning(f"文件不存在: {file_path}")
  111.                 return pd.DataFrame()
  112.             
  113.             df = pd.read_csv(file_path, **kwargs)
  114.             self.logger.info(f"成功从文件 {file_path} 加载 {len(df)} 行数据")
  115.             return df
  116.         except Exception as e:
  117.             self.logger.error(f"从文件 {file_name} 加载数据时出错: {e}")
  118.             return pd.DataFrame()
  119.    
  120.     def save_records(self, records, file_name, fieldnames=None):
  121.         """保存记录列表到CSV文件
  122.         
  123.         参数:
  124.             records: 字典列表
  125.             file_name: 文件名
  126.             fieldnames: 字段名列表(可选)
  127.             
  128.         返回:
  129.             bool: 是否成功
  130.         """
  131.         try:
  132.             file_path = self._get_file_path(file_name)
  133.             
  134.             if not records:
  135.                 self.logger.warning(f"没有记录可保存到文件: {file_path}")
  136.                 return True
  137.             
  138.             # 如果未提供字段名,使用第一条记录的键
  139.             if fieldnames is None:
  140.                 fieldnames = list(records[0].keys())
  141.             
  142.             with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
  143.                 writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
  144.                 writer.writeheader()
  145.                 writer.writerows(records)
  146.             
  147.             self.logger.info(f"成功保存 {len(records)} 条记录到文件: {file_path}")
  148.             return True
  149.         except Exception as e:
  150.             self.logger.error(f"保存记录到文件 {file_name} 时出错: {e}")
  151.             return False
  152.    
  153.     def append_records(self, records, file_name, fieldnames=None):
  154.         """追加记录列表到CSV文件
  155.         
  156.         参数:
  157.             records: 字典列表
  158.             file_name: 文件名
  159.             fieldnames: 字段名列表(可选)
  160.             
  161.         返回:
  162.             bool: 是否成功
  163.         """
  164.         try:
  165.             file_path = self._get_file_path(file_name)
  166.             
  167.             if not records:
  168.                 self.logger.warning(f"没有记录可追加到文件: {file_path}")
  169.                 return True
  170.             
  171.             # 检查文件是否存在
  172.             file_exists = os.path.isfile(file_path)
  173.             
  174.             # 如果未提供字段名,使用第一条记录的键
  175.             if fieldnames is None:
  176.                 fieldnames = list(records[0].keys())
  177.             
  178.             with open(file_path, 'a', newline='', encoding='utf-8') as csvfile:
  179.                 writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
  180.                
  181.                 # 如果文件不存在,写入表头
  182.                 if not file_exists:
  183.                     writer.writeheader()
  184.                
  185.                 writer.writerows(records)
  186.             
  187.             self.logger.info(f"成功追加 {len(records)} 条记录到文件: {file_path}")
  188.             return True
  189.         except Exception as e:
  190.             self.logger.error(f"追加记录到文件 {file_name} 时出错: {e}")
  191.             return False
  192.    
  193.     def file_exists(self, file_name):
  194.         """检查文件是否存在
  195.         
  196.         参数:
  197.             file_name: 文件名
  198.             
  199.         返回:
  200.             bool: 文件是否存在
  201.         """
  202.         file_path = self._get_file_path(file_name)
  203.         return os.path.isfile(file_path)
  204.    
  205.     def list_files(self):
  206.         """列出所有CSV文件
  207.         
  208.         返回:
  209.             list: CSV文件列表
  210.         """
  211.         try:
  212.             files = [f for f in os.listdir(self.base_dir) if f.endswith('.csv')]
  213.             self.logger.info(f"找到 {len(files)} 个CSV文件")
  214.             return files
  215.         except Exception as e:
  216.             self.logger.error(f"列出CSV文件时出错: {e}")
  217.             return []
  218.    
  219.     def delete_file(self, file_name):
  220.         """删除CSV文件
  221.         
  222.         参数:
  223.             file_name: 文件名
  224.             
  225.         返回:
  226.             bool: 是否成功
  227.         """
  228.         try:
  229.             file_path = self._get_file_path(file_name)
  230.             
  231.             if not os.path.isfile(file_path):
  232.                 self.logger.warning(f"文件不存在,无法删除: {file_path}")
  233.                 return False
  234.             
  235.             os.remove(file_path)
  236.             self.logger.info(f"成功删除文件: {file_path}")
  237.             return True
  238.         except Exception as e:
  239.             self.logger.error(f"删除文件 {file_name} 时出错: {e}")
  240.             return False
  241. # 使用示例
  242. def csv_example():
  243.     # 创建CSV存储实例
  244.     csv_storage = CSVStorage('data/csv')
  245.    
  246.     # 准备示例数据
  247.     data = pd.DataFrame({
  248.         'date': [datetime.now().strftime('%Y-%m-%d %H:%M:%S') for _ in range(3)],
  249.         'category': ['电子产品', '家居', '食品'],
  250.         'item_count': [120, 85, 200],
  251.         'average_price': [1500.75, 350.25, 45.50]
  252.     })
  253.    
  254.     # 保存数据
  255.     csv_storage.save_dataframe(data, 'inventory')
  256.    
  257.     # 加载数据
  258.     loaded_data = csv_storage.load_csv('inventory')
  259.     print(f"加载的数据: {len(loaded_data)} 行")
  260.     print(loaded_data)
  261.    
  262.     # 追加数据
  263.     new_data = pd.DataFrame({
  264.         'date': [datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
  265.         'category': ['服装'],
  266.         'item_count': [150],
  267.         'average_price': [250.00]
  268.     })
  269.     csv_storage.append_dataframe(new_data, 'inventory')
  270.    
  271.     return loaded_data
  272. if __name__ == "__main__":
  273.     csv_example()
复制代码
5.4 存储工厂

创建一个存储工厂类,用于统一管理不同类型的存储:
  1. class StorageFactory:
  2.     """存储工厂类,用于创建和管理不同类型的存储"""
  3.    
  4.     def __init__(self):
  5.         self.storage_classes = {}
  6.         self.storage_instances = {}
  7.    
  8.     def register_storage(self, storage_type, storage_class):
  9.         """注册存储类
  10.         
  11.         参数:
  12.             storage_type: 存储类型名称
  13.             storage_class: 存储类
  14.         """
  15.         self.storage_classes[storage_type] = storage_class
  16.         print(f"已注册存储类型: {storage_type}")
  17.    
  18.     def get_storage(self, storage_type, **kwargs):
  19.         """获取存储实例
  20.         
  21.         参数:
  22.             storage_type: 存储类型名称
  23.             **kwargs: 传递给存储类构造函数的参数
  24.             
  25.         返回:
  26.             存储实例
  27.         """
  28.         # 检查存储类型是否已注册
  29.         if storage_type not in self.storage_classes:
  30.             raise ValueError(f"未注册的存储类型: {storage_type}")
  31.         
  32.         # 创建存储实例的键
  33.         instance_key = f"{storage_type}_{hash(frozenset(kwargs.items()))}"
  34.         
  35.         # 如果实例不存在,创建新实例
  36.         if instance_key not in self.storage_instances:
  37.             storage_class = self.storage_classes[storage_type]
  38.             self.storage_instances[instance_key] = storage_class(**kwargs)
  39.         
  40.         return self.storage_instances[instance_key]
  41.    
  42.     def close_all(self):
  43.         """关闭所有存储连接"""
  44.         for instance_key, storage in self.storage_instances.items():
  45.             if hasattr(storage, 'close'):
  46.                 storage.close()
  47.         
  48.         self.storage_instances.clear()
  49.         print("已关闭所有存储连接")
  50. # 使用示例
  51. def storage_factory_example():
  52.     # 创建存储工厂
  53.     factory = StorageFactory()
  54.    
  55.     # 注册存储类
  56.     factory.register_storage('sqlite', SQLiteStorage)
  57.     factory.register_storage('mongodb', MongoDBStorage)
  58.     factory.register_storage('csv', CSVStorage)
  59.    
  60.     # 获取SQLite存储实例
  61.     sqlite_storage = factory.get_storage('sqlite', db_path='data/example.db')
  62.    
  63.     # 获取MongoDB存储实例
  64.     mongo_storage = factory.get_storage('mongodb',
  65.                                        connection_string='mongodb://localhost:27017',
  66.                                        database_name='example_db')
  67.    
  68.     # 获取CSV存储实例
  69.     csv_storage = factory.get_storage('csv', base_dir='data/csv_files')
  70.    
  71.     # 使用存储实例...
  72.    
  73.     # 关闭所有连接
  74.     factory.close_all()
  75.    
  76.     return "存储工厂示例完成"
  77. if __name__ == "__main__":
  78.     storage_factory_example()
复制代码
6. 数据分析模块

数据分析模块负责对爬取的数据进行清洗、转换、分析和挖掘,从而提取有代价的信息和洞察。
6.1 数据清洗与预处置处罚

数据清洗是数据分析的第一步,用于处置处罚缺失值、非常值和格式不一致的数据:
  1. import pandas as pd
  2. import numpy as np
  3. import re
  4. from datetime import datetime
  5. import logging
  6. class DataCleaner:
  7.     """数据清洗类"""
  8.    
  9.     def __init__(self):
  10.         """初始化数据清洗器"""
  11.         self.logger = self._setup_logger()
  12.    
  13.     def _setup_logger(self):
  14.         """设置日志记录器"""
  15.         logger = logging.getLogger('DataCleaner')
  16.         logger.setLevel(logging.INFO)
  17.         
  18.         if not logger.handlers:
  19.             handler = logging.StreamHandler()
  20.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  21.             handler.setFormatter(formatter)
  22.             logger.addHandler(handler)
  23.         
  24.         return logger
  25.    
  26.     def handle_missing_values(self, df, strategy='drop', fill_value=None):
  27.         """处理缺失值
  28.         
  29.         参数:
  30.             df: 输入DataFrame
  31.             strategy: 处理策略,可选'drop'(删除)、'fill'(填充)
  32.             fill_value: 填充值,当strategy为'fill'时使用
  33.             
  34.         返回:
  35.             处理后的DataFrame
  36.         """
  37.         if df.empty:
  38.             self.logger.warning("输入DataFrame为空")
  39.             return df
  40.         
  41.         missing_count = df.isnull().sum().sum()
  42.         self.logger.info(f"检测到 {missing_count} 个缺失值")
  43.         
  44.         if missing_count == 0:
  45.             return df
  46.         
  47.         if strategy == 'drop':
  48.             # 删除包含缺失值的行
  49.             result = df.dropna()
  50.             self.logger.info(f"删除了 {len(df) - len(result)} 行含有缺失值的数据")
  51.             return result
  52.         
  53.         elif strategy == 'fill':
  54.             # 填充缺失值
  55.             if isinstance(fill_value, dict):
  56.                 # 对不同列使用不同的填充值
  57.                 result = df.fillna(fill_value)
  58.                 self.logger.info(f"使用指定值填充了缺失值: {fill_value}")
  59.             else:
  60.                 # 使用相同的值填充所有缺失值
  61.                 result = df.fillna(fill_value)
  62.                 self.logger.info(f"使用 {fill_value} 填充了所有缺失值")
  63.             return result
  64.         
  65.         else:
  66.             self.logger.error(f"未知的缺失值处理策略: {strategy}")
  67.             return df
  68.    
  69.     def remove_duplicates(self, df, subset=None):
  70.         """删除重复行
  71.         
  72.         参数:
  73.             df: 输入DataFrame
  74.             subset: 用于判断重复的列,默认使用所有列
  75.             
  76.         返回:
  77.             处理后的DataFrame
  78.         """
  79.         if df.empty:
  80.             return df
  81.         
  82.         # 删除重复行
  83.         result = df.drop_duplicates(subset=subset)
  84.         
  85.         removed_count = len(df) - len(result)
  86.         self.logger.info(f"删除了 {removed_count} 行重复数据")
  87.         
  88.         return result
  89.    
  90.     def handle_outliers(self, df, columns, method='zscore', threshold=3.0):
  91.         """处理异常值
  92.         
  93.         参数:
  94.             df: 输入DataFrame
  95.             columns: 要处理的列名列表
  96.             method: 异常值检测方法,可选'zscore'、'iqr'
  97.             threshold: 阈值,zscore方法使用
  98.             
  99.         返回:
  100.             处理后的DataFrame
  101.         """
  102.         if df.empty:
  103.             return df
  104.         
  105.         result = df.copy()
  106.         outliers_count = 0
  107.         
  108.         for col in columns:
  109.             if col not in df.columns:
  110.                 self.logger.warning(f"列 {col} 不存在")
  111.                 continue
  112.             
  113.             if not pd.api.types.is_numeric_dtype(df[col]):
  114.                 self.logger.warning(f"列 {col} 不是数值类型,跳过异常值检测")
  115.                 continue
  116.             
  117.             # 获取非缺失值
  118.             values = df[col].dropna()
  119.             
  120.             if method == 'zscore':
  121.                 # 使用Z-score方法检测异常值
  122.                 mean = values.mean()
  123.                 std = values.std()
  124.                 if std == 0:
  125.                     self.logger.warning(f"列 {col} 的标准差为0,跳过异常值检测")
  126.                     continue
  127.                
  128.                 z_scores = np.abs((values - mean) / std)
  129.                 outliers = values[z_scores > threshold].index
  130.                
  131.             elif method == 'iqr':
  132.                 # 使用IQR方法检测异常值
  133.                 q1 = values.quantile(0.25)
  134.                 q3 = values.quantile(0.75)
  135.                 iqr = q3 - q1
  136.                 lower_bound = q1 - 1.5 * iqr
  137.                 upper_bound = q3 + 1.5 * iqr
  138.                 outliers = values[(values < lower_bound) | (values > upper_bound)].index
  139.                
  140.             else:
  141.                 self.logger.error(f"未知的异常值检测方法: {method}")
  142.                 continue
  143.             
  144.             # 将异常值设为NaN
  145.             result.loc[outliers, col] = np.nan
  146.             outliers_count += len(outliers)
  147.         
  148.         self.logger.info(f"检测并处理了 {outliers_count} 个异常值")
  149.         return result
  150.    
  151.     def normalize_text(self, df, text_columns):
  152.         """文本标准化处理
  153.         
  154.         参数:
  155.             df: 输入DataFrame
  156.             text_columns: 要处理的文本列名列表
  157.             
  158.         返回:
  159.             处理后的DataFrame
  160.         """
  161.         if df.empty:
  162.             return df
  163.         
  164.         result = df.copy()
  165.         
  166.         for col in text_columns:
  167.             if col not in df.columns:
  168.                 self.logger.warning(f"列 {col} 不存在")
  169.                 continue
  170.             
  171.             if not pd.api.types.is_string_dtype(df[col]):
  172.                 self.logger.warning(f"列 {col} 不是文本类型")
  173.                 continue
  174.             
  175.             # 文本处理:去除多余空格、转为小写
  176.             result[col] = df[col].str.strip().str.lower()
  177.             
  178.             # 去除特殊字符
  179.             result[col] = result[col].apply(lambda x: re.sub(r'[^\w\s]', '', str(x)) if pd.notna(x) else x)
  180.             
  181.             self.logger.info(f"完成列 {col} 的文本标准化处理")
  182.         
  183.         return result
  184.    
  185.     def convert_data_types(self, df, type_dict):
  186.         """转换数据类型
  187.         
  188.         参数:
  189.             df: 输入DataFrame
  190.             type_dict: 类型转换字典,键为列名,值为目标类型
  191.             
  192.         返回:
  193.             处理后的DataFrame
  194.         """
  195.         if df.empty:
  196.             return df
  197.         
  198.         result = df.copy()
  199.         
  200.         for col, dtype in type_dict.items():
  201.             if col not in df.columns:
  202.                 self.logger.warning(f"列 {col} 不存在")
  203.                 continue
  204.             
  205.             try:
  206.                 result[col] = result[col].astype(dtype)
  207.                 self.logger.info(f"将列 {col} 的类型转换为 {dtype}")
  208.             except Exception as e:
  209.                 self.logger.error(f"转换列 {col} 的类型时出错: {e}")
  210.         
  211.         return result
  212.    
  213.     def parse_dates(self, df, date_columns, date_format=None):
  214.         """解析日期列
  215.         
  216.         参数:
  217.             df: 输入DataFrame
  218.             date_columns: 日期列名列表
  219.             date_format: 日期格式字符串(可选)
  220.             
  221.         返回:
  222.             处理后的DataFrame
  223.         """
  224.         if df.empty:
  225.             return df
  226.         
  227.         result = df.copy()
  228.         
  229.         for col in date_columns:
  230.             if col not in df.columns:
  231.                 self.logger.warning(f"列 {col} 不存在")
  232.                 continue
  233.             
  234.             try:
  235.                 if date_format:
  236.                     result[col] = pd.to_datetime(result[col], format=date_format)
  237.                 else:
  238.                     result[col] = pd.to_datetime(result[col])
  239.                
  240.                 self.logger.info(f"将列 {col} 转换为日期时间类型")
  241.             except Exception as e:
  242.                 self.logger.error(f"转换列 {col} 为日期时间类型时出错: {e}")
  243.         
  244.         return result
  245.    
  246.     def clean_data(self, df, config=None):
  247.         """综合数据清洗
  248.         
  249.         参数:
  250.             df: 输入DataFrame
  251.             config: 清洗配置字典
  252.             
  253.         返回:
  254.             清洗后的DataFrame
  255.         """
  256.         if df.empty:
  257.             return df
  258.         
  259.         if config is None:
  260.             config = {}
  261.         
  262.         result = df.copy()
  263.         
  264.         # 处理缺失值
  265.         if 'missing_values' in config:
  266.             missing_config = config['missing_values']
  267.             result = self.handle_missing_values(
  268.                 result,
  269.                 strategy=missing_config.get('strategy', 'drop'),
  270.                 fill_value=missing_config.get('fill_value')
  271.             )
  272.         
  273.         # 删除重复行
  274.         if config.get('remove_duplicates', True):
  275.             subset = config.get('duplicate_subset')
  276.             result = self.remove_duplicates(result, subset=subset)
  277.         
  278.         # 处理异常值
  279.         if 'outliers' in config:
  280.             outlier_config = config['outliers']
  281.             result = self.handle_outliers(
  282.                 result,
  283.                 columns=outlier_config.get('columns', []),
  284.                 method=outlier_config.get('method', 'zscore'),
  285.                 threshold=outlier_config.get('threshold', 3.0)
  286.             )
  287.         
  288.         # 文本标准化
  289.         if 'text_columns' in config:
  290.             result = self.normalize_text(result, config['text_columns'])
  291.         
  292.         # 转换数据类型
  293.         if 'type_conversions' in config:
  294.             result = self.convert_data_types(result, config['type_conversions'])
  295.         
  296.         # 解析日期
  297.         if 'date_columns' in config:
  298.             date_config = config['date_columns']
  299.             if isinstance(date_config, list):
  300.                 result = self.parse_dates(result, date_config)
  301.             elif isinstance(date_config, dict):
  302.                 for col, format_str in date_config.items():
  303.                     result = self.parse_dates(result, [col], date_format=format_str)
  304.         
  305.         self.logger.info(f"数据清洗完成,从 {len(df)} 行处理为 {len(result)} 行")
  306.         return result
  307. # 使用示例
  308. def data_cleaning_example():
  309.     # 创建示例数据
  310.     data = {
  311.         'product_name': ['iPhone 13  ', 'Samsung Galaxy', 'Xiaomi Mi 11', 'iPhone 13', None],
  312.         'price': [5999, 4999, 3999, 5999, 2999],
  313.         'rating': [4.8, 4.6, 4.5, 4.8, 10.0],  # 包含异常值
  314.         'reviews_count': ['120', '98', '75', '120', '30'],  # 字符串类型
  315.         'release_date': ['2021-09-15', '2021-08-20', '2021-03-10', '2021-09-15', '2022-01-01']
  316.     }
  317.     df = pd.DataFrame(data)
  318.    
  319.     # 创建数据清洗器
  320.     cleaner = DataCleaner()
  321.    
  322.     # 配置清洗参数
  323.     config = {
  324.         'missing_values': {'strategy': 'drop'},
  325.         'remove_duplicates': True,
  326.         'outliers': {
  327.             'columns': ['rating', 'price'],
  328.             'method': 'zscore',
  329.             'threshold': 2.5
  330.         },
  331.         'text_columns': ['product_name'],
  332.         'type_conversions': {'reviews_count': 'int'},
  333.         'date_columns': {'release_date': '%Y-%m-%d'}
  334.     }
  335.    
  336.     # 执行数据清洗
  337.     cleaned_df = cleaner.clean_data(df, config)
  338.    
  339.     print("原始数据:")
  340.     print(df)
  341.     print("\n清洗后的数据:")
  342.     print(cleaned_df)
  343.    
  344.     return cleaned_df
  345. if __name__ == "__main__":
  346.     data_cleaning_example()
复制代码
6.2 统计分析

统计分析用于计算数据的基本统计量和分布特性:
  1. import pandas as pd
  2. import numpy as np
  3. import scipy.stats as stats
  4. import logging
  5. class StatisticalAnalyzer:
  6.     """统计分析类"""
  7.    
  8.     def __init__(self):
  9.         """初始化统计分析器"""
  10.         self.logger = self._setup_logger()
  11.    
  12.     def _setup_logger(self):
  13.         """设置日志记录器"""
  14.         logger = logging.getLogger('StatisticalAnalyzer')
  15.         logger.setLevel(logging.INFO)
  16.         
  17.         if not logger.handlers:
  18.             handler = logging.StreamHandler()
  19.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  20.             handler.setFormatter(formatter)
  21.             logger.addHandler(handler)
  22.         
  23.         return logger
  24.    
  25.     def describe_data(self, df, include=None):
  26.         """生成数据描述性统计
  27.         
  28.         参数:
  29.             df: 输入DataFrame
  30.             include: 包含的数据类型,默认为None(所有数值列)
  31.             
  32.         返回:
  33.             描述性统计结果DataFrame
  34.         """
  35.         if df.empty:
  36.             self.logger.warning("输入DataFrame为空")
  37.             return pd.DataFrame()
  38.         
  39.         try:
  40.             stats_df = df.describe(include=include, percentiles=[.1, .25, .5, .75, .9])
  41.             self.logger.info("生成描述性统计完成")
  42.             return stats_df
  43.         except Exception as e:
  44.             self.logger.error(f"生成描述性统计时出错: {e}")
  45.             return pd.DataFrame()
  46.    
  47.     def correlation_analysis(self, df, method='pearson'):
  48.         """相关性分析
  49.         
  50.         参数:
  51.             df: 输入DataFrame
  52.             method: 相关系数计算方法,可选'pearson'、'spearman'、'kendall'
  53.             
  54.         返回:
  55.             相关系数矩阵DataFrame
  56.         """
  57.         if df.empty:
  58.             self.logger.warning("输入DataFrame为空")
  59.             return pd.DataFrame()
  60.         
  61.         # 筛选数值型列
  62.         numeric_df = df.select_dtypes(include=['number'])
  63.         
  64.         if numeric_df.empty:
  65.             self.logger.warning("没有数值型列可进行相关性分析")
  66.             return pd.DataFrame()
  67.         
  68.         try:
  69.             corr_matrix = numeric_df.corr(method=method)
  70.             self.logger.info(f"使用 {method} 方法完成相关性分析")
  71.             return corr_matrix
  72.         except Exception as e:
  73.             self.logger.error(f"计算相关系数时出错: {e}")
  74.             return pd.DataFrame()
  75.    
  76.     def frequency_analysis(self, df, column, normalize=False, bins=None):
  77.         """频率分析
  78.         
  79.         参数:
  80.             df: 输入DataFrame
  81.             column: 要分析的列名
  82.             normalize: 是否归一化频率
  83.             bins: 数值型数据的分箱数量
  84.             
  85.         返回:
  86.             频率分析结果Series
  87.         """
  88.         if df.empty or column not in df.columns:
  89.             self.logger.warning(f"输入DataFrame为空或不包含列 {column}")
  90.             return pd.Series()
  91.         
  92.         try:
  93.             # 检查列的数据类型
  94.             if pd.api.types.is_numeric_dtype(df[column]) and bins is not None:
  95.                 # 数值型数据,进行分箱
  96.                 freq = pd.cut(df[column], bins=bins).value_counts(normalize=normalize)
  97.                 self.logger.info(f"对数值列 {column} 进行分箱频率分析,分箱数量: {bins}")
  98.             else:
  99.                 # 分类数据,直接计算频率
  100.                 freq = df[column].value_counts(normalize=normalize)
  101.                 self.logger.info(f"对列 {column} 进行频率分析")
  102.             
  103.             return freq
  104.         except Exception as e:
  105.             self.logger.error(f"进行频率分析时出错: {e}")
  106.             return pd.Series()
  107.    
  108.     def group_analysis(self, df, group_by, agg_dict):
  109.         """分组分析
  110.         
  111.         参数:
  112.             df: 输入DataFrame
  113.             group_by: 分组列名或列名列表
  114.             agg_dict: 聚合字典,键为列名,值为聚合函数或函数列表
  115.             
  116.         返回:
  117.             分组分析结果DataFrame
  118.         """
  119.         if df.empty:
  120.             self.logger.warning("输入DataFrame为空")
  121.             return pd.DataFrame()
  122.         
  123.         try:
  124.             result = df.groupby(group_by).agg(agg_dict)
  125.             self.logger.info(f"按 {group_by} 完成分组分析")
  126.             return result
  127.         except Exception as e:
  128.             self.logger.error(f"进行分组分析时出错: {e}")
  129.             return pd.DataFrame()
  130.    
  131.     def time_series_analysis(self, df, date_column, value_column, freq='D'):
  132.         """时间序列分析
  133.         
  134.         参数:
  135.             df: 输入DataFrame
  136.             date_column: 日期列名
  137.             value_column: 值列名
  138.             freq: 重采样频率,如'D'(天)、'W'(周)、'M'(月)
  139.             
  140.         返回:
  141.             重采样后的时间序列DataFrame
  142.         """
  143.         if df.empty or date_column not in df.columns or value_column not in df.columns:
  144.             self.logger.warning(f"输入DataFrame为空或缺少必要的列")
  145.             return pd.DataFrame()
  146.         
  147.         try:
  148.             # 确保日期列是datetime类型
  149.             if not pd.api.types.is_datetime64_dtype(df[date_column]):
  150.                 df = df.copy()
  151.                 df[date_column] = pd.to_datetime(df[date_column])
  152.             
  153.             # 设置日期索引
  154.             ts_df = df.set_index(date_column)
  155.             
  156.             # 按指定频率重采样并计算均值
  157.             resampled = ts_df[value_column].resample(freq).mean()
  158.             
  159.             self.logger.info(f"完成时间序列分析,重采样频率: {freq}")
  160.             return resampled.reset_index()
  161.         except Exception as e:
  162.             self.logger.error(f"进行时间序列分析时出错: {e}")
  163.             return pd.DataFrame()
  164.    
  165.     def hypothesis_testing(self, df, column1, column2=None, test_type='ttest'):
  166.         """假设检验
  167.         
  168.         参数:
  169.             df: 输入DataFrame
  170.             column1: 第一个数据列名
  171.             column2: 第二个数据列名(对于双样本检验)
  172.             test_type: 检验类型,可选'ttest'、'anova'、'chi2'等
  173.             
  174.         返回:
  175.             检验结果字典
  176.         """
  177.         if df.empty or column1 not in df.columns:
  178.             self.logger.warning(f"输入DataFrame为空或不包含列 {column1}")
  179.             return {}
  180.         
  181.         try:
  182.             result = {}
  183.             
  184.             if test_type == 'ttest':
  185.                 # t检验
  186.                 if column2 and column2 in df.columns:
  187.                     # 双样本t检验
  188.                     t_stat, p_value = stats.ttest_ind(
  189.                         df[column1].dropna(),
  190.                         df[column2].dropna(),
  191.                         equal_var=False  # 不假设方差相等
  192.                     )
  193.                     result = {
  194.                         'test': 'Independent Samples t-test',
  195.                         't_statistic': t_stat,
  196.                         'p_value': p_value,
  197.                         'significant': p_value < 0.05
  198.                     }
  199.                     self.logger.info(f"完成独立样本t检验: {column1} vs {column2}")
  200.                 else:
  201.                     # 单样本t检验(与0比较)
  202.                     t_stat, p_value = stats.ttest_1samp(df[column1].dropna(), 0)
  203.                     result = {
  204.                         'test': 'One Sample t-test',
  205.                         't_statistic': t_stat,
  206.                         'p_value': p_value,
  207.                         'significant': p_value < 0.05
  208.                     }
  209.                     self.logger.info(f"完成单样本t检验: {column1}")
  210.             
  211.             elif test_type == 'chi2' and column2 and column2 in df.columns:
  212.                 # 卡方检验(分类变量)
  213.                 contingency_table = pd.crosstab(df[column1], df[column2])
  214.                 chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
  215.                 result = {
  216.                     'test': 'Chi-square Test',
  217.                     'chi2_statistic': chi2,
  218.                     'p_value': p_value,
  219.                     'degrees_of_freedom': dof,
  220.                     'significant': p_value < 0.05
  221.                 }
  222.                 self.logger.info(f"完成卡方检验: {column1} vs {column2}")
  223.             
  224.             else:
  225.                 self.logger.warning(f"不支持的检验类型: {test_type}")
  226.             
  227.             return result
  228.         
  229.         except Exception as e:
  230.             self.logger.error(f"进行假设检验时出错: {e}")
  231.             return {'error': str(e)}
  232. # 使用示例
  233. def statistical_analysis_example():
  234.     # 创建示例数据
  235.     np.random.seed(42)
  236.     n_samples = 200
  237.    
  238.     # 生成特征
  239.     X = np.random.randn(n_samples, 3)  # 3个特征
  240.    
  241.     # 生成分类目标变量
  242.     y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
  243.    
  244.     # 生成回归目标变量
  245.     y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  246.    
  247.     # 创建DataFrame
  248.     data = pd.DataFrame(
  249.         X,
  250.         columns=['feature_1', 'feature_2', 'feature_3']
  251.     )
  252.     data['target_class'] = y_class.astype(int)
  253.     data['target_reg'] = y_reg
  254.    
  255.     # 添加一些派生列
  256.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  257.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
  258.     data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
  259.    
  260.     # 创建统计分析器
  261.     analyzer = StatisticalAnalyzer()
  262.    
  263.     # 描述性统计
  264.     desc_stats = analyzer.describe_data(data)
  265.     print("描述性统计:")
  266.     print(desc_stats)
  267.    
  268.     # 相关性分析
  269.     corr_matrix = analyzer.correlation_analysis(data)
  270.     print("\n相关性矩阵:")
  271.     print(corr_matrix)
  272.    
  273.     # 频率分析
  274.     category_freq = analyzer.frequency_analysis(data, 'feature_1', normalize=True)
  275.     print("\n特征1频率分析:")
  276.     print(category_freq)
  277.    
  278.     # 分组分析
  279.     group_result = analyzer.group_analysis(
  280.         data,
  281.         'feature_1',
  282.         {'target_class': ['mean', 'sum'], 'feature_2': 'mean', 'feature_3': 'mean'}
  283.     )
  284.     print("\n分组分析结果:")
  285.     print(group_result)
  286.    
  287.     # 时间序列分析
  288.     ts_result = analyzer.time_series_analysis(data, 'feature_1', 'target_reg', freq='W')
  289.     print("\n时间序列分析结果(周均值):")
  290.     print(ts_result.head())
  291.    
  292.     # 假设检验
  293.     test_result = analyzer.hypothesis_testing(data, 'feature_1', test_type='ttest')
  294.     print("\n假设检验结果:")
  295.     print(test_result)
  296.    
  297.     return {
  298.         'desc_stats': desc_stats,
  299.         'corr_matrix': corr_matrix,
  300.         'category_freq': category_freq,
  301.         'group_result': group_result,
  302.         'ts_result': ts_result,
  303.         'test_result': test_result
  304.     }
  305. if __name__ == "__main__":
  306.     statistical_analysis_example()
复制代码
6.3 数据挖掘

数据挖掘是从大量数据中发现模式和关系的过程,包罗聚类分析、分类和回归模型等:
  1. import pandas as pd
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import seaborn as sns
  5. from sklearn.cluster import KMeans, DBSCAN
  6. from sklearn.preprocessing import StandardScaler, MinMaxScaler
  7. from sklearn.decomposition import PCA
  8. from sklearn.model_selection import train_test_split, cross_val_score
  9. from sklearn.linear_model import LinearRegression, LogisticRegression
  10. from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
  11. from sklearn.metrics import (
  12.     accuracy_score, precision_score, recall_score, f1_score,
  13.     mean_squared_error, r2_score, silhouette_score
  14. )
  15. import logging
  16. class DataMiner:
  17.     """数据挖掘类"""
  18.    
  19.     def __init__(self):
  20.         """初始化数据挖掘器"""
  21.         self.logger = self._setup_logger()
  22.         self.models = {}  # 存储训练好的模型
  23.    
  24.     def _setup_logger(self):
  25.         """设置日志记录器"""
  26.         logger = logging.getLogger('DataMiner')
  27.         logger.setLevel(logging.INFO)
  28.         
  29.         if not logger.handlers:
  30.             handler = logging.StreamHandler()
  31.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  32.             handler.setFormatter(formatter)
  33.             logger.addHandler(handler)
  34.         
  35.         return logger
  36.    
  37.     def preprocess_data(self, df, scale_method='standard', categorical_cols=None):
  38.         """数据预处理
  39.         
  40.         参数:
  41.             df: 输入DataFrame
  42.             scale_method: 缩放方法,可选'standard'、'minmax'
  43.             categorical_cols: 分类变量列名列表
  44.             
  45.         返回:
  46.             预处理后的DataFrame和预处理器
  47.         """
  48.         if df.empty:
  49.             self.logger.warning("输入DataFrame为空")
  50.             return df, None
  51.         
  52.         # 处理缺失值
  53.         df_clean = df.dropna()
  54.         if len(df_clean) < len(df):
  55.             self.logger.info(f"删除了 {len(df) - len(df_clean)} 行含有缺失值的数据")
  56.         
  57.         # 处理分类变量
  58.         if categorical_cols:
  59.             df_encoded = pd.get_dummies(df_clean, columns=categorical_cols)
  60.             self.logger.info(f"对 {len(categorical_cols)} 个分类变量进行了独热编码")
  61.         else:
  62.             df_encoded = df_clean
  63.         
  64.         # 数值变量缩放
  65.         numeric_cols = df_encoded.select_dtypes(include=['number']).columns
  66.         
  67.         if scale_method == 'standard':
  68.             scaler = StandardScaler()
  69.             self.logger.info("使用StandardScaler进行标准化")
  70.         elif scale_method == 'minmax':
  71.             scaler = MinMaxScaler()
  72.             self.logger.info("使用MinMaxScaler进行归一化")
  73.         else:
  74.             self.logger.warning(f"未知的缩放方法: {scale_method},不进行缩放")
  75.             return df_encoded, None
  76.         
  77.         if len(numeric_cols) > 0:
  78.             df_encoded[numeric_cols] = scaler.fit_transform(df_encoded[numeric_cols])
  79.             self.logger.info(f"对 {len(numeric_cols)} 个数值变量进行了缩放")
  80.         
  81.         return df_encoded, scaler
  82.    
  83.     def reduce_dimensions(self, df, n_components=2, method='pca'):
  84.         """降维
  85.         
  86.         参数:
  87.             df: 输入DataFrame
  88.             n_components: 目标维度
  89.             method: 降维方法,目前支持'pca'
  90.             
  91.         返回:
  92.             降维后的DataFrame和降维器
  93.         """
  94.         if df.empty:
  95.             self.logger.warning("输入DataFrame为空")
  96.             return df, None
  97.         
  98.         # 确保数据为数值型
  99.         numeric_df = df.select_dtypes(include=['number'])
  100.         
  101.         if numeric_df.empty:
  102.             self.logger.warning("没有数值型列可进行降维")
  103.             return df, None
  104.         
  105.         try:
  106.             if method == 'pca':
  107.                 reducer = PCA(n_components=n_components)
  108.                 reduced_data = reducer.fit_transform(numeric_df)
  109.                
  110.                 # 创建包含降维结果的DataFrame
  111.                 result_df = pd.DataFrame(
  112.                     reduced_data,
  113.                     columns=[f'PC{i+1}' for i in range(n_components)],
  114.                     index=df.index
  115.                 )
  116.                
  117.                 # 计算解释方差比例
  118.                 explained_variance = reducer.explained_variance_ratio_.sum()
  119.                 self.logger.info(f"PCA降维完成,保留了 {n_components} 个主成分,解释了 {explained_variance:.2%} 的方差")
  120.                
  121.                 return result_df, reducer
  122.             else:
  123.                 self.logger.warning(f"不支持的降维方法: {method}")
  124.                 return df, None
  125.         except Exception as e:
  126.             self.logger.error(f"降维过程中出错: {e}")
  127.             return df, None
  128.    
  129.     def cluster_data(self, df, method='kmeans', n_clusters=3, eps=0.5, min_samples=5):
  130.         """聚类分析
  131.         
  132.         参数:
  133.             df: 输入DataFrame
  134.             method: 聚类方法,可选'kmeans'、'dbscan'
  135.             n_clusters: KMeans的簇数量
  136.             eps: DBSCAN的邻域半径
  137.             min_samples: DBSCAN的最小样本数
  138.             
  139.         返回:
  140.             带有聚类标签的DataFrame和聚类器
  141.         """
  142.         if df.empty:
  143.             self.logger.warning("输入DataFrame为空")
  144.             return df, None
  145.         
  146.         # 确保数据为数值型
  147.         numeric_df = df.select_dtypes(include=['number'])
  148.         
  149.         if numeric_df.empty:
  150.             self.logger.warning("没有数值型列可进行聚类")
  151.             return df, None
  152.         
  153.         try:
  154.             result_df = df.copy()
  155.             
  156.             if method == 'kmeans':
  157.                 # K-means聚类
  158.                 clusterer = KMeans(n_clusters=n_clusters, random_state=42)
  159.                 labels = clusterer.fit_predict(numeric_df)
  160.                
  161.                 # 计算轮廓系数
  162.                 if n_clusters > 1 and len(numeric_df) > n_clusters:
  163.                     silhouette = silhouette_score(numeric_df, labels)
  164.                     self.logger.info(f"K-means聚类完成,轮廓系数: {silhouette:.4f}")
  165.                 else:
  166.                     self.logger.info("K-means聚类完成,但无法计算轮廓系数(簇数过少或数据量不足)")
  167.                
  168.             elif method == 'dbscan':
  169.                 # DBSCAN聚类
  170.                 clusterer = DBSCAN(eps=eps, min_samples=min_samples)
  171.                 labels = clusterer.fit_predict(numeric_df)
  172.                
  173.                 # 计算聚类统计信息
  174.                 n_clusters_found = len(set(labels)) - (1 if -1 in labels else 0)
  175.                 n_noise = list(labels).count(-1)
  176.                 self.logger.info(f"DBSCAN聚类完成,发现 {n_clusters_found} 个簇,{n_noise} 个噪声点")
  177.                
  178.             else:
  179.                 self.logger.warning(f"不支持的聚类方法: {method}")
  180.                 return df, None
  181.             
  182.             # 添加聚类标签
  183.             result_df['cluster'] = labels
  184.             
  185.             return result_df, clusterer
  186.         
  187.         except Exception as e:
  188.             self.logger.error(f"聚类过程中出错: {e}")
  189.             return df, None
  190.    
  191.     def train_classifier(self, df, target_col, feature_cols=None, model_type='random_forest', test_size=0.2):
  192.         """训练分类模型
  193.         
  194.         参数:
  195.             df: 输入DataFrame
  196.             target_col: 目标变量列名
  197.             feature_cols: 特征列名列表,默认使用所有数值列
  198.             model_type: 模型类型,可选'logistic'、'random_forest'
  199.             test_size: 测试集比例
  200.             
  201.         返回:
  202.             模型评估指标字典和训练好的模型
  203.         """
  204.         if df.empty or target_col not in df.columns:
  205.             self.logger.warning(f"输入DataFrame为空或不包含目标列 {target_col}")
  206.             return {}, None
  207.         
  208.         try:
  209.             # 准备特征和目标变量
  210.             if feature_cols is None:
  211.                 # 使用除目标列外的所有数值列作为特征
  212.                 feature_cols = df.select_dtypes(include=['number']).columns.tolist()
  213.                 if target_col in feature_cols:
  214.                     feature_cols.remove(target_col)
  215.             
  216.             if not feature_cols:
  217.                 self.logger.warning("没有可用的特征列")
  218.                 return {}, None
  219.             
  220.             X = df[feature_cols]
  221.             y = df[target_col]
  222.             
  223.             # 划分训练集和测试集
  224.             X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
  225.             
  226.             # 训练模型
  227.             if model_type == 'logistic':
  228.                 model = LogisticRegression(max_iter=1000, random_state=42)
  229.                 model_name = 'Logistic Regression'
  230.             elif model_type == 'random_forest':
  231.                 model = RandomForestClassifier(n_estimators=100, random_state=42)
  232.                 model_name = 'Random Forest'
  233.             else:
  234.                 self.logger.warning(f"不支持的分类模型类型: {model_type}")
  235.                 return {}, None
  236.             
  237.             model.fit(X_train, y_train)
  238.             
  239.             # 在测试集上评估
  240.             y_pred = model.predict(X_test)
  241.             
  242.             # 计算评估指标
  243.             metrics = {
  244.                 'accuracy': accuracy_score(y_test, y_pred),
  245.                 'precision': precision_score(y_test, y_pred, average='weighted'),
  246.                 'recall': recall_score(y_test, y_pred, average='weighted'),
  247.                 'f1': f1_score(y_test, y_pred, average='weighted')
  248.             }
  249.             
  250.             # 交叉验证
  251.             cv_scores = cross_val_score(model, X, y, cv=5)
  252.             metrics['cv_accuracy_mean'] = cv_scores.mean()
  253.             metrics['cv_accuracy_std'] = cv_scores.std()
  254.             
  255.             self.logger.info(f"{model_name}分类模型训练完成,准确率: {metrics['accuracy']:.4f}")
  256.             
  257.             # 存储模型
  258.             model_id = f"{model_type}_classifier_{target_col}"
  259.             self.models[model_id] = {
  260.                 'model': model,
  261.                 'feature_cols': feature_cols,
  262.                 'target_col': target_col,
  263.                 'metrics': metrics
  264.             }
  265.             
  266.             return metrics, model
  267.         
  268.         except Exception as e:
  269.             self.logger.error(f"训练分类模型时出错: {e}")
  270.             return {}, None
  271.    
  272.     def train_regressor(self, df, target_col, feature_cols=None, model_type='linear', test_size=0.2):
  273.         """训练回归模型
  274.         
  275.         参数:
  276.             df: 输入DataFrame
  277.             target_col: 目标变量列名
  278.             feature_cols: 特征列名列表,默认使用所有数值列
  279.             model_type: 模型类型,可选'linear'、'random_forest'
  280.             test_size: 测试集比例
  281.             
  282.         返回:
  283.             模型评估指标字典和训练好的模型
  284.         """
  285.         if df.empty or target_col not in df.columns:
  286.             self.logger.warning(f"输入DataFrame为空或不包含目标列 {target_col}")
  287.             return {}, None
  288.         
  289.         try:
  290.             # 准备特征和目标变量
  291.             if feature_cols is None:
  292.                 # 使用除目标列外的所有数值列作为特征
  293.                 feature_cols = df.select_dtypes(include=['number']).columns.tolist()
  294.                 if target_col in feature_cols:
  295.                     feature_cols.remove(target_col)
  296.             
  297.             if not feature_cols:
  298.                 self.logger.warning("没有可用的特征列")
  299.                 return {}, None
  300.             
  301.             X = df[feature_cols]
  302.             y = df[target_col]
  303.             
  304.             # 划分训练集和测试集
  305.             X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
  306.             
  307.             # 训练模型
  308.             if model_type == 'linear':
  309.                 model = LinearRegression()
  310.                 model_name = 'Linear Regression'
  311.             elif model_type == 'random_forest':
  312.                 model = RandomForestRegressor(n_estimators=100, random_state=42)
  313.                 model_name = 'Random Forest'
  314.             else:
  315.                 self.logger.warning(f"不支持的回归模型类型: {model_type}")
  316.                 return {}, None
  317.             
  318.             model.fit(X_train, y_train)
  319.             
  320.             # 在测试集上评估
  321.             y_pred = model.predict(X_test)
  322.             
  323.             # 计算评估指标
  324.             metrics = {
  325.                 'mse': mean_squared_error(y_test, y_pred),
  326.                 'rmse': np.sqrt(mean_squared_error(y_test, y_pred)),
  327.                 'r2': r2_score(y_test, y_pred)
  328.             }
  329.             
  330.             # 交叉验证
  331.             cv_scores = cross_val_score(model, X, y, cv=5, scoring='r2')
  332.             metrics['cv_r2_mean'] = cv_scores.mean()
  333.             metrics['cv_r2_std'] = cv_scores.std()
  334.             
  335.             self.logger.info(f"{model_name}回归模型训练完成,R²: {metrics['r2']:.4f}")
  336.             
  337.             # 存储模型
  338.             model_id = f"{model_type}_regressor_{target_col}"
  339.             self.models[model_id] = {
  340.                 'model': model,
  341.                 'feature_cols': feature_cols,
  342.                 'target_col': target_col,
  343.                 'metrics': metrics
  344.             }
  345.             
  346.             return metrics, model
  347.         
  348.         except Exception as e:
  349.             self.logger.error(f"训练回归模型时出错: {e}")
  350.             return {}, None
  351.    
  352.     def predict(self, model_id, new_data):
  353.         """使用训练好的模型进行预测
  354.         
  355.         参数:
  356.             model_id: 模型ID
  357.             new_data: 新数据DataFrame
  358.             
  359.         返回:
  360.             预测结果
  361.         """
  362.         if model_id not in self.models:
  363.             self.logger.warning(f"模型ID {model_id} 不存在")
  364.             return None
  365.         
  366.         model_info = self.models[model_id]
  367.         model = model_info['model']
  368.         feature_cols = model_info['feature_cols']
  369.         
  370.         # 检查新数据是否包含所有特征列
  371.         missing_cols = [col for col in feature_cols if col not in new_data.columns]
  372.         if missing_cols:
  373.             self.logger.warning(f"新数据缺少特征列: {missing_cols}")
  374.             return None
  375.         
  376.         try:
  377.             # 提取特征
  378.             X_new = new_data[feature_cols]
  379.             
  380.             # 进行预测
  381.             predictions = model.predict(X_new)
  382.             
  383.             self.logger.info(f"使用模型 {model_id} 完成预测,预测样本数: {len(predictions)}")
  384.             
  385.             return predictions
  386.         
  387.         except Exception as e:
  388.             self.logger.error(f"预测过程中出错: {e}")
  389.             return None
  390.    
  391.     def get_feature_importance(self, model_id):
  392.         """获取特征重要性
  393.         
  394.         参数:
  395.             model_id: 模型ID
  396.             
  397.         返回:
  398.             特征重要性DataFrame
  399.         """
  400.         if model_id not in self.models:
  401.             self.logger.warning(f"模型ID {model_id} 不存在")
  402.             return pd.DataFrame()
  403.         
  404.         model_info = self.models[model_id]
  405.         model = model_info['model']
  406.         feature_cols = model_info['feature_cols']
  407.         
  408.         # 检查模型是否有feature_importances_属性
  409.         if not hasattr(model, 'feature_importances_'):
  410.             self.logger.warning(f"模型 {model_id} 不支持特征重要性分析")
  411.             
  412.             # 对于线性模型,可以使用系数作为特征重要性
  413.             if hasattr(model, 'coef_'):
  414.                 importances = np.abs(model.coef_)
  415.                 if importances.ndim > 1:
  416.                     importances = importances.mean(axis=0)
  417.             else:
  418.                 return pd.DataFrame()
  419.         else:
  420.             importances = model.feature_importances_
  421.         
  422.         # 创建特征重要性DataFrame
  423.         importance_df = pd.DataFrame({
  424.             'feature': feature_cols,
  425.             'importance': importances
  426.         })
  427.         
  428.         # 按重要性降序排序
  429.         importance_df = importance_df.sort_values('importance', ascending=False)
  430.         
  431.         self.logger.info(f"获取模型 {model_id} 的特征重要性")
  432.         
  433.         return importance_df
  434. # 使用示例
  435. def data_mining_example():
  436.     # 创建示例数据
  437.     np.random.seed(42)
  438.     n_samples = 200
  439.    
  440.     # 生成特征
  441.     X = np.random.randn(n_samples, 5)  # 5个特征
  442.    
  443.     # 生成分类目标变量
  444.     y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
  445.    
  446.     # 生成回归目标变量
  447.     y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  448.    
  449.     # 创建DataFrame
  450.     data = pd.DataFrame(
  451.         X,
  452.         columns=[f'feature_{i+1}' for i in range(5)]
  453.     )
  454.     data['target_class'] = y_class.astype(int)
  455.     data['target_reg'] = y_reg
  456.    
  457.     # 添加一些派生列
  458.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  459.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
  460.     data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
  461.    
  462.     # 创建数据挖掘器
  463.     miner = DataMiner()
  464.    
  465.     # 数据预处理
  466.     print("数据预处理...")
  467.     data_processed, scaler = miner.preprocess_data(
  468.         data,
  469.         scale_method='standard',
  470.         categorical_cols=['month', 'day_of_week']
  471.     )
  472.    
  473.     # 降维分析
  474.     print("\n降维分析...")
  475.     data_reduced, pca = miner.reduce_dimensions(
  476.         data_processed.drop(['target_class', 'target_reg'], axis=1),
  477.         n_components=2
  478.     )
  479.    
  480.     # 聚类分析
  481.     print("\n聚类分析...")
  482.     data_clustered, kmeans = miner.cluster_data(
  483.         data_reduced,
  484.         method='kmeans',
  485.         n_clusters=3
  486.     )
  487.    
  488.     # 分类模型
  489.     print("\n训练分类模型...")
  490.     class_metrics, classifier = miner.train_classifier(
  491.         data_processed,
  492.         target_col='target_class',
  493.         model_type='random_forest'
  494.     )
  495.     print(f"分类模型评估指标: {class_metrics}")
  496.    
  497.     # 回归模型
  498.     print("\n训练回归模型...")
  499.     reg_metrics, regressor = miner.train_regressor(
  500.         data_processed,
  501.         target_col='target_reg',
  502.         model_type='random_forest'
  503.     )
  504.     print(f"回归模型评估指标: {reg_metrics}")
  505.    
  506.     # 特征重要性
  507.     print("\n特征重要性分析...")
  508.     importance = miner.get_feature_importance('random_forest_regressor_target_reg')
  509.     print(importance)
  510.    
  511.     return {
  512.         'data_processed': data_processed,
  513.         'data_reduced': data_reduced,
  514.         'data_clustered': data_clustered,
  515.         'class_metrics': class_metrics,
  516.         'reg_metrics': reg_metrics,
  517.         'feature_importance': importance
  518.     }
  519. if __name__ == "__main__":
  520.     data_mining_example()
复制代码
6.4 特性工程

特性工程是数据分析和机器学习中至关重要的一步,它可以明显进步模型性能:
  1. import pandas as pd
  2. import numpy as np
  3. from sklearn.preprocessing import PolynomialFeatures
  4. from sklearn.feature_selection import SelectKBest, f_regression, mutual_info_regression
  5. import logging
  6. class FeatureEngineer:
  7.     """特征工程类"""
  8.    
  9.     def __init__(self):
  10.         """初始化特征工程器"""
  11.         self.logger = self._setup_logger()
  12.    
  13.     def _setup_logger(self):
  14.         """设置日志记录器"""
  15.         logger = logging.getLogger('FeatureEngineer')
  16.         logger.setLevel(logging.INFO)
  17.         
  18.         if not logger.handlers:
  19.             handler = logging.StreamHandler()
  20.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  21.             handler.setFormatter(formatter)
  22.             logger.addHandler(handler)
  23.         
  24.         return logger
  25.    
  26.     def create_polynomial_features(self, df, feature_cols, degree=2, include_bias=False):
  27.         """创建多项式特征
  28.         
  29.         参数:
  30.             df: 输入DataFrame
  31.             feature_cols: 特征列名列表
  32.             degree: 多项式次数
  33.             include_bias: 是否包含偏置项
  34.             
  35.         返回:
  36.             包含多项式特征的DataFrame
  37.         """
  38.         if df.empty or not feature_cols:
  39.             self.logger.warning("输入DataFrame为空或未指定特征列")
  40.             return df
  41.         
  42.         try:
  43.             # 提取特征
  44.             X = df[feature_cols].values
  45.             
  46.             # 创建多项式特征
  47.             poly = PolynomialFeatures(degree=degree, include_bias=include_bias)
  48.             poly_features = poly.fit_transform(X)
  49.             
  50.             # 创建特征名称
  51.             feature_names = poly.get_feature_names_out(feature_cols)
  52.             
  53.             # 创建包含多项式特征的DataFrame
  54.             poly_df = pd.DataFrame(poly_features, columns=feature_names, index=df.index)
  55.             
  56.             # 合并原始DataFrame和多项式特征
  57.             result_df = pd.concat([df.drop(feature_cols, axis=1), poly_df], axis=1)
  58.             
  59.             self.logger.info(f"创建了 {poly_features.shape[1]} 个多项式特征,次数: {degree}")
  60.             
  61.             return result_df
  62.         
  63.         except Exception as e:
  64.             self.logger.error(f"创建多项式特征时出错: {e}")
  65.             return df
  66.    
  67.     def create_interaction_features(self, df, feature_cols):
  68.         """创建交互特征
  69.         
  70.         参数:
  71.             df: 输入DataFrame
  72.             feature_cols: 特征列名列表
  73.             
  74.         返回:
  75.             包含交互特征的DataFrame
  76.         """
  77.         if df.empty or len(feature_cols) < 2:
  78.             self.logger.warning("输入DataFrame为空或特征列不足")
  79.             return df
  80.         
  81.         try:
  82.             result_df = df.copy()
  83.             interaction_count = 0
  84.             
  85.             # 创建两两特征的交互项
  86.             for i in range(len(feature_cols)):
  87.                 for j in range(i+1, len(feature_cols)):
  88.                     col1 = feature_cols[i]
  89.                     col2 = feature_cols[j]
  90.                     
  91.                     # 创建交互特征
  92.                     interaction_name = f"{col1}_x_{col2}"
  93.                     result_df[interaction_name] = df[col1] * df[col2]
  94.                     interaction_count += 1
  95.             
  96.             self.logger.info(f"创建了 {interaction_count} 个交互特征")
  97.             
  98.             return result_df
  99.         
  100.         except Exception as e:
  101.             self.logger.error(f"创建交互特征时出错: {e}")
  102.             return df
  103.    
  104.     def create_binning_features(self, df, feature_col, bins=5, strategy='uniform'):
  105.         """创建分箱特征
  106.         
  107.         参数:
  108.             df: 输入DataFrame
  109.             feature_col: 要分箱的特征列名
  110.             bins: 分箱数量或边界列表
  111.             strategy: 分箱策略,可选'uniform'、'quantile'
  112.             
  113.         返回:
  114.             包含分箱特征的DataFrame
  115.         """
  116.         if df.empty or feature_col not in df.columns:
  117.             self.logger.warning(f"输入DataFrame为空或不包含列 {feature_col}")
  118.             return df
  119.         
  120.         try:
  121.             result_df = df.copy()
  122.             
  123.             # 确定分箱边界
  124.             if isinstance(bins, int):
  125.                 if strategy == 'uniform':
  126.                     # 均匀分箱
  127.                     bin_edges = np.linspace(
  128.                         df[feature_col].min(),
  129.                         df[feature_col].max(),
  130.                         bins + 1
  131.                     )
  132.                 elif strategy == 'quantile':
  133.                     # 分位数分箱
  134.                     bin_edges = np.percentile(
  135.                         df[feature_col],
  136.                         np.linspace(0, 100, bins + 1)
  137.                     )
  138.                 else:
  139.                     self.logger.warning(f"不支持的分箱策略: {strategy}")
  140.                     return df
  141.             else:
  142.                 # 使用指定的分箱边界
  143.                 bin_edges = bins
  144.             
  145.             # 创建分箱特征
  146.             binned_feature = pd.cut(
  147.                 df[feature_col],
  148.                 bins=bin_edges,
  149.                 labels=False,
  150.                 include_lowest=True
  151.             )
  152.             
  153.             # 添加分箱特征
  154.             result_df[f"{feature_col}_bin"] = binned_feature
  155.             
  156.             # 创建独热编码的分箱特征
  157.             bin_dummies = pd.get_dummies(
  158.                 binned_feature,
  159.                 prefix=f"{feature_col}_bin",
  160.                 prefix_sep="_"
  161.             )
  162.             
  163.             # 合并结果
  164.             result_df = pd.concat([result_df, bin_dummies], axis=1)
  165.             
  166.             self.logger.info(f"对特征 {feature_col} 创建了 {len(bin_edges)-1} 个分箱特征")
  167.             
  168.             return result_df
  169.         
  170.         except Exception as e:
  171.             self.logger.error(f"创建分箱特征时出错: {e}")
  172.             return df
  173.    
  174.     def select_best_features(self, df, feature_cols, target_col, k=5, method='f_regression'):
  175.         """选择最佳特征
  176.         
  177.         参数:
  178.             df: 输入DataFrame
  179.             feature_cols: 特征列名列表
  180.             target_col: 目标变量列名
  181.             k: 选择的特征数量
  182.             method: 特征选择方法,可选'f_regression'、'mutual_info'
  183.             
  184.         返回:
  185.             包含选定特征的DataFrame和特征得分
  186.         """
  187.         if df.empty or not feature_cols or target_col not in df.columns:
  188.             self.logger.warning("输入DataFrame为空或未指定特征列或目标列")
  189.             return df, {}
  190.         
  191.         try:
  192.             # 提取特征和目标变量
  193.             X = df[feature_cols]
  194.             y = df[target_col]
  195.             
  196.             # 选择特征选择器
  197.             if method == 'f_regression':
  198.                 selector = SelectKBest(score_func=f_regression, k=k)
  199.                 method_name = "F回归"
  200.             elif method == 'mutual_info':
  201.                 selector = SelectKBest(score_func=mutual_info_regression, k=k)
  202.                 method_name = "互信息"
  203.             else:
  204.                 self.logger.warning(f"不支持的特征选择方法: {method}")
  205.                 return df, {}
  206.             
  207.             # 拟合选择器
  208.             selector.fit(X, y)
  209.             
  210.             # 获取选定的特征索引
  211.             selected_indices = selector.get_support(indices=True)
  212.             selected_features = [feature_cols[i] for i in selected_indices]
  213.             
  214.             # 创建特征得分字典
  215.             feature_scores = dict(zip(feature_cols, selector.scores_))
  216.             
  217.             # 创建包含选定特征的DataFrame
  218.             result_df = df.copy()
  219.             dropped_features = [col for col in feature_cols if col not in selected_features]
  220.             if dropped_features:
  221.                 result_df = result_df.drop(dropped_features, axis=1)
  222.             
  223.             self.logger.info(f"使用 {method_name} 方法选择了 {len(selected_features)} 个最佳特征")
  224.             
  225.             return result_df, feature_scores
  226.         
  227.         except Exception as e:
  228.             self.logger.error(f"选择最佳特征时出错: {e}")
  229.             return df, {}
  230. # 使用示例
  231. def feature_engineering_example():
  232.     # 创建示例数据
  233.     np.random.seed(42)
  234.     n_samples = 200
  235.    
  236.     # 生成特征
  237.     X = np.random.randn(n_samples, 3)  # 3个特征
  238.    
  239.     # 生成目标变量(回归)
  240.     y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  241.    
  242.     # 创建DataFrame
  243.     data = pd.DataFrame(
  244.         X,
  245.         columns=['feature_1', 'feature_2', 'feature_3']
  246.     )
  247.     data['target'] = y
  248.    
  249.     # 创建特征工程器
  250.     engineer = FeatureEngineer()
  251.    
  252.     # 创建多项式特征
  253.     print("创建多项式特征...")
  254.     poly_data = engineer.create_polynomial_features(
  255.         data,
  256.         ['feature_1', 'feature_2', 'feature_3'],
  257.         degree=2
  258.     )
  259.     print(f"多项式特征后的列: {poly_data.columns.tolist()}")
  260.    
  261.     # 创建交互特征
  262.     print("\n创建交互特征...")
  263.     interaction_data = engineer.create_interaction_features(
  264.         data,
  265.         ['feature_1', 'feature_2', 'feature_3']
  266.     )
  267.     print(f"交互特征后的列: {interaction_data.columns.tolist()}")
  268.    
  269.     # 创建分箱特征
  270.     print("\n创建分箱特征...")
  271.     binned_data = engineer.create_binning_features(
  272.         data,
  273.         'feature_1',
  274.         bins=5,
  275.         strategy='quantile'
  276.     )
  277.     print(f"分箱特征后的列: {binned_data.columns.tolist()}")
  278.    
  279.     # 特征选择
  280.     print("\n特征选择...")
  281.     # 首先创建更多特征用于选择
  282.     combined_data = engineer.create_polynomial_features(
  283.         data,
  284.         ['feature_1', 'feature_2', 'feature_3'],
  285.         degree=2
  286.     )
  287.    
  288.     # 选择最佳特征
  289.     selected_data, feature_scores = engineer.select_best_features(
  290.         combined_data,
  291.         [col for col in combined_data.columns if col != 'target'],
  292.         'target',
  293.         k=5,
  294.         method='f_regression'
  295.     )
  296.    
  297.     print("特征得分:")
  298.     for feature, score in sorted(feature_scores.items(), key=lambda x: x[1], reverse=True):
  299.         print(f"{feature}: {score:.4f}")
  300.    
  301.     print(f"\n选择的特征: {[col for col in selected_data.columns if col != 'target']}")
  302.    
  303.     return {
  304.         'original_data': data,
  305.         'poly_data': poly_data,
  306.         'interaction_data': interaction_data,
  307.         'binned_data': binned_data,
  308.         'selected_data': selected_data,
  309.         'feature_scores': feature_scores
  310.     }
  311. if __name__ == "__main__":
  312.     feature_engineering_example()
复制代码
6.5 数据分析模块集成

以下是如何将数据清洗、统计分析、数据挖掘和特性工程组件集成到一个完备的数据分析流程中:
  1. def complete_data_analysis_pipeline(data, config=None):
  2.     """完整的数据分析流程
  3.    
  4.     参数:
  5.         data: 输入DataFrame
  6.         config: 配置字典
  7.         
  8.     返回:
  9.         分析结果字典
  10.     """
  11.     if config is None:
  12.         config = {}
  13.    
  14.     results = {'original_data': data}
  15.    
  16.     # 1. 数据清洗
  17.     print("1. 执行数据清洗...")
  18.     cleaner = DataCleaner()
  19.     clean_config = config.get('cleaning', {})
  20.     cleaned_data = cleaner.clean_data(data, clean_config)
  21.     results['cleaned_data'] = cleaned_data
  22.    
  23.     # 2. 统计分析
  24.     print("\n2. 执行统计分析...")
  25.     analyzer = StatisticalAnalyzer()
  26.    
  27.     # 描述性统计
  28.     desc_stats = analyzer.describe_data(cleaned_data)
  29.     results['descriptive_stats'] = desc_stats
  30.    
  31.     # 相关性分析
  32.     corr_matrix = analyzer.correlation_analysis(cleaned_data)
  33.     results['correlation_matrix'] = corr_matrix
  34.    
  35.     # 3. 特征工程
  36.     print("\n3. 执行特征工程...")
  37.     engineer = FeatureEngineer()
  38.     feature_config = config.get('feature_engineering', {})
  39.    
  40.     engineered_data = cleaned_data.copy()
  41.    
  42.     # 应用多项式特征
  43.     if 'polynomial' in feature_config:
  44.         poly_config = feature_config['polynomial']
  45.         engineered_data = engineer.create_polynomial_features(
  46.             engineered_data,
  47.             poly_config.get('features', []),
  48.             degree=poly_config.get('degree', 2)
  49.         )
  50.    
  51.     # 应用交互特征
  52.     if 'interaction' in feature_config:
  53.         interaction_config = feature_config['interaction']
  54.         engineered_data = engineer.create_interaction_features(
  55.             engineered_data,
  56.             interaction_config.get('features', [])
  57.         )
  58.    
  59.     # 应用分箱特征
  60.     if 'binning' in feature_config:
  61.         for bin_config in feature_config['binning']:
  62.             engineered_data = engineer.create_binning_features(
  63.                 engineered_data,
  64.                 bin_config.get('feature'),
  65.                 bins=bin_config.get('bins', 5),
  66.                 strategy=bin_config.get('strategy', 'uniform')
  67.             )
  68.    
  69.     results['engineered_data'] = engineered_data
  70.    
  71.     # 4. 数据挖掘
  72.     print("\n4. 执行数据挖掘...")
  73.     miner = DataMiner()
  74.     mining_config = config.get('mining', {})
  75.    
  76.     # 数据预处理
  77.     processed_data, scaler = miner.preprocess_data(
  78.         engineered_data,
  79.         scale_method=mining_config.get('scale_method', 'standard'),
  80.         categorical_cols=mining_config.get('categorical_cols', [])
  81.     )
  82.     results['processed_data'] = processed_data
  83.    
  84.     # 降维分析
  85.     if 'dimensionality_reduction' in mining_config:
  86.         dr_config = mining_config['dimensionality_reduction']
  87.         reduced_data, reducer = miner.reduce_dimensions(
  88.             processed_data,
  89.             n_components=dr_config.get('n_components', 2),
  90.             method=dr_config.get('method', 'pca')
  91.         )
  92.         results['reduced_data'] = reduced_data
  93.    
  94.     # 聚类分析
  95.     if 'clustering' in mining_config:
  96.         cluster_config = mining_config['clustering']
  97.         data_to_cluster = results.get('reduced_data', processed_data)
  98.         clustered_data, clusterer = miner.cluster_data(
  99.             data_to_cluster,
  100.             method=cluster_config.get('method', 'kmeans'),
  101.             n_clusters=cluster_config.get('n_clusters', 3)
  102.         )
  103.         results['clustered_data'] = clustered_data
  104.    
  105.     # 模型训练
  106.     if 'models' in mining_config:
  107.         models_results = {}
  108.         
  109.         for model_config in mining_config['models']:
  110.             model_type = model_config.get('type')
  111.             target = model_config.get('target')
  112.             features = model_config.get('features')
  113.             
  114.             if model_type == 'classifier':
  115.                 metrics, model = miner.train_classifier(
  116.                     processed_data,
  117.                     target_col=target,
  118.                     feature_cols=features,
  119.                     model_type=model_config.get('algorithm', 'random_forest')
  120.                 )
  121.                 models_results[f'classifier_{target}'] = {
  122.                     'metrics': metrics,
  123.                     'model_id': f"{model_config.get('algorithm', 'random_forest')}_classifier_{target}"
  124.                 }
  125.                
  126.             elif model_type == 'regressor':
  127.                 metrics, model = miner.train_regressor(
  128.                     processed_data,
  129.                     target_col=target,
  130.                     feature_cols=features,
  131.                     model_type=model_config.get('algorithm', 'random_forest')
  132.                 )
  133.                 models_results[f'regressor_{target}'] = {
  134.                     'metrics': metrics,
  135.                     'model_id': f"{model_config.get('algorithm', 'random_forest')}_regressor_{target}"
  136.                 }
  137.         
  138.         results['models'] = models_results
  139.    
  140.     print("\n数据分析流程完成!")
  141.     return results
  142. # 使用示例
  143. def data_analysis_example():
  144.     # 创建示例数据
  145.     np.random.seed(42)
  146.     n_samples = 500
  147.    
  148.     # 生成特征
  149.     X = np.random.randn(n_samples, 4)  # 4个特征
  150.    
  151.     # 生成分类目标变量
  152.     y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
  153.    
  154.     # 生成回归目标变量
  155.     y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  156.    
  157.     # 创建DataFrame
  158.     data = pd.DataFrame(
  159.         X,
  160.         columns=[f'feature_{i+1}' for i in range(4)]
  161.     )
  162.     data['category'] = np.random.choice(['A', 'B', 'C', 'D'], n_samples)
  163.     data['target_class'] = y_class.astype(int)
  164.     data['target_reg'] = y_reg
  165.    
  166.     # 添加一些派生列
  167.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  168.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
  169.     data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
  170.    
  171.     # 配置分析流程
  172.     config = {
  173.         'cleaning': {
  174.             'missing_values': {'strategy': 'drop'},
  175.             'remove_duplicates': True,
  176.             'outliers': {
  177.                 'columns': ['feature_1', 'feature_2', 'feature_3', 'feature_4'],
  178.                 'method': 'zscore',
  179.                 'threshold': 3.0
  180.             },
  181.             'text_columns': [],
  182.             'type_conversions': {},
  183.             'date_columns': {}
  184.         },
  185.         'feature_engineering': {
  186.             'polynomial': {
  187.                 'features': ['feature_1', 'feature_2'],
  188.                 'degree': 2
  189.             },
  190.             'interaction': {
  191.                 'features': ['feature_1', 'feature_2', 'feature_3']
  192.             },
  193.             'binning': [
  194.                 {
  195.                     'feature': 'feature_4',
  196.                     'bins': 5,
  197.                     'strategy': 'quantile'
  198.                 }
  199.             ]
  200.         },
  201.         'mining': {
  202.             'scale_method': 'standard',
  203.             'categorical_cols': ['category'],
  204.             'dimensionality_reduction': {
  205.                 'n_components': 2,
  206.                 'method': 'pca'
  207.             },
  208.             'clustering': {
  209.                 'method': 'kmeans',
  210.                 'n_clusters': 3
  211.             },
  212.             'models': [
  213.                 {
  214.                     'type': 'classifier',
  215.                     'target': 'target_class',
  216.                     'algorithm': 'random_forest'
  217.                 },
  218.                 {
  219.                     'type': 'regressor',
  220.                     'target': 'target_reg',
  221.                     'algorithm': 'random_forest'
  222.                 }
  223.             ]
  224.         }
  225.     }
  226.    
  227.     # 执行分析流程
  228.     results = complete_data_analysis_pipeline(data, config)
  229.    
  230.     # 打印部分结果
  231.     print("\n描述性统计:")
  232.     print(results['descriptive_stats'])
  233.    
  234.     print("\n模型性能:")
  235.     for model_name, model_info in results['models'].items():
  236.         print(f"{model_name}: {model_info['metrics']}")
  237.    
  238.     return results
  239. if __name__ == "__main__":
  240.     data_analysis_example()
复制代码
7. 数据可视化模块

数据可视化是将数据转化为图形表现的过程,通过视觉元素如图表、图形和地图,使复杂数据更容易理解和分析。
7.1 静态可视化

静态可视化是指天生不可交互的图表,重要使用Matplotlib和Seaborn库:
  1. import matplotlib.pyplot as plt
  2. import seaborn as sns
  3. import pandas as pd
  4. import numpy as np
  5. import matplotlib.ticker as ticker
  6. from matplotlib.colors import LinearSegmentedColormap
  7. import logging
  8. from pathlib import Path
  9. class StaticVisualizer:
  10.     """静态可视化类"""
  11.    
  12.     def __init__(self, output_dir='visualizations'):
  13.         """初始化可视化器
  14.         
  15.         参数:
  16.             output_dir: 输出目录
  17.         """
  18.         self.output_dir = output_dir
  19.         self.logger = self._setup_logger()
  20.         self._setup_style()
  21.         self._ensure_output_dir()
  22.    
  23.     def _setup_logger(self):
  24.         """设置日志记录器"""
  25.         logger = logging.getLogger('StaticVisualizer')
  26.         logger.setLevel(logging.INFO)
  27.         
  28.         if not logger.handlers:
  29.             handler = logging.StreamHandler()
  30.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  31.             handler.setFormatter(formatter)
  32.             logger.addHandler(handler)
  33.         
  34.         return logger
  35.    
  36.     def _setup_style(self):
  37.         """设置可视化样式"""
  38.         # 设置Seaborn样式
  39.         sns.set(style="whitegrid")
  40.         
  41.         # 设置Matplotlib参数
  42.         plt.rcParams['figure.figsize'] = (10, 6)
  43.         plt.rcParams['font.size'] = 12
  44.         plt.rcParams['axes.labelsize'] = 14
  45.         plt.rcParams['axes.titlesize'] = 16
  46.         plt.rcParams['xtick.labelsize'] = 12
  47.         plt.rcParams['ytick.labelsize'] = 12
  48.         plt.rcParams['legend.fontsize'] = 12
  49.         plt.rcParams['figure.titlesize'] = 20
  50.    
  51.     def _ensure_output_dir(self):
  52.         """确保输出目录存在"""
  53.         Path(self.output_dir).mkdir(parents=True, exist_ok=True)
  54.         self.logger.info(f"输出目录: {self.output_dir}")
  55.    
  56.     def save_figure(self, fig, filename, dpi=300):
  57.         """保存图表
  58.         
  59.         参数:
  60.             fig: 图表对象
  61.             filename: 文件名
  62.             dpi: 分辨率
  63.         """
  64.         filepath = Path(self.output_dir) / filename
  65.         fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
  66.         self.logger.info(f"图表已保存: {filepath}")
  67.         
  68.         return filepath
  69.    
  70.     def plot_bar_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
  71.                       color='skyblue', figsize=(10, 6), save_as=None, **kwargs):
  72.         """绘制条形图
  73.         
  74.         参数:
  75.             data: DataFrame
  76.             x: x轴列名
  77.             y: y轴列名
  78.             title: 图表标题
  79.             xlabel: x轴标签
  80.             ylabel: y轴标签
  81.             color: 条形颜色
  82.             figsize: 图表大小
  83.             save_as: 保存文件名
  84.             **kwargs: 其他参数
  85.             
  86.         返回:
  87.             matplotlib图表对象
  88.         """
  89.         try:
  90.             # 创建图表
  91.             fig, ax = plt.subplots(figsize=figsize)
  92.             
  93.             # 绘制条形图
  94.             sns.barplot(x=x, y=y, data=data, color=color, ax=ax, **kwargs)
  95.             
  96.             # 设置标题和标签
  97.             if title:
  98.                 ax.set_title(title)
  99.             if xlabel:
  100.                 ax.set_xlabel(xlabel)
  101.             if ylabel:
  102.                 ax.set_ylabel(ylabel)
  103.             
  104.             # 格式化y轴标签
  105.             ax.yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
  106.             
  107.             # 添加数值标签
  108.             for p in ax.patches:
  109.                 ax.annotate(f'{p.get_height():,.0f}',
  110.                            (p.get_x() + p.get_width() / 2., p.get_height()),
  111.                            ha='center', va='bottom', fontsize=10)
  112.             
  113.             # 调整布局
  114.             plt.tight_layout()
  115.             
  116.             # 保存图表
  117.             if save_as:
  118.                 self.save_figure(fig, save_as)
  119.             
  120.             return fig
  121.             
  122.         except Exception as e:
  123.             self.logger.error(f"绘制条形图时出错: {e}")
  124.             return None
  125.    
  126.     def plot_line_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
  127.                        color='royalblue', figsize=(12, 6), save_as=None, **kwargs):
  128.         """绘制折线图
  129.         
  130.         参数:
  131.             data: DataFrame
  132.             x: x轴列名
  133.             y: y轴列名或列名列表
  134.             title: 图表标题
  135.             xlabel: x轴标签
  136.             ylabel: y轴标签
  137.             color: 线条颜色或颜色列表
  138.             figsize: 图表大小
  139.             save_as: 保存文件名
  140.             **kwargs: 其他参数
  141.             
  142.         返回:
  143.             matplotlib图表对象
  144.         """
  145.         try:
  146.             # 创建图表
  147.             fig, ax = plt.subplots(figsize=figsize)
  148.             
  149.             # 处理多条线的情况
  150.             if isinstance(y, list):
  151.                 if not isinstance(color, list):
  152.                     color = [plt.cm.tab10(i) for i in range(len(y))]
  153.                
  154.                 for i, col in enumerate(y):
  155.                     data.plot(x=x, y=col, ax=ax, label=col, color=color[i % len(color)], **kwargs)
  156.             else:
  157.                 data.plot(x=x, y=y, ax=ax, color=color, **kwargs)
  158.             
  159.             # 设置标题和标签
  160.             if title:
  161.                 ax.set_title(title)
  162.             if xlabel:
  163.                 ax.set_xlabel(xlabel)
  164.             if ylabel:
  165.                 ax.set_ylabel(ylabel)
  166.             
  167.             # 添加网格线
  168.             ax.grid(True, linestyle='--', alpha=0.7)
  169.             
  170.             # 添加图例
  171.             if isinstance(y, list) and len(y) > 1:
  172.                 ax.legend()
  173.             
  174.             # 调整布局
  175.             plt.tight_layout()
  176.             
  177.             # 保存图表
  178.             if save_as:
  179.                 self.save_figure(fig, save_as)
  180.             
  181.             return fig
  182.             
  183.         except Exception as e:
  184.             self.logger.error(f"绘制折线图时出错: {e}")
  185.             return None
  186.    
  187.     def plot_pie_chart(self, data, values, names, title=None, figsize=(10, 10),
  188.                       colors=None, autopct='%1.1f%%', save_as=None, **kwargs):
  189.         """绘制饼图
  190.         
  191.         参数:
  192.             data: DataFrame
  193.             values: 值列名
  194.             names: 名称列名
  195.             title: 图表标题
  196.             figsize: 图表大小
  197.             colors: 颜色列表
  198.             autopct: 百分比格式
  199.             save_as: 保存文件名
  200.             **kwargs: 其他参数
  201.             
  202.         返回:
  203.             matplotlib图表对象
  204.         """
  205.         try:
  206.             # 准备数据
  207.             if isinstance(data, pd.DataFrame):
  208.                 values_data = data[values].values
  209.                 names_data = data[names].values
  210.             else:
  211.                 values_data = values
  212.                 names_data = names
  213.             
  214.             # 创建图表
  215.             fig, ax = plt.subplots(figsize=figsize)
  216.             
  217.             # 绘制饼图
  218.             wedges, texts, autotexts = ax.pie(
  219.                 values_data,
  220.                 labels=names_data,
  221.                 autopct=autopct,
  222.                 colors=colors,
  223.                 startangle=90,
  224.                 **kwargs
  225.             )
  226.             
  227.             # 设置标题
  228.             if title:
  229.                 ax.set_title(title)
  230.             
  231.             # 设置等比例
  232.             ax.axis('equal')
  233.             
  234.             # 调整文本样式
  235.             plt.setp(autotexts, size=10, weight='bold')
  236.             
  237.             # 调整布局
  238.             plt.tight_layout()
  239.             
  240.             # 保存图表
  241.             if save_as:
  242.                 self.save_figure(fig, save_as)
  243.             
  244.             return fig
  245.             
  246.         except Exception as e:
  247.             self.logger.error(f"绘制饼图时出错: {e}")
  248.             return None
  249.    
  250.     def plot_histogram(self, data, column, bins=30, title=None, xlabel=None, ylabel='频率',
  251.                       color='skyblue', kde=True, figsize=(10, 6), save_as=None, **kwargs):
  252.         """绘制直方图
  253.         
  254.         参数:
  255.             data: DataFrame
  256.             column: 列名
  257.             bins: 分箱数量
  258.             title: 图表标题
  259.             xlabel: x轴标签
  260.             ylabel: y轴标签
  261.             color: 直方图颜色
  262.             kde: 是否显示核密度估计
  263.             figsize: 图表大小
  264.             save_as: 保存文件名
  265.             **kwargs: 其他参数
  266.             
  267.         返回:
  268.             matplotlib图表对象
  269.         """
  270.         try:
  271.             # 创建图表
  272.             fig, ax = plt.subplots(figsize=figsize)
  273.             
  274.             # 绘制直方图
  275.             sns.histplot(data=data, x=column, bins=bins, kde=kde, color=color, ax=ax, **kwargs)
  276.             
  277.             # 设置标题和标签
  278.             if title:
  279.                 ax.set_title(title)
  280.             if xlabel:
  281.                 ax.set_xlabel(xlabel)
  282.             if ylabel:
  283.                 ax.set_ylabel(ylabel)
  284.             
  285.             # 调整布局
  286.             plt.tight_layout()
  287.             
  288.             # 保存图表
  289.             if save_as:
  290.                 self.save_figure(fig, save_as)
  291.             
  292.             return fig
  293.             
  294.         except Exception as e:
  295.             self.logger.error(f"绘制直方图时出错: {e}")
  296.             return None
  297.    
  298.     def plot_scatter(self, data, x, y, title=None, xlabel=None, ylabel=None,
  299.                     hue=None, palette='viridis', size=None, figsize=(10, 8), save_as=None, **kwargs):
  300.         """绘制散点图
  301.         
  302.         参数:
  303.             data: DataFrame
  304.             x: x轴列名
  305.             y: y轴列名
  306.             title: 图表标题
  307.             xlabel: x轴标签
  308.             ylabel: y轴标签
  309.             hue: 分组变量
  310.             palette: 颜色调色板
  311.             size: 点大小变量
  312.             figsize: 图表大小
  313.             save_as: 保存文件名
  314.             **kwargs: 其他参数
  315.             
  316.         返回:
  317.             matplotlib图表对象
  318.         """
  319.         try:
  320.             # 创建图表
  321.             fig, ax = plt.subplots(figsize=figsize)
  322.             
  323.             # 绘制散点图
  324.             scatter = sns.scatterplot(
  325.                 data=data,
  326.                 x=x,
  327.                 y=y,
  328.                 hue=hue,
  329.                 palette=palette,
  330.                 size=size,
  331.                 ax=ax,
  332.                 **kwargs
  333.             )
  334.             
  335.             # 设置标题和标签
  336.             if title:
  337.                 ax.set_title(title)
  338.             if xlabel:
  339.                 ax.set_xlabel(xlabel)
  340.             if ylabel:
  341.                 ax.set_ylabel(ylabel)
  342.             
  343.             # 添加网格线
  344.             ax.grid(True, linestyle='--', alpha=0.7)
  345.             
  346.             # 如果有分组变量,调整图例
  347.             if hue:
  348.                 plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
  349.             
  350.             # 调整布局
  351.             plt.tight_layout()
  352.             
  353.             # 保存图表
  354.             if save_as:
  355.                 self.save_figure(fig, save_as)
  356.             
  357.             return fig
  358.             
  359.         except Exception as e:
  360.             self.logger.error(f"绘制散点图时出错: {e}")
  361.             return None
  362.    
  363.     def plot_heatmap(self, data, title=None, cmap='viridis', annot=True, fmt='.2f',
  364.                     figsize=(12, 10), save_as=None, **kwargs):
  365.         """绘制热力图
  366.         
  367.         参数:
  368.             data: DataFrame或矩阵
  369.             title: 图表标题
  370.             cmap: 颜色映射
  371.             annot: 是否显示数值
  372.             fmt: 数值格式
  373.             figsize: 图表大小
  374.             save_as: 保存文件名
  375.             **kwargs: 其他参数
  376.             
  377.         返回:
  378.             matplotlib图表对象
  379.         """
  380.         try:
  381.             # 创建图表
  382.             fig, ax = plt.subplots(figsize=figsize)
  383.             
  384.             # 绘制热力图
  385.             heatmap = sns.heatmap(
  386.                 data,
  387.                 cmap=cmap,
  388.                 annot=annot,
  389.                 fmt=fmt,
  390.                 linewidths=.5,
  391.                 ax=ax,
  392.                 **kwargs
  393.             )
  394.             
  395.             # 设置标题
  396.             if title:
  397.                 ax.set_title(title)
  398.             
  399.             # 调整布局
  400.             plt.tight_layout()
  401.             
  402.             # 保存图表
  403.             if save_as:
  404.                 self.save_figure(fig, save_as)
  405.             
  406.             return fig
  407.             
  408.         except Exception as e:
  409.             self.logger.error(f"绘制热力图时出错: {e}")
  410.             return None
  411.    
  412.     def plot_box(self, data, x=None, y=None, title=None, xlabel=None, ylabel=None,
  413.                hue=None, palette='Set3', figsize=(12, 8), save_as=None, **kwargs):
  414.         """绘制箱线图
  415.         
  416.         参数:
  417.             data: DataFrame
  418.             x: x轴列名
  419.             y: y轴列名
  420.             title: 图表标题
  421.             xlabel: x轴标签
  422.             ylabel: y轴标签
  423.             hue: 分组变量
  424.             palette: 颜色调色板
  425.             figsize: 图表大小
  426.             save_as: 保存文件名
  427.             **kwargs: 其他参数
  428.             
  429.         返回:
  430.             matplotlib图表对象
  431.         """
  432.         try:
  433.             # 创建图表
  434.             fig, ax = plt.subplots(figsize=figsize)
  435.             
  436.             # 绘制箱线图
  437.             sns.boxplot(
  438.                 data=data,
  439.                 x=x,
  440.                 y=y,
  441.                 hue=hue,
  442.                 palette=palette,
  443.                 ax=ax,
  444.                 **kwargs
  445.             )
  446.             
  447.             # 设置标题和标签
  448.             if title:
  449.                 ax.set_title(title)
  450.             if xlabel:
  451.                 ax.set_xlabel(xlabel)
  452.             if ylabel:
  453.                 ax.set_ylabel(ylabel)
  454.             
  455.             # 如果有分组变量,调整图例
  456.             if hue:
  457.                 plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
  458.             
  459.             # 调整布局
  460.             plt.tight_layout()
  461.             
  462.             # 保存图表
  463.             if save_as:
  464.                 self.save_figure(fig, save_as)
  465.             
  466.             return fig
  467.             
  468.         except Exception as e:
  469.             self.logger.error(f"绘制箱线图时出错: {e}")
  470.             return None
  471.    
  472.     def plot_multiple_charts(self, data, chart_configs, title=None, figsize=(15, 10),
  473.                            nrows=None, ncols=None, save_as=None):
  474.         """绘制多个子图
  475.         
  476.         参数:
  477.             data: DataFrame
  478.             chart_configs: 子图配置列表,每个配置是一个字典,包含:
  479.                 - 'type': 图表类型 ('bar', 'line', 'scatter', 'hist', 'box', 'pie')
  480.                 - 'x', 'y': 数据列名
  481.                 - 'title': 子图标题
  482.                 - 其他特定图表类型的参数
  483.             title: 总标题
  484.             figsize: 图表大小
  485.             nrows: 行数,如果为None则自动计算
  486.             ncols: 列数,如果为None则自动计算
  487.             save_as: 保存文件名
  488.             
  489.         返回:
  490.             matplotlib图表对象
  491.         """
  492.         try:
  493.             # 确定子图布局
  494.             n_charts = len(chart_configs)
  495.             
  496.             if nrows is None and ncols is None:
  497.                 # 自动计算行列数
  498.                 ncols = min(3, n_charts)
  499.                 nrows = (n_charts + ncols - 1) // ncols
  500.             elif nrows is None:
  501.                 nrows = (n_charts + ncols - 1) // ncols
  502.             elif ncols is None:
  503.                 ncols = (n_charts + nrows - 1) // nrows
  504.             
  505.             # 创建图表
  506.             fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
  507.             
  508.             # 确保axes是二维数组
  509.             if nrows == 1 and ncols == 1:
  510.                 axes = np.array([[axes]])
  511.             elif nrows == 1:
  512.                 axes = axes.reshape(1, -1)
  513.             elif ncols == 1:
  514.                 axes = axes.reshape(-1, 1)
  515.             
  516.             # 绘制每个子图
  517.             for i, config in enumerate(chart_configs):
  518.                 if i >= nrows * ncols:
  519.                     self.logger.warning(f"子图数量超过布局容量,跳过第{i+1}个子图")
  520.                     break
  521.                
  522.                 # 获取当前子图的轴
  523.                 row, col = i // ncols, i % ncols
  524.                 ax = axes[row, col]
  525.                
  526.                 # 根据类型绘制不同的图表
  527.                 chart_type = config.get('type', 'bar').lower()
  528.                
  529.                 if chart_type == 'bar':
  530.                     sns.barplot(
  531.                         data=data,
  532.                         x=config.get('x'),
  533.                         y=config.get('y'),
  534.                         hue=config.get('hue'),
  535.                         palette=config.get('palette', 'viridis'),
  536.                         ax=ax
  537.                     )
  538.                 elif chart_type == 'line':
  539.                     if isinstance(config.get('y'), list):
  540.                         for y_col in config.get('y'):
  541.                             data.plot(
  542.                                 x=config.get('x'),
  543.                                 y=y_col,
  544.                                 ax=ax,
  545.                                 label=y_col
  546.                             )
  547.                     else:
  548.                         data.plot(
  549.                             x=config.get('x'),
  550.                             y=config.get('y'),
  551.                             ax=ax
  552.                         )
  553.                 elif chart_type == 'scatter':
  554.                     sns.scatterplot(
  555.                         data=data,
  556.                         x=config.get('x'),
  557.                         y=config.get('y'),
  558.                         hue=config.get('hue'),
  559.                         palette=config.get('palette', 'viridis'),
  560.                         ax=ax
  561.                     )
  562.                 elif chart_type == 'hist':
  563.                     sns.histplot(
  564.                         data=data,
  565.                         x=config.get('x'),
  566.                         bins=config.get('bins', 30),
  567.                         kde=config.get('kde', True),
  568.                         ax=ax
  569.                     )
  570.                 elif chart_type == 'box':
  571.                     sns.boxplot(
  572.                         data=data,
  573.                         x=config.get('x'),
  574.                         y=config.get('y'),
  575.                         hue=config.get('hue'),
  576.                         palette=config.get('palette', 'viridis'),
  577.                         ax=ax
  578.                     )
  579.                 elif chart_type == 'pie':
  580.                     # 饼图需要特殊处理
  581.                     values = data[config.get('values')].values
  582.                     names = data[config.get('names')].values
  583.                     ax.pie(
  584.                         values,
  585.                         labels=names,
  586.                         autopct='%1.1f%%',
  587.                         startangle=90
  588.                     )
  589.                     ax.axis('equal')
  590.                
  591.                 # 设置子图标题和标签
  592.                 if 'title' in config:
  593.                     ax.set_title(config['title'])
  594.                 if 'xlabel' in config:
  595.                     ax.set_xlabel(config['xlabel'])
  596.                 if 'ylabel' in config:
  597.                     ax.set_ylabel(config['ylabel'])
  598.             
  599.             # 隐藏多余的子图
  600.             for i in range(n_charts, nrows * ncols):
  601.                 row, col = i // ncols, i % ncols
  602.                 fig.delaxes(axes[row, col])
  603.             
  604.             # 设置总标题
  605.             if title:
  606.                 fig.suptitle(title, fontsize=16)
  607.                 plt.subplots_adjust(top=0.9)
  608.             
  609.             # 调整布局
  610.             plt.tight_layout()
  611.             
  612.             # 保存图表
  613.             if save_as:
  614.                 self.save_figure(fig, save_as)
  615.             
  616.             return fig
  617.             
  618.         except Exception as e:
  619.             self.logger.error(f"绘制多个子图时出错: {e}")
  620.             return None
  621. # 使用示例
  622. def static_visualization_example():
  623.     """静态可视化示例"""
  624.     # 创建示例数据
  625.     np.random.seed(42)
  626.     n_samples = 200
  627.    
  628.     # 生成特征
  629.     X = np.random.randn(n_samples, 3)  # 3个特征
  630.    
  631.     # 生成目标变量(回归)
  632.     y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  633.    
  634.     # 创建DataFrame
  635.     data = pd.DataFrame(
  636.         X,
  637.         columns=['feature_1', 'feature_2', 'feature_3']
  638.     )
  639.     data['target'] = y
  640.    
  641.     # 添加一些派生列
  642.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  643.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
  644.     data['sales_per_customer'] = data['target'] / np.random.poisson(10, n_samples)
  645.    
  646.     # 创建可视化器
  647.     visualizer = StaticVisualizer(output_dir='visualizations/static')
  648.    
  649.     # 1. 绘制条形图 - 按月份的销售额
  650.     monthly_sales = data.groupby('month')['target'].sum().reset_index()
  651.     monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
  652.                                            categories=['Jan', 'Feb', 'Mar', 'Apr'],
  653.                                            ordered=True)
  654.     monthly_sales = monthly_sales.sort_values('month')
  655.    
  656.     visualizer.plot_bar_chart(
  657.         data=monthly_sales,
  658.         x='month',
  659.         y='target',
  660.         title='Monthly Sales',
  661.         xlabel='Month',
  662.         ylabel='Total Sales',
  663.         color='skyblue',
  664.         save_as='monthly_sales_bar.png'
  665.     )
  666.    
  667.     # 2. 绘制折线图 - 销售额和利润趋势
  668.     visualizer.plot_line_chart(
  669.         data=data,
  670.         x='feature_1',
  671.         y=['target', 'sales_per_customer'],
  672.         title='Sales and Profit Trends',
  673.         xlabel='Feature 1',
  674.         ylabel='Amount',
  675.         figsize=(14, 7),
  676.         save_as='sales_profit_trend.png'
  677.     )
  678.    
  679.     # 3. 绘制饼图 - 按区域的销售额分布
  680.     region_sales = data.groupby('day_of_week')['target'].sum().reset_index()
  681.    
  682.     visualizer.plot_pie_chart(
  683.         data=region_sales,
  684.         values='target',
  685.         names='day_of_week',
  686.         title='Sales Distribution by Region',
  687.         save_as='region_sales_pie.png'
  688.     )
  689.    
  690.     # 4. 绘制直方图 - 每位客户销售额分布
  691.     visualizer.plot_histogram(
  692.         data=data,
  693.         column='sales_per_customer',
  694.         bins=20,
  695.         title='Distribution of Sales per Customer',
  696.         xlabel='Sales per Customer',
  697.         ylabel='Frequency',
  698.         save_as='sales_per_customer_hist.png'
  699.     )
  700.    
  701.     # 5. 绘制散点图 - 客户数量与销售额的关系
  702.     visualizer.plot_scatter(
  703.         data=data,
  704.         x='feature_2',
  705.         y='target',
  706.         title='Relationship between Number of Customers and Sales',
  707.         xlabel='Number of Customers',
  708.         ylabel='Sales',
  709.         hue='month',
  710.         save_as='customers_sales_scatter.png'
  711.     )
  712.    
  713.     # 6. 绘制热力图 - 相关性矩阵
  714.     correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'target']].corr()
  715.    
  716.     visualizer.plot_heatmap(
  717.         data=correlation_matrix,
  718.         title='Correlation Matrix',
  719.         save_as='correlation_heatmap.png'
  720.     )
  721.    
  722.     # 7. 绘制箱线图 - 按区域的销售额分布
  723.     visualizer.plot_box(
  724.         data=data,
  725.         x='day_of_week',
  726.         y='target',
  727.         title='Sales Distribution by Region',
  728.         xlabel='Region',
  729.         ylabel='Sales',
  730.         save_as='region_sales_box.png'
  731.     )
  732.    
  733.     # 8. 绘制多个子图
  734.     chart_configs = [
  735.         {
  736.             'type': 'bar',
  737.             'x': 'month',
  738.             'y': 'target',
  739.             'title': 'Sales by Month'
  740.         },
  741.         {
  742.             'type': 'line',
  743.             'x': 'feature_1',
  744.             'y': 'target',
  745.             'title': 'Sales Trend'
  746.         },
  747.         {
  748.             'type': 'scatter',
  749.             'x': 'feature_2',
  750.             'y': 'target',
  751.             'title': 'Sales vs Feature 2'
  752.         },
  753.         {
  754.             'type': 'hist',
  755.             'x': 'sales_per_customer',
  756.             'title': 'Sales per Customer Distribution'
  757.         }
  758.     ]
  759.    
  760.     visualizer.plot_multiple_charts(
  761.         data=data,
  762.         chart_configs=chart_configs,
  763.         title='Sales Dashboard',
  764.         save_as='sales_dashboard.png'
  765.     )
  766.    
  767.     print("静态可视化示例完成,图表已保存到 'visualizations/static' 目录")
  768.    
  769.     return {
  770.         'sales_data': data,
  771.         'visualizer': visualizer
  772.     }
  773. if __name__ == "__main__":
  774.     static_visualization_example()
复制代码
7.2 交互式可视化

交互式可视化允许用户与图表进行交互,比方缩放、悬停查察详情、筛选数据等,重要使用Plotly和Bokeh库:
  1. import plotly.express as px
  2. import plotly.graph_objects as go
  3. from plotly.subplots import make_subplots
  4. import pandas as pd
  5. import numpy as np
  6. import logging
  7. from pathlib import Path
  8. import json
  9. import plotly.io as pio
  10. class InteractiveVisualizer:
  11.     """交互式可视化类"""
  12.    
  13.     def __init__(self, output_dir='visualizations/interactive'):
  14.         """初始化可视化器
  15.         
  16.         参数:
  17.             output_dir: 输出目录
  18.         """
  19.         self.output_dir = output_dir
  20.         self.logger = self._setup_logger()
  21.         self._setup_style()
  22.         self._ensure_output_dir()
  23.    
  24.     def _setup_logger(self):
  25.         """设置日志记录器"""
  26.         logger = logging.getLogger('InteractiveVisualizer')
  27.         logger.setLevel(logging.INFO)
  28.         
  29.         if not logger.handlers:
  30.             handler = logging.StreamHandler()
  31.             formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  32.             handler.setFormatter(formatter)
  33.             logger.addHandler(handler)
  34.         
  35.         return logger
  36.    
  37.     def _setup_style(self):
  38.         """设置可视化样式"""
  39.         # 设置Plotly模板
  40.         pio.templates.default = "plotly_white"
  41.    
  42.     def _ensure_output_dir(self):
  43.         """确保输出目录存在"""
  44.         Path(self.output_dir).mkdir(parents=True, exist_ok=True)
  45.         self.logger.info(f"输出目录: {self.output_dir}")
  46.    
  47.     def save_figure(self, fig, filename, include_plotlyjs='cdn'):
  48.         """保存图表
  49.         
  50.         参数:
  51.             fig: 图表对象
  52.             filename: 文件名
  53.             include_plotlyjs: 是否包含plotly.js
  54.         """
  55.         filepath = Path(self.output_dir) / filename
  56.         
  57.         # 保存为HTML
  58.         if filename.endswith('.html'):
  59.             fig.write_html(filepath, include_plotlyjs=include_plotlyjs)
  60.         # 保存为JSON
  61.         elif filename.endswith('.json'):
  62.             with open(filepath, 'w') as f:
  63.                 json.dump(fig.to_dict(), f)
  64.         # 保存为图像
  65.         else:
  66.             fig.write_image(filepath)
  67.         
  68.         self.logger.info(f"图表已保存: {filepath}")
  69.         
  70.         return filepath
  71.    
  72.     def plot_bar_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
  73.                       color=None, barmode='group', figsize=(900, 600),
  74.                       save_as=None, **kwargs):
  75.         """绘制交互式条形图
  76.         
  77.         参数:
  78.             data: DataFrame
  79.             x: x轴列名
  80.             y: y轴列名或列名列表
  81.             title: 图表标题
  82.             xlabel: x轴标签
  83.             ylabel: y轴标签
  84.             color: 分组变量
  85.             barmode: 条形模式 ('group', 'stack', 'relative', 'overlay')
  86.             figsize: 图表大小 (宽, 高)
  87.             save_as: 保存文件名
  88.             **kwargs: 其他参数
  89.             
  90.         返回:
  91.             plotly图表对象
  92.         """
  93.         try:
  94.             # 处理y为列表的情况
  95.             if isinstance(y, list):
  96.                 fig = go.Figure()
  97.                
  98.                 for col in y:
  99.                     fig.add_trace(go.Bar(
  100.                         x=data[x],
  101.                         y=data[col],
  102.                         name=col
  103.                     ))
  104.                
  105.                 fig.update_layout(barmode=barmode)
  106.             else:
  107.                 # 使用Plotly Express创建条形图
  108.                 fig = px.bar(
  109.                     data,
  110.                     x=x,
  111.                     y=y,
  112.                     color=color,
  113.                     barmode=barmode,
  114.                     **kwargs
  115.                 )
  116.             
  117.             # 更新布局
  118.             fig.update_layout(
  119.                 title=title,
  120.                 xaxis_title=xlabel,
  121.                 yaxis_title=ylabel,
  122.                 width=figsize[0],
  123.                 height=figsize[1],
  124.                 hovermode='closest'
  125.             )
  126.             
  127.             # 保存图表
  128.             if save_as:
  129.                 self.save_figure(fig, save_as)
  130.             
  131.             return fig
  132.             
  133.         except Exception as e:
  134.             self.logger.error(f"绘制交互式条形图时出错: {e}")
  135.             return None
  136.    
  137.     def plot_line_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
  138.                        color=None, line_shape='linear', figsize=(900, 600),
  139.                        save_as=None, **kwargs):
  140.         """绘制交互式折线图
  141.         
  142.         参数:
  143.             data: DataFrame
  144.             x: x轴列名
  145.             y: y轴列名或列名列表
  146.             title: 图表标题
  147.             xlabel: x轴标签
  148.             ylabel: y轴标签
  149.             color: 分组变量
  150.             line_shape: 线条形状 ('linear', 'spline', 'hv', 'vh', 'hvh', 'vhv')
  151.             figsize: 图表大小 (宽, 高)
  152.             save_as: 保存文件名
  153.             **kwargs: 其他参数
  154.             
  155.         返回:
  156.             plotly图表对象
  157.         """
  158.         try:
  159.             # 处理y为列表的情况
  160.             if isinstance(y, list):
  161.                 fig = go.Figure()
  162.                
  163.                 for col in y:
  164.                     fig.add_trace(go.Scatter(
  165.                         x=data[x],
  166.                         y=data[col],
  167.                         mode='lines+markers',
  168.                         name=col,
  169.                         line=dict(shape=line_shape)
  170.                     ))
  171.             else:
  172.                 # 使用Plotly Express创建折线图
  173.                 fig = px.line(
  174.                     data,
  175.                     x=x,
  176.                     y=y,
  177.                     color=color,
  178.                     line_shape=line_shape,
  179.                     **kwargs
  180.                 )
  181.                
  182.                 # 添加标记点
  183.                 fig.update_traces(mode='lines+markers')
  184.             
  185.             # 更新布局
  186.             fig.update_layout(
  187.                 title=title,
  188.                 xaxis_title=xlabel,
  189.                 yaxis_title=ylabel,
  190.                 width=figsize[0],
  191.                 height=figsize[1],
  192.                 hovermode='closest'
  193.             )
  194.             
  195.             # 保存图表
  196.             if save_as:
  197.                 self.save_figure(fig, save_as)
  198.             
  199.             return fig
  200.             
  201.         except Exception as e:
  202.             self.logger.error(f"绘制交互式折线图时出错: {e}")
  203.             return None
  204.    
  205.     def plot_pie_chart(self, data, values, names, title=None, figsize=(800, 800),
  206.                       hole=0, save_as=None, **kwargs):
  207.         """绘制交互式饼图/环形图
  208.         
  209.         参数:
  210.             data: DataFrame
  211.             values: 值列名
  212.             names: 名称列名
  213.             title: 图表标题
  214.             figsize: 图表大小 (宽, 高)
  215.             hole: 中心孔大小 (0-1),0为饼图,>0为环形图
  216.             save_as: 保存文件名
  217.             **kwargs: 其他参数
  218.             
  219.         返回:
  220.             plotly图表对象
  221.         """
  222.         try:
  223.             # 使用Plotly Express创建饼图/环形图
  224.             fig = px.pie(
  225.                 data,
  226.                 values=values,
  227.                 names=names,
  228.                 hole=hole,
  229.                 **kwargs
  230.             )
  231.             
  232.             # 更新布局
  233.             fig.update_layout(
  234.                 title=title,
  235.                 width=figsize[0],
  236.                 height=figsize[1]
  237.             )
  238.             
  239.             # 更新轨迹
  240.             fig.update_traces(
  241.                 textposition='inside',
  242.                 textinfo='percent+label',
  243.                 hoverinfo='label+percent+value'
  244.             )
  245.             
  246.             # 保存图表
  247.             if save_as:
  248.                 self.save_figure(fig, save_as)
  249.             
  250.             return fig
  251.             
  252.         except Exception as e:
  253.             self.logger.error(f"绘制交互式饼图时出错: {e}")
  254.             return None
  255.    
  256.     def plot_histogram(self, data, column, bins=30, title=None, xlabel=None, ylabel='频率',
  257.                       color=None, figsize=(900, 600), save_as=None, **kwargs):
  258.         """绘制交互式直方图
  259.         
  260.         参数:
  261.             data: DataFrame
  262.             column: 列名
  263.             bins: 分箱数量
  264.             title: 图表标题
  265.             xlabel: x轴标签
  266.             ylabel: y轴标签
  267.             color: 分组变量
  268.             figsize: 图表大小 (宽, 高)
  269.             save_as: 保存文件名
  270.             **kwargs: 其他参数
  271.             
  272.         返回:
  273.             plotly图表对象
  274.         """
  275.         try:
  276.             # 使用Plotly Express创建直方图
  277.             fig = px.histogram(
  278.                 data,
  279.                 x=column,
  280.                 color=color,
  281.                 nbins=bins,
  282.                 marginal='rug',  # 添加边缘分布
  283.                 **kwargs
  284.             )
  285.             
  286.             # 更新布局
  287.             fig.update_layout(
  288.                 title=title,
  289.                 xaxis_title=xlabel,
  290.                 yaxis_title=ylabel,
  291.                 width=figsize[0],
  292.                 height=figsize[1],
  293.                 bargap=0.1  # 条形之间的间隙
  294.             )
  295.             
  296.             # 保存图表
  297.             if save_as:
  298.                 self.save_figure(fig, save_as)
  299.             
  300.             return fig
  301.             
  302.         except Exception as e:
  303.             self.logger.error(f"绘制交互式直方图时出错: {e}")
  304.             return None
  305.    
  306.     def plot_scatter(self, data, x, y, title=None, xlabel=None, ylabel=None,
  307.                     color=None, size=None, hover_name=None, figsize=(900, 600),
  308.                     save_as=None, **kwargs):
  309.         """绘制交互式散点图
  310.         
  311.         参数:
  312.             data: DataFrame
  313.             x: x轴列名
  314.             y: y轴列名
  315.             title: 图表标题
  316.             xlabel: x轴标签
  317.             ylabel: y轴标签
  318.             color: 分组变量
  319.             size: 点大小变量
  320.             hover_name: 悬停显示的标识列
  321.             figsize: 图表大小 (宽, 高)
  322.             save_as: 保存文件名
  323.             **kwargs: 其他参数
  324.             
  325.         返回:
  326.             plotly图表对象
  327.         """
  328.         try:
  329.             # 使用Plotly Express创建散点图
  330.             fig = px.scatter(
  331.                 data,
  332.                 x=x,
  333.                 y=y,
  334.                 color=color,
  335.                 size=size,
  336.                 hover_name=hover_name,
  337.                 **kwargs
  338.             )
  339.             
  340.             # 更新布局
  341.             fig.update_layout(
  342.                 title=title,
  343.                 xaxis_title=xlabel,
  344.                 yaxis_title=ylabel,
  345.                 width=figsize[0],
  346.                 height=figsize[1],
  347.                 hovermode='closest'
  348.             )
  349.             
  350.             # 添加趋势线
  351.             if 'trendline' not in kwargs:
  352.                 fig.update_layout(
  353.                     shapes=[{
  354.                         'type': 'line',
  355.                         'x0': data[x].min(),
  356.                         'y0': data[y].min(),
  357.                         'x1': data[x].max(),
  358.                         'y1': data[y].max(),
  359.                         'line': {
  360.                             'color': 'rgba(0,0,0,0.2)',
  361.                             'width': 2,
  362.                             'dash': 'dash'
  363.                         }
  364.                     }]
  365.                 )
  366.             
  367.             # 保存图表
  368.             if save_as:
  369.                 self.save_figure(fig, save_as)
  370.             
  371.             return fig
  372.             
  373.         except Exception as e:
  374.             self.logger.error(f"绘制交互式散点图时出错: {e}")
  375.             return None
  376.    
  377.     def plot_heatmap(self, data, title=None, figsize=(900, 700),
  378.                     colorscale='Viridis', save_as=None, **kwargs):
  379.         """绘制交互式热力图
  380.         
  381.         参数:
  382.             data: DataFrame或矩阵
  383.             title: 图表标题
  384.             figsize: 图表大小 (宽, 高)
  385.             colorscale: 颜色映射
  386.             save_as: 保存文件名
  387.             **kwargs: 其他参数
  388.             
  389.         返回:
  390.             plotly图表对象
  391.         """
  392.         try:
  393.             # 创建热力图
  394.             fig = go.Figure(data=go.Heatmap(
  395.                 z=data.values,
  396.                 x=data.columns,
  397.                 y=data.index,
  398.                 colorscale=colorscale,
  399.                 **kwargs
  400.             ))
  401.             
  402.             # 更新布局
  403.             fig.update_layout(
  404.                 title=title,
  405.                 width=figsize[0],
  406.                 height=figsize[1]
  407.             )
  408.             
  409.             # 保存图表
  410.             if save_as:
  411.                 self.save_figure(fig, save_as)
  412.             
  413.             return fig
  414.             
  415.         except Exception as e:
  416.             self.logger.error(f"绘制交互式热力图时出错: {e}")
  417.             return None
  418.    
  419.     def plot_box(self, data, x=None, y=None, title=None, xlabel=None, ylabel=None,
  420.                color=None, figsize=(900, 600), save_as=None, **kwargs):
  421.         """绘制交互式箱线图
  422.         
  423.         参数:
  424.             data: DataFrame
  425.             x: x轴列名
  426.             y: y轴列名
  427.             title: 图表标题
  428.             xlabel: x轴标签
  429.             ylabel: y轴标签
  430.             color: 分组变量
  431.             figsize: 图表大小 (宽, 高)
  432.             save_as: 保存文件名
  433.             **kwargs: 其他参数
  434.             
  435.         返回:
  436.             plotly图表对象
  437.         """
  438.         try:
  439.             # 使用Plotly Express创建箱线图
  440.             fig = px.box(
  441.                 data,
  442.                 x=x,
  443.                 y=y,
  444.                 color=color,
  445.                 **kwargs
  446.             )
  447.             
  448.             # 更新布局
  449.             fig.update_layout(
  450.                 title=title,
  451.                 xaxis_title=xlabel,
  452.                 yaxis_title=ylabel,
  453.                 width=figsize[0],
  454.                 height=figsize[1]
  455.             )
  456.             
  457.             # 保存图表
  458.             if save_as:
  459.                 self.save_figure(fig, save_as)
  460.             
  461.             return fig
  462.             
  463.         except Exception as e:
  464.             self.logger.error(f"绘制交互式箱线图时出错: {e}")
  465.             return None
  466.    
  467.     def plot_bubble(self, data, x, y, size, title=None, xlabel=None, ylabel=None,
  468.                    color=None, hover_name=None, figsize=(900, 600), save_as=None, **kwargs):
  469.         """绘制交互式气泡图
  470.         
  471.         参数:
  472.             data: DataFrame
  473.             x: x轴列名
  474.             y: y轴列名
  475.             size: 气泡大小列名
  476.             title: 图表标题
  477.             xlabel: x轴标签
  478.             ylabel: y轴标签
  479.             color: 分组变量
  480.             hover_name: 悬停显示的标识列
  481.             figsize: 图表大小 (宽, 高)
  482.             save_as: 保存文件名
  483.             **kwargs: 其他参数
  484.             
  485.         返回:
  486.             plotly图表对象
  487.         """
  488.         try:
  489.             # 使用Plotly Express创建气泡图
  490.             fig = px.scatter(
  491.                 data,
  492.                 x=x,
  493.                 y=y,
  494.                 size=size,
  495.                 color=color,
  496.                 hover_name=hover_name,
  497.                 **kwargs
  498.             )
  499.             
  500.             # 更新布局
  501.             fig.update_layout(
  502.                 title=title,
  503.                 xaxis_title=xlabel,
  504.                 yaxis_title=ylabel,
  505.                 width=figsize[0],
  506.                 height=figsize[1],
  507.                 hovermode='closest'
  508.             )
  509.             
  510.             # 保存图表
  511.             if save_as:
  512.                 self.save_figure(fig, save_as)
  513.             
  514.             return fig
  515.             
  516.         except Exception as e:
  517.             self.logger.error(f"绘制交互式气泡图时出错: {e}")
  518.             return None
  519.    
  520.     def plot_3d_scatter(self, data, x, y, z, title=None, xlabel=None, ylabel=None, zlabel=None,
  521.                        color=None, size=None, hover_name=None, figsize=(900, 700),
  522.                        save_as=None, **kwargs):
  523.         """绘制交互式3D散点图
  524.         
  525.         参数:
  526.             data: DataFrame
  527.             x: x轴列名
  528.             y: y轴列名
  529.             z: z轴列名
  530.             title: 图表标题
  531.             xlabel: x轴标签
  532.             ylabel: y轴标签
  533.             zlabel: z轴标签
  534.             color: 分组变量
  535.             size: 点大小变量
  536.             hover_name: 悬停显示的标识列
  537.             figsize: 图表大小 (宽, 高)
  538.             save_as: 保存文件名
  539.             **kwargs: 其他参数
  540.             
  541.         返回:
  542.             plotly图表对象
  543.         """
  544.         try:
  545.             # 使用Plotly Express创建3D散点图
  546.             fig = px.scatter_3d(
  547.                 data,
  548.                 x=x,
  549.                 y=y,
  550.                 z=z,
  551.                 color=color,
  552.                 size=size,
  553.                 hover_name=hover_name,
  554.                 **kwargs
  555.             )
  556.             
  557.             # 更新布局
  558.             fig.update_layout(
  559.                 title=title,
  560.                 scene=dict(
  561.                     xaxis_title=xlabel,
  562.                     yaxis_title=ylabel,
  563.                     zaxis_title=zlabel
  564.                 ),
  565.                 width=figsize[0],
  566.                 height=figsize[1]
  567.             )
  568.             
  569.             # 保存图表
  570.             if save_as:
  571.                 self.save_figure(fig, save_as)
  572.             
  573.             return fig
  574.             
  575.         except Exception as e:
  576.             self.logger.error(f"绘制交互式3D散点图时出错: {e}")
  577.             return None
  578.    
  579.     def plot_choropleth_map(self, data, locations, color, title=None,
  580.                            location_mode='ISO-3', figsize=(900, 600),
  581.                            colorscale='Viridis', save_as=None, **kwargs):
  582.         """绘制交互式地理热力图
  583.         
  584.         参数:
  585.             data: DataFrame
  586.             locations: 地理位置列名
  587.             color: 颜色值列名
  588.             title: 图表标题
  589.             location_mode: 地理位置模式 ('ISO-3', 'country names', 等)
  590.             figsize: 图表大小 (宽, 高)
  591.             colorscale: 颜色映射
  592.             save_as: 保存文件名
  593.             **kwargs: 其他参数
  594.             
  595.         返回:
  596.             plotly图表对象
  597.         """
  598.         try:
  599.             # 使用Plotly Express创建地理热力图
  600.             fig = px.choropleth(
  601.                 data,
  602.                 locations=locations,
  603.                 color=color,
  604.                 locationmode=location_mode,
  605.                 color_continuous_scale=colorscale,
  606.                 **kwargs
  607.             )
  608.             
  609.             # 更新布局
  610.             fig.update_layout(
  611.                 title=title,
  612.                 width=figsize[0],
  613.                 height=figsize[1],
  614.                 geo=dict(
  615.                     showframe=False,
  616.                     showcoastlines=True,
  617.                     projection_type='equirectangular'
  618.                 )
  619.             )
  620.             
  621.             # 保存图表
  622.             if save_as:
  623.                 self.save_figure(fig, save_as)
  624.             
  625.             return fig
  626.             
  627.         except Exception as e:
  628.             self.logger.error(f"绘制交互式地理热力图时出错: {e}")
  629.             return None
  630.     def plot_multiple_charts(self, chart_configs, title=None, figsize=(1000, 800),
  631.                            rows=None, cols=None, subplot_titles=None, save_as=None):
  632.         """绘制多个子图
  633.         
  634.         参数:
  635.             chart_configs: 子图配置列表,每个配置是一个字典,包含:
  636.                 - 'data': 数据
  637.                 - 'type': 图表类型 ('bar', 'line', 'scatter', 'pie', 'box', 'heatmap', 等)
  638.                 - 'x', 'y': 数据列名
  639.                 - 'row', 'col': 子图位置
  640.                 - 其他特定图表类型的参数
  641.             title: 总标题
  642.             figsize: 图表大小 (宽, 高)
  643.             rows: 行数
  644.             cols: 列数
  645.             subplot_titles: 子图标题列表
  646.             save_as: 保存文件名
  647.             
  648.         返回:
  649.             plotly图表对象
  650.         """
  651.         try:
  652.             # 确定子图布局
  653.             if rows is None or cols is None:
  654.                 # 查找最大的row和col值
  655.                 max_row = max([config.get('row', 1) for config in chart_configs])
  656.                 max_col = max([config.get('col', 1) for config in chart_configs])
  657.                 rows = max(rows or 0, max_row)
  658.                 cols = max(cols or 0, max_col)
  659.             
  660.             # 创建子图
  661.             fig = make_subplots(
  662.                 rows=rows,
  663.                 cols=cols,
  664.                 subplot_titles=subplot_titles,
  665.                 specs=[[{"type": "xy"} for _ in range(cols)] for _ in range(rows)]
  666.             )
  667.             
  668.             # 添加每个子图
  669.             for config in chart_configs:
  670.                 data = config.get('data')
  671.                 chart_type = config.get('type', 'scatter').lower()
  672.                 row = config.get('row', 1)
  673.                 col = config.get('col', 1)
  674.                
  675.                 if chart_type == 'bar':
  676.                     trace = go.Bar(
  677.                         x=data[config.get('x')],
  678.                         y=data[config.get('y')],
  679.                         name=config.get('name', config.get('y')),
  680.                         marker_color=config.get('color')
  681.                     )
  682.                 elif chart_type == 'line':
  683.                     trace = go.Scatter(
  684.                         x=data[config.get('x')],
  685.                         y=data[config.get('y')],
  686.                         mode='lines+markers',
  687.                         name=config.get('name', config.get('y')),
  688.                         line=dict(color=config.get('color'))
  689.                     )
  690.                 elif chart_type == 'scatter':
  691.                     trace = go.Scatter(
  692.                         x=data[config.get('x')],
  693.                         y=data[config.get('y')],
  694.                         mode='markers',
  695.                         name=config.get('name', config.get('y')),
  696.                         marker=dict(
  697.                             color=config.get('color'),
  698.                             size=config.get('size', 10)
  699.                         )
  700.                     )
  701.                 elif chart_type == 'pie':
  702.                     trace = go.Pie(
  703.                         values=data[config.get('values')],
  704.                         labels=data[config.get('names')],
  705.                         name=config.get('name', '')
  706.                     )
  707.                 elif chart_type == 'box':
  708.                     trace = go.Box(
  709.                         x=data[config.get('x')] if 'x' in config else None,
  710.                         y=data[config.get('y')],
  711.                         name=config.get('name', config.get('y'))
  712.                     )
  713.                 elif chart_type == 'heatmap':
  714.                     # 热力图需要特殊处理
  715.                     if isinstance(data, pd.DataFrame):
  716.                         z_data = data.values
  717.                         x_data = data.columns
  718.                         y_data = data.index
  719.                     else:
  720.                         z_data = data
  721.                         x_data = config.get('x')
  722.                         y_data = config.get('y')
  723.                     
  724.                     trace = go.Heatmap(
  725.                         z=z_data,
  726.                         x=x_data,
  727.                         y=y_data,
  728.                         colorscale=config.get('colorscale', 'Viridis')
  729.                     )
  730.                 else:
  731.                     self.logger.warning(f"未知的图表类型: {chart_type}")
  732.                     continue
  733.                
  734.                 fig.add_trace(trace, row=row, col=col)
  735.                
  736.                 # 更新轴标签
  737.                 if 'xlabel' in config:
  738.                     fig.update_xaxes(title_text=config['xlabel'], row=row, col=col)
  739.                 if 'ylabel' in config:
  740.                     fig.update_yaxes(title_text=config['ylabel'], row=row, col=col)
  741.             
  742.             # 更新布局
  743.             fig.update_layout(
  744.                 title=title,
  745.                 width=figsize[0],
  746.                 height=figsize[1],
  747.                 showlegend=True
  748.             )
  749.             
  750.             # 保存图表
  751.             if save_as:
  752.                 self.save_figure(fig, save_as)
  753.             
  754.             return fig
  755.             
  756.         except Exception as e:
  757.             self.logger.error(f"绘制多个子图时出错: {e}")
  758.             return None
  759. # 使用示例
  760. def interactive_visualization_example():
  761.     """交互式可视化示例"""
  762.     # 创建示例数据
  763.     np.random.seed(42)
  764.     n_samples = 200
  765.    
  766.     # 生成特征
  767.     X = np.random.randn(n_samples, 3)  # 3个特征
  768.    
  769.     # 生成目标变量(回归)
  770.     y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  771.    
  772.     # 创建DataFrame
  773.     data = pd.DataFrame(
  774.         X,
  775.         columns=['feature_1', 'feature_2', 'feature_3']
  776.     )
  777.     data['target'] = y
  778.    
  779.     # 添加一些派生列
  780.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  781.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
  782.     data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
  783.     data['sales'] = data['target'] * 100 + 500
  784.     data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
  785.     data['customers'] = np.random.poisson(50, n_samples)
  786.     data['sales_per_customer'] = data['sales'] / data['customers']
  787.    
  788.     # 创建一些国家数据
  789.     countries = ['USA', 'CAN', 'MEX', 'BRA', 'ARG', 'GBR', 'FRA', 'DEU', 'ITA', 'ESP',
  790.                 'RUS', 'CHN', 'JPN', 'IND', 'AUS']
  791.     country_codes = ['USA', 'CAN', 'MEX', 'BRA', 'ARG', 'GBR', 'FRA', 'DEU', 'ITA', 'ESP',
  792.                     'RUS', 'CHN', 'JPN', 'IND', 'AUS']
  793.     country_data = pd.DataFrame({
  794.         'country': countries,
  795.         'code': country_codes,
  796.         'gdp': np.random.uniform(100, 1000, len(countries)),
  797.         'population': np.random.uniform(10, 500, len(countries))
  798.     })
  799.    
  800.     # 创建可视化器
  801.     visualizer = InteractiveVisualizer(output_dir='visualizations/interactive')
  802.    
  803.     # 1. 绘制交互式条形图 - 按月份的销售额
  804.     monthly_sales = data.groupby('month')['sales'].sum().reset_index()
  805.     monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
  806.                                            categories=['Jan', 'Feb', 'Mar', 'Apr'],
  807.                                            ordered=True)
  808.     monthly_sales = monthly_sales.sort_values('month')
  809.    
  810.     bar_fig = visualizer.plot_bar_chart(
  811.         data=monthly_sales,
  812.         x='month',
  813.         y='sales',
  814.         title='Monthly Sales',
  815.         xlabel='Month',
  816.         ylabel='Total Sales',
  817.         save_as='monthly_sales_bar.html'
  818.     )
  819.    
  820.     # 2. 绘制交互式折线图 - 销售额和利润趋势
  821.     line_fig = visualizer.plot_line_chart(
  822.         data=data.sort_values('feature_1').iloc[:50],  # 使用部分数据
  823.         x='feature_1',
  824.         y=['sales', 'profit'],
  825.         title='Sales and Profit Trends',
  826.         xlabel='Feature 1',
  827.         ylabel='Amount',
  828.         save_as='sales_profit_trend.html'
  829.     )
  830.    
  831.     # 3. 绘制交互式饼图 - 按区域的销售额分布
  832.     region_sales = data.groupby('region')['sales'].sum().reset_index()
  833.    
  834.     pie_fig = visualizer.plot_pie_chart(
  835.         data=region_sales,
  836.         values='sales',
  837.         names='region',
  838.         title='Sales Distribution by Region',
  839.         save_as='region_sales_pie.html'
  840.     )
  841.    
  842.     # 4. 绘制交互式环形图 - 按星期几的销售额分布
  843.     day_sales = data.groupby('day_of_week')['sales'].sum().reset_index()
  844.    
  845.     donut_fig = visualizer.plot_pie_chart(
  846.         data=day_sales,
  847.         values='sales',
  848.         names='day_of_week',
  849.         title='Sales Distribution by Day of Week',
  850.         hole=0.4,  # 环形图
  851.         save_as='day_sales_donut.html'
  852.     )
  853.    
  854.     # 5. 绘制交互式直方图 - 每位客户销售额分布
  855.     hist_fig = visualizer.plot_histogram(
  856.         data=data,
  857.         column='sales_per_customer',
  858.         bins=20,
  859.         title='Distribution of Sales per Customer',
  860.         xlabel='Sales per Customer',
  861.         ylabel='Frequency',
  862.         color='region',  # 按区域分组
  863.         save_as='sales_per_customer_hist.html'
  864.     )
  865.    
  866.     # 6. 绘制交互式散点图 - 客户数量与销售额的关系
  867.     scatter_fig = visualizer.plot_scatter(
  868.         data=data,
  869.         x='customers',
  870.         y='sales',
  871.         title='Relationship between Number of Customers and Sales',
  872.         xlabel='Number of Customers',
  873.         ylabel='Sales',
  874.         color='region',
  875.         size='profit',  # 使用利润作为点大小
  876.         hover_name='month',  # 悬停显示月份
  877.         save_as='customers_sales_scatter.html'
  878.     )
  879.    
  880.     # 7. 绘制交互式热力图 - 相关性矩阵
  881.     correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
  882.    
  883.     heatmap_fig = visualizer.plot_heatmap(
  884.         data=correlation_matrix,
  885.         title='Correlation Matrix',
  886.         save_as='correlation_heatmap.html'
  887.     )
  888.    
  889.     # 8. 绘制交互式箱线图 - 按区域的销售额分布
  890.     box_fig = visualizer.plot_box(
  891.         data=data,
  892.         x='region',
  893.         y='sales',
  894.         title='Sales Distribution by Region',
  895.         xlabel='Region',
  896.         ylabel='Sales',
  897.         color='region',
  898.         save_as='region_sales_box.html'
  899.     )
  900.    
  901.     # 9. 绘制交互式气泡图 - 特征与销售额和利润的关系
  902.     bubble_fig = visualizer.plot_bubble(
  903.         data=data,
  904.         x='feature_1',
  905.         y='feature_2',
  906.         size='sales',
  907.         color='region',
  908.         title='Feature Relationships with Sales',
  909.         xlabel='Feature 1',
  910.         ylabel='Feature 2',
  911.         hover_name='month',
  912.         save_as='feature_sales_bubble.html'
  913.     )
  914.    
  915.     # 10. 绘制交互式3D散点图 - 三个特征的关系
  916.     scatter_3d_fig = visualizer.plot_3d_scatter(
  917.         data=data,
  918.         x='feature_1',
  919.         y='feature_2',
  920.         z='feature_3',
  921.         color='sales',
  922.         size='profit',
  923.         title='3D Relationship between Features',
  924.         xlabel='Feature 1',
  925.         ylabel='Feature 2',
  926.         zlabel='Feature 3',
  927.         save_as='features_3d_scatter.html'
  928.     )
  929.    
  930.     # 11. 绘制交互式地理热力图 - 国家GDP分布
  931.     choropleth_fig = visualizer.plot_choropleth_map(
  932.         data=country_data,
  933.         locations='code',
  934.         color='gdp',
  935.         title='GDP by Country',
  936.         location_mode='ISO-3',
  937.         color_continuous_scale='Viridis',
  938.         save_as='country_gdp_map.html'
  939.     )
  940.    
  941.     # 12. 绘制多个子图 - 销售仪表盘
  942.     chart_configs = [
  943.         {
  944.             'data': monthly_sales,
  945.             'type': 'bar',
  946.             'x': 'month',
  947.             'y': 'sales',
  948.             'row': 1,
  949.             'col': 1,
  950.             'name': 'Monthly Sales',
  951.             'xlabel': 'Month',
  952.             'ylabel': 'Sales'
  953.         },
  954.         {
  955.             'data': data.sort_values('feature_1').iloc[:50],
  956.             'type': 'line',
  957.             'x': 'feature_1',
  958.             'y': 'sales',
  959.             'row': 1,
  960.             'col': 2,
  961.             'name': 'Sales Trend',
  962.             'xlabel': 'Feature 1',
  963.             'ylabel': 'Sales'
  964.         },
  965.         {
  966.             'data': data,
  967.             'type': 'scatter',
  968.             'x': 'customers',
  969.             'y': 'sales',
  970.             'row': 2,
  971.             'col': 1,
  972.             'name': 'Customers vs Sales',
  973.             'xlabel': 'Customers',
  974.             'ylabel': 'Sales'
  975.         },
  976.         {
  977.             'data': correlation_matrix,
  978.             'type': 'heatmap',
  979.             'row': 2,
  980.             'col': 2,
  981.             'name': 'Correlation'
  982.         }
  983.     ]
  984.    
  985.     subplot_titles = ['Monthly Sales', 'Sales Trend', 'Customers vs Sales', 'Correlation Matrix']
  986.    
  987.     dashboard_fig = visualizer.plot_multiple_charts(
  988.         chart_configs=chart_configs,
  989.         title='Sales Dashboard',
  990.         rows=2,
  991.         cols=2,
  992.         subplot_titles=subplot_titles,
  993.         save_as='sales_dashboard.html'
  994.     )
  995.    
  996.     print("交互式可视化示例完成,图表已保存到 'visualizations/interactive' 目录")
  997.    
  998.     return {
  999.         'data': data,
  1000.         'country_data': country_data,
  1001.         'visualizer': visualizer,
  1002.         'figures': {
  1003.             'bar': bar_fig,
  1004.             'line': line_fig,
  1005.             'pie': pie_fig,
  1006.             'donut': donut_fig,
  1007.             'hist': hist_fig,
  1008.             'scatter': scatter_fig,
  1009.             'heatmap': heatmap_fig,
  1010.             'box': box_fig,
  1011.             'bubble': bubble_fig,
  1012.             'scatter_3d': scatter_3d_fig,
  1013.             'choropleth': choropleth_fig,
  1014.             'dashboard': dashboard_fig
  1015.         }
  1016.     }
  1017. if __name__ == "__main__":
  1018.     interactive_visualization_example()
复制代码
交互式仪表盘功能

  1. # 交互式仪表盘模块
  2. import dash
  3. from dash import dcc, html
  4. from dash.dependencies import Input, Output
  5. import plotly.express as px
  6. import plotly.graph_objects as go
  7. from plotly.subplots import make_subplots
  8. import pandas as pd
  9. import numpy as np
  10. import os
  11. from pathlib import Path
  12. import logging
  13. class DashboardBuilder:
  14.     """交互式仪表盘构建器"""
  15.    
  16.     def __init__(self, title="数据分析仪表盘", theme="plotly_white"):
  17.         """初始化仪表盘构建器
  18.         
  19.         参数:
  20.             title: 仪表盘标题
  21.             theme: 仪表盘主题
  22.         """
  23.         self.title = title
  24.         self.theme = theme
  25.         self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
  26.         self.app.title = title
  27.         self.visualizer = InteractiveVisualizer()
  28.         self.logger = logging.getLogger(__name__)
  29.         
  30.         # 设置Plotly主题
  31.         pio.templates.default = theme
  32.    
  33.     def create_layout(self, components):
  34.         """创建仪表盘布局
  35.         
  36.         参数:
  37.             components: 组件列表,每个组件是一个字典,包含:
  38.                 - 'type': 组件类型 ('graph', 'table', 'control', 等)
  39.                 - 'id': 组件ID
  40.                 - 'title': 组件标题
  41.                 - 'width': 组件宽度 (1-12)
  42.                 - 其他特定组件类型的参数
  43.                
  44.         返回:
  45.             Dash应用布局
  46.         """
  47.         try:
  48.             # 创建页面布局
  49.             layout = html.Div([
  50.                 # 标题
  51.                 html.H1(self.title, style={'textAlign': 'center', 'marginBottom': 30}),
  52.                
  53.                 # 内容容器
  54.                 html.Div([
  55.                     # 为每个组件创建一个Div
  56.                     html.Div([
  57.                         # 组件标题
  58.                         html.H3(component.get('title', f"Component {i+1}"),
  59.                               style={'marginBottom': 15}),
  60.                         
  61.                         # 根据组件类型创建不同的内容
  62.                         self._create_component(component)
  63.                     ], className=f"col-{component.get('width', 12)}",
  64.                        style={'padding': '10px'})
  65.                     
  66.                     for i, component in enumerate(components)
  67.                 ], className='row')
  68.             ], className='container-fluid')
  69.             
  70.             self.app.layout = layout
  71.             return layout
  72.             
  73.         except Exception as e:
  74.             self.logger.error(f"创建仪表盘布局时出错: {e}")
  75.             return html.Div(f"创建仪表盘布局时出错: {e}")
  76.    
  77.     def _create_component(self, component):
  78.         """根据组件类型创建组件
  79.         
  80.         参数:
  81.             component: 组件配置字典
  82.             
  83.         返回:
  84.             Dash组件
  85.         """
  86.         try:
  87.             component_type = component.get('type', '').lower()
  88.             component_id = component.get('id', f"component-{id(component)}")
  89.             
  90.             if component_type == 'graph':
  91.                 # 创建图表组件
  92.                 return dcc.Graph(
  93.                     id=component_id,
  94.                     figure=component.get('figure', {}),
  95.                     style={'height': component.get('height', 400)}
  96.                 )
  97.                
  98.             elif component_type == 'table':
  99.                 # 创建表格组件
  100.                 data = component.get('data', pd.DataFrame())
  101.                 return html.Div([
  102.                     dash.dash_table.DataTable(
  103.                         id=component_id,
  104.                         columns=[{"name": i, "id": i} for i in data.columns],
  105.                         data=data.to_dict('records'),
  106.                         page_size=component.get('page_size', 10),
  107.                         style_table={'overflowX': 'auto'},
  108.                         style_cell={
  109.                             'textAlign': 'left',
  110.                             'padding': '10px',
  111.                             'minWidth': '100px', 'width': '150px', 'maxWidth': '300px',
  112.                             'whiteSpace': 'normal',
  113.                             'height': 'auto'
  114.                         },
  115.                         style_header={
  116.                             'backgroundColor': 'rgb(230, 230, 230)',
  117.                             'fontWeight': 'bold'
  118.                         }
  119.                     )
  120.                 ])
  121.                
  122.             elif component_type == 'control':
  123.                 # 创建控制组件
  124.                 control_subtype = component.get('control_type', '').lower()
  125.                
  126.                 if control_subtype == 'dropdown':
  127.                     return dcc.Dropdown(
  128.                         id=component_id,
  129.                         options=[{'label': str(opt), 'value': opt}
  130.                                 for opt in component.get('options', [])],
  131.                         value=component.get('value'),
  132.                         multi=component.get('multi', False),
  133.                         placeholder=component.get('placeholder', 'Select an option')
  134.                     )
  135.                     
  136.                 elif control_subtype == 'slider':
  137.                     return dcc.Slider(
  138.                         id=component_id,
  139.                         min=component.get('min', 0),
  140.                         max=component.get('max', 100),
  141.                         step=component.get('step', 1),
  142.                         value=component.get('value', 50),
  143.                         marks={i: str(i) for i in range(
  144.                             component.get('min', 0),
  145.                             component.get('max', 100) + 1,
  146.                             component.get('mark_step', 10)
  147.                         )}
  148.                     )
  149.                     
  150.                 elif control_subtype == 'radio':
  151.                     return dcc.RadioItems(
  152.                         id=component_id,
  153.                         options=[{'label': str(opt), 'value': opt}
  154.                                 for opt in component.get('options', [])],
  155.                         value=component.get('value'),
  156.                         inline=component.get('inline', True)
  157.                     )
  158.                     
  159.                 elif control_subtype == 'checklist':
  160.                     return dcc.Checklist(
  161.                         id=component_id,
  162.                         options=[{'label': str(opt), 'value': opt}
  163.                                 for opt in component.get('options', [])],
  164.                         value=component.get('value', []),
  165.                         inline=component.get('inline', True)
  166.                     )
  167.                     
  168.                 elif control_subtype == 'date':
  169.                     return dcc.DatePickerSingle(
  170.                         id=component_id,
  171.                         date=component.get('date'),
  172.                         min_date_allowed=component.get('min_date'),
  173.                         max_date_allowed=component.get('max_date')
  174.                     )
  175.                     
  176.                 elif control_subtype == 'daterange':
  177.                     return dcc.DatePickerRange(
  178.                         id=component_id,
  179.                         start_date=component.get('start_date'),
  180.                         end_date=component.get('end_date'),
  181.                         min_date_allowed=component.get('min_date'),
  182.                         max_date_allowed=component.get('max_date')
  183.                     )
  184.                
  185.                 else:
  186.                     return html.Div(f"未知的控制类型: {control_subtype}")
  187.                
  188.             elif component_type == 'text':
  189.                 # 创建文本组件
  190.                 return html.Div([
  191.                     html.P(component.get('text', ''),
  192.                           style={'fontSize': component.get('font_size', 16)})
  193.                 ])
  194.                
  195.             elif component_type == 'html':
  196.                 # 创建自定义HTML组件
  197.                 return html.Div([
  198.                     html.Div(component.get('html', ''),
  199.                             dangerously_set_inner_html=True)
  200.                 ])
  201.                
  202.             else:
  203.                 return html.Div(f"未知的组件类型: {component_type}")
  204.                
  205.         except Exception as e:
  206.             self.logger.error(f"创建组件时出错: {e}")
  207.             return html.Div(f"创建组件时出错: {e}")
  208.    
  209.     def add_callback(self, outputs, inputs, state=None):
  210.         """添加回调函数
  211.         
  212.         参数:
  213.             outputs: 输出组件列表,每个元素是一个元组 (component_id, component_property)
  214.             inputs: 输入组件列表,每个元素是一个元组 (component_id, component_property)
  215.             state: 状态组件列表,每个元素是一个元组 (component_id, component_property)
  216.             
  217.         返回:
  218.             装饰器函数
  219.         """
  220.         try:
  221.             # 转换为Dash输出格式
  222.             dash_outputs = [Output(component_id, component_property)
  223.                           for component_id, component_property in outputs]
  224.             
  225.             # 转换为Dash输入格式
  226.             dash_inputs = [Input(component_id, component_property)
  227.                          for component_id, component_property in inputs]
  228.             
  229.             # 转换为Dash状态格式
  230.             dash_state = []
  231.             if state:
  232.                 dash_state = [dash.dependencies.State(component_id, component_property)
  233.                             for component_id, component_property in state]
  234.             
  235.             # 返回Dash回调装饰器
  236.             return self.app.callback(dash_outputs, dash_inputs, dash_state)
  237.             
  238.         except Exception as e:
  239.             self.logger.error(f"添加回调函数时出错: {e}")
  240.             return None
  241.    
  242.     def run_server(self, debug=True, port=8050, host='0.0.0.0'):
  243.         """运行仪表盘服务器
  244.         
  245.         参数:
  246.             debug: 是否启用调试模式
  247.             port: 服务器端口
  248.             host: 服务器主机
  249.         """
  250.         try:
  251.             self.app.run_server(debug=debug, port=port, host=host)
  252.         except Exception as e:
  253.             self.logger.error(f"运行仪表盘服务器时出错: {e}")
  254. # 使用示例
  255. def interactive_dashboard_example():
  256.     """交互式仪表盘示例"""
  257.     # 创建示例数据
  258.     np.random.seed(42)
  259.     n_samples = 200
  260.    
  261.     # 生成特征
  262.     X = np.random.randn(n_samples, 3)  # 3个特征
  263.    
  264.     # 生成目标变量(回归)
  265.     y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  266.    
  267.     # 创建DataFrame
  268.     data = pd.DataFrame(
  269.         X,
  270.         columns=['feature_1', 'feature_2', 'feature_3']
  271.     )
  272.     data['target'] = y
  273.    
  274.     # 添加一些派生列
  275.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  276.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
  277.     data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
  278.     data['sales'] = data['target'] * 100 + 500
  279.     data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
  280.     data['customers'] = np.random.poisson(50, n_samples)
  281.     data['sales_per_customer'] = data['sales'] / data['customers']
  282.    
  283.     # 创建可视化器
  284.     visualizer = InteractiveVisualizer()
  285.    
  286.     # 创建一些图表
  287.     monthly_sales = data.groupby('month')['sales'].sum().reset_index()
  288.     monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
  289.                                            categories=['Jan', 'Feb', 'Mar', 'Apr'],
  290.                                            ordered=True)
  291.     monthly_sales = monthly_sales.sort_values('month')
  292.    
  293.     bar_fig = visualizer.plot_bar_chart(
  294.         data=monthly_sales,
  295.         x='month',
  296.         y='sales',
  297.         title='Monthly Sales',
  298.         xlabel='Month',
  299.         ylabel='Total Sales'
  300.     )
  301.    
  302.     region_sales = data.groupby('region')['sales'].sum().reset_index()
  303.     pie_fig = visualizer.plot_pie_chart(
  304.         data=region_sales,
  305.         values='sales',
  306.         names='region',
  307.         title='Sales Distribution by Region'
  308.     )
  309.    
  310.     scatter_fig = visualizer.plot_scatter(
  311.         data=data,
  312.         x='customers',
  313.         y='sales',
  314.         title='Relationship between Number of Customers and Sales',
  315.         xlabel='Number of Customers',
  316.         ylabel='Sales',
  317.         color='region',
  318.         size='profit'
  319.     )
  320.    
  321.     correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
  322.     heatmap_fig = visualizer.plot_heatmap(
  323.         data=correlation_matrix,
  324.         title='Correlation Matrix'
  325.     )
  326.    
  327.     # 创建仪表盘构建器
  328.     dashboard = DashboardBuilder(title="销售数据分析仪表盘")
  329.    
  330.     # 定义仪表盘组件
  331.     components = [
  332.         {
  333.             'type': 'control',
  334.             'id': 'region-filter',
  335.             'title': '区域筛选',
  336.             'control_type': 'dropdown',
  337.             'options': ['All'] + list(data['region'].unique()),
  338.             'value': 'All',
  339.             'width': 3
  340.         },
  341.         {
  342.             'type': 'control',
  343.             'id': 'month-filter',
  344.             'title': '月份筛选',
  345.             'control_type': 'checklist',
  346.             'options': list(data['month'].unique()),
  347.             'value': list(data['month'].unique()),
  348.             'width': 9
  349.         },
  350.         {
  351.             'type': 'graph',
  352.             'id': 'monthly-sales-chart',
  353.             'title': '月度销售额',
  354.             'figure': bar_fig,
  355.             'width': 6,
  356.             'height': 400
  357.         },
  358.         {
  359.             'type': 'graph',
  360.             'id': 'region-sales-chart',
  361.             'title': '区域销售额分布',
  362.             'figure': pie_fig,
  363.             'width': 6,
  364.             'height': 400
  365.         },
  366.         {
  367.             'type': 'graph',
  368.             'id': 'customer-sales-chart',
  369.             'title': '客户数量与销售额关系',
  370.             'figure': scatter_fig,
  371.             'width': 6,
  372.             'height': 400
  373.         },
  374.         {
  375.             'type': 'graph',
  376.             'id': 'correlation-matrix',
  377.             'title': '相关性矩阵',
  378.             'figure': heatmap_fig,
  379.             'width': 6,
  380.             'height': 400
  381.         },
  382.         {
  383.             'type': 'table',
  384.             'id': 'sales-table',
  385.             'title': '销售数据表',
  386.             'data': data[['month', 'region', 'sales', 'profit', 'customers']].head(10),
  387.             'width': 12,
  388.             'page_size': 10
  389.         }
  390.     ]
  391.    
  392.     # 创建仪表盘布局
  393.     dashboard.create_layout(components)
  394.    
  395.     # 添加回调函数 - 区域筛选
  396.     @dashboard.add_callback(
  397.         outputs=[('sales-table', 'data')],
  398.         inputs=[('region-filter', 'value'), ('month-filter', 'value')]
  399.     )
  400.     def update_table(region, months):
  401.         filtered_data = data.copy()
  402.         
  403.         # 筛选区域
  404.         if region != 'All':
  405.             filtered_data = filtered_data[filtered_data['region'] == region]
  406.         
  407.         # 筛选月份
  408.         if months:
  409.             filtered_data = filtered_data[filtered_data['month'].isin(months)]
  410.         
  411.         return [filtered_data[['month', 'region', 'sales', 'profit', 'customers']].head(10).to_dict('records')]
  412.    
  413.     # 添加回调函数 - 更新图表
  414.     @dashboard.add_callback(
  415.         outputs=[
  416.             ('monthly-sales-chart', 'figure'),
  417.             ('region-sales-chart', 'figure'),
  418.             ('customer-sales-chart', 'figure')
  419.         ],
  420.         inputs=[('region-filter', 'value'), ('month-filter', 'value')]
  421.     )
  422.     def update_charts(region, months):
  423.         filtered_data = data.copy()
  424.         
  425.         # 筛选区域
  426.         if region != 'All':
  427.             filtered_data = filtered_data[filtered_data['region'] == region]
  428.         
  429.         # 筛选月份
  430.         if months:
  431.             filtered_data = filtered_data[filtered_data['month'].isin(months)]
  432.         
  433.         # 更新月度销售额图表
  434.         monthly_sales = filtered_data.groupby('month')['sales'].sum().reset_index()
  435.         monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
  436.                                               categories=['Jan', 'Feb', 'Mar', 'Apr'],
  437.                                               ordered=True)
  438.         monthly_sales = monthly_sales.sort_values('month')
  439.         
  440.         bar_fig = visualizer.plot_bar_chart(
  441.             data=monthly_sales,
  442.             x='month',
  443.             y='sales',
  444.             title='Monthly Sales',
  445.             xlabel='Month',
  446.             ylabel='Total Sales'
  447.         )
  448.         
  449.         # 更新区域销售额分布图表
  450.         region_sales = filtered_data.groupby('region')['sales'].sum().reset_index()
  451.         pie_fig = visualizer.plot_pie_chart(
  452.             data=region_sales,
  453.             values='sales',
  454.             names='region',
  455.             title='Sales Distribution by Region'
  456.         )
  457.         
  458.         # 更新客户数量与销售额关系图表
  459.         scatter_fig = visualizer.plot_scatter(
  460.             data=filtered_data,
  461.             x='customers',
  462.             y='sales',
  463.             title='Relationship between Number of Customers and Sales',
  464.             xlabel='Number of Customers',
  465.             ylabel='Sales',
  466.             color='region',
  467.             size='profit'
  468.         )
  469.         
  470.         return [bar_fig, pie_fig, scatter_fig]
  471.    
  472.     # 运行仪表盘
  473.     print("启动交互式仪表盘,请访问 http://127.0.0.1:8050/")
  474.     dashboard.run_server(debug=True)
  475. if __name__ == "__main__":
  476.     interactive_dashboard_example()
  477. # 可视化模块整合
  478. class VisualizationManager:
  479.     """可视化管理器,整合静态和交互式可视化"""
  480.    
  481.     def __init__(self, output_dir='visualizations'):
  482.         """初始化可视化管理器
  483.         
  484.         参数:
  485.             output_dir: 输出目录
  486.         """
  487.         # 创建静态和交互式可视化器
  488.         self.static_visualizer = StaticVisualizer(output_dir=os.path.join(output_dir, 'static'))
  489.         self.interactive_visualizer = InteractiveVisualizer(output_dir=os.path.join(output_dir, 'interactive'))
  490.         self.output_dir = output_dir
  491.         self.logger = logging.getLogger(__name__)
  492.         
  493.         # 确保输出目录存在
  494.         os.makedirs(output_dir, exist_ok=True)
  495.    
  496.     def create_visualization(self, data, chart_type, static=True, interactive=True, **kwargs):
  497.         """创建可视化图表
  498.         
  499.         参数:
  500.             data: 输入数据
  501.             chart_type: 图表类型 ('bar', 'line', 'scatter', 'pie', 'box', 'heatmap', 等)
  502.             static: 是否创建静态图表
  503.             interactive: 是否创建交互式图表
  504.             **kwargs: 其他参数
  505.             
  506.         返回:
  507.             字典,包含静态和交互式图表对象
  508.         """
  509.         try:
  510.             result = {}
  511.             
  512.             # 根据图表类型选择相应的方法
  513.             method_name = f"plot_{chart_type}"
  514.             
  515.             # 创建静态图表
  516.             if static and hasattr(self.static_visualizer, method_name):
  517.                 static_method = getattr(self.static_visualizer, method_name)
  518.                 static_fig = static_method(data=data, **kwargs)
  519.                 result['static'] = static_fig
  520.             
  521.             # 创建交互式图表
  522.             if interactive and hasattr(self.interactive_visualizer, method_name):
  523.                 interactive_method = getattr(self.interactive_visualizer, method_name)
  524.                 interactive_fig = interactive_method(data=data, **kwargs)
  525.                 result['interactive'] = interactive_fig
  526.             
  527.             return result
  528.             
  529.         except Exception as e:
  530.             self.logger.error(f"创建可视化图表时出错: {e}")
  531.             return {}
  532.    
  533.     def create_dashboard(self, data, config, title="数据可视化仪表盘"):
  534.         """创建交互式仪表盘
  535.         
  536.         参数:
  537.             data: 输入数据
  538.             config: 仪表盘配置,包含组件列表
  539.             title: 仪表盘标题
  540.             
  541.         返回:
  542.             DashboardBuilder对象
  543.         """
  544.         try:
  545.             # 创建仪表盘构建器
  546.             dashboard = DashboardBuilder(title=title)
  547.             
  548.             # 创建组件
  549.             components = []
  550.             
  551.             for component_config in config:
  552.                 component_type = component_config.get('type')
  553.                
  554.                 if component_type == 'graph':
  555.                     # 创建图表组件
  556.                     chart_type = component_config.get('chart_type')
  557.                     chart_params = component_config.get('params', {})
  558.                     
  559.                     # 创建图表
  560.                     chart_result = self.create_visualization(
  561.                         data=component_config.get('data', data),
  562.                         chart_type=chart_type,
  563.                         static=False,
  564.                         interactive=True,
  565.                         **chart_params
  566.                     )
  567.                     
  568.                     # 添加到组件列表
  569.                     if 'interactive' in chart_result:
  570.                         components.append({
  571.                             'type': 'graph',
  572.                             'id': component_config.get('id', f"graph-{len(components)}"),
  573.                             'title': component_config.get('title', f"{chart_type.capitalize()} Chart"),
  574.                             'figure': chart_result['interactive'],
  575.                             'width': component_config.get('width', 6),
  576.                             'height': component_config.get('height', 400)
  577.                         })
  578.                
  579.                 elif component_type in ['control', 'table', 'text', 'html']:
  580.                     # 直接添加其他类型的组件
  581.                     components.append(component_config)
  582.             
  583.             # 创建仪表盘布局
  584.             dashboard.create_layout(components)
  585.             
  586.             return dashboard
  587.             
  588.         except Exception as e:
  589.             self.logger.error(f"创建仪表盘时出错: {e}")
  590.             return None
  591.    
  592.     def export_visualizations(self, visualizations, format='html'):
  593.         """导出可视化图表
  594.         
  595.         参数:
  596.             visualizations: 可视化图表字典
  597.             format: 导出格式 ('html', 'png', 'pdf', 等)
  598.             
  599.         返回:
  600.             导出文件路径列表
  601.         """
  602.         try:
  603.             export_paths = []
  604.             
  605.             for name, viz_dict in visualizations.items():
  606.                 # 导出静态图表
  607.                 if 'static' in viz_dict and viz_dict['static'] is not None:
  608.                     static_path = os.path.join(self.output_dir, 'exports', 'static', f"{name}.{format}")
  609.                     os.makedirs(os.path.dirname(static_path), exist_ok=True)
  610.                     
  611.                     if format == 'html':
  612.                         # 对于Matplotlib图表,需要先保存为图像
  613.                         temp_path = os.path.join(self.output_dir, 'exports', 'static', f"{name}.png")
  614.                         viz_dict['static'].savefig(temp_path)
  615.                         
  616.                         # 创建HTML包装
  617.                         with open(static_path, 'w') as f:
  618.                             f.write(f"""
  619.                             <html>
  620.                             <head><title>{name} - Static Visualization</title></head>
  621.                             <body>
  622.                                 <h1>{name}</h1>
  623.                                 <img src="{os.path.basename(temp_path)}" alt="{name}">
  624.                             </body>
  625.                             </html>
  626.                             """)
  627.                     else:
  628.                         viz_dict['static'].savefig(static_path)
  629.                     
  630.                     export_paths.append(static_path)
  631.                
  632.                 # 导出交互式图表
  633.                 if 'interactive' in viz_dict and viz_dict['interactive'] is not None:
  634.                     interactive_path = os.path.join(self.output_dir, 'exports', 'interactive', f"{name}.html")
  635.                     os.makedirs(os.path.dirname(interactive_path), exist_ok=True)
  636.                     
  637.                     # 保存Plotly图表
  638.                     viz_dict['interactive'].write_html(interactive_path)
  639.                     export_paths.append(interactive_path)
  640.             
  641.             return export_paths
  642.             
  643.         except Exception as e:
  644.             self.logger.error(f"导出可视化图表时出错: {e}")
  645.             return []
  646. # 使用示例
  647. def visualization_manager_example():
  648.     """可视化管理器示例"""
  649.     # 创建示例数据
  650.     np.random.seed(42)
  651.     n_samples = 200
  652.    
  653.     # 生成特征
  654.     X = np.random.randn(n_samples, 3)  # 3个特征
  655.    
  656.     # 生成目标变量(回归)
  657.     y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
  658.    
  659.     # 创建DataFrame
  660.     data = pd.DataFrame(
  661.         X,
  662.         columns=['feature_1', 'feature_2', 'feature_3']
  663.     )
  664.     data['target'] = y
  665.    
  666.     # 添加一些派生列
  667.     data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
  668.     data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
  669.     data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
  670.     data['sales'] = data['target'] * 100 + 500
  671.     data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
  672.     data['customers'] = np.random.poisson(50, n_samples)
  673.     data['sales_per_customer'] = data['sales'] / data['customers']
  674.    
  675.     # 创建可视化管理器
  676.     viz_manager = VisualizationManager(output_dir='visualizations')
  677.    
  678.     # 创建各种图表
  679.     visualizations = {}
  680.    
  681.     # 1. 条形图
  682.     monthly_sales = data.groupby('month')['sales'].sum().reset_index()
  683.     monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
  684.                                            categories=['Jan', 'Feb', 'Mar', 'Apr'],
  685.                                            ordered=True)
  686.     monthly_sales = monthly_sales.sort_values('month')
  687.    
  688.     bar_charts = viz_manager.create_visualization(
  689.         data=monthly_sales,
  690.         chart_type='bar_chart',
  691.         x='month',
  692.         y='sales',
  693.         title='Monthly Sales',
  694.         xlabel='Month',
  695.         ylabel='Total Sales',
  696.         save_as='monthly_sales'
  697.     )
  698.     visualizations['monthly_sales'] = bar_charts
  699.    
  700.     # 2. 散点图
  701.     scatter_charts = viz_manager.create_visualization(
  702.         data=data,
  703.         chart_type='scatter',
  704.         x='customers',
  705.         y='sales',
  706.         title='Relationship between Number of Customers and Sales',
  707.         xlabel='Number of Customers',
  708.         ylabel='Sales',
  709.         color='region',
  710.         size='profit',
  711.         save_as='customers_sales'
  712.     )
  713.     visualizations['customers_sales'] = scatter_charts
  714.    
  715.     # 3. 热力图
  716.     correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
  717.     heatmap_charts = viz_manager.create_visualization(
  718.         data=correlation_matrix,
  719.         chart_type='heatmap',
  720.         title='Correlation Matrix',
  721.         save_as='correlation_matrix'
  722.     )
  723.     visualizations['correlation_matrix'] = heatmap_charts
  724.    
  725.     # 导出可视化图表
  726.     export_paths = viz_manager.export_visualizations(visualizations)
  727.     print(f"导出的可视化图表: {export_paths}")
  728.    
  729.     # 创建仪表盘
  730.     dashboard_config = [
  731.         {
  732.             'type': 'control',
  733.             'id': 'region-filter',
  734.             'title': '区域筛选',
  735.             'control_type': 'dropdown',
  736.             'options': ['All'] + list(data['region'].unique()),
  737.             'value': 'All',
  738.             'width': 3
  739.         },
  740.         {
  741.             'type': 'graph',
  742.             'id': 'monthly-sales-chart',
  743.             'title': '月度销售额',
  744.             'chart_type': 'bar_chart',
  745.             'data': monthly_sales,
  746.             'params': {
  747.                 'x': 'month',
  748.                 'y': 'sales',
  749.                 'title': 'Monthly Sales',
  750.                 'xlabel': 'Month',
  751.                 'ylabel': 'Total Sales'
  752.             },
  753.             'width': 6
  754.         },
  755.         {
  756.             'type': 'graph',
  757.             'id': 'customer-sales-chart',
  758.             'title': '客户数量与销售额关系',
  759.             'chart_type': 'scatter',
  760.             'params': {
  761.                 'x': 'customers',
  762.                 'y': 'sales',
  763.                 'title': 'Relationship between Number of Customers and Sales',
  764.                 'xlabel': 'Number of Customers',
  765.                 'ylabel': 'Sales',
  766.                 'color': 'region',
  767.                 'size': 'profit'
  768.             },
  769.             'width': 6
  770.         },
  771.         {
  772.             'type': 'graph',
  773.             'id': 'correlation-matrix',
  774.             'title': '相关性矩阵',
  775.             'chart_type': 'heatmap',
  776.             'data': correlation_matrix,
  777.             'params': {
  778.                 'title': 'Correlation Matrix'
  779.             },
  780.             'width': 12
  781.         }
  782.     ]
  783.    
  784.     dashboard = viz_manager.create_dashboard(data, dashboard_config, title="销售数据分析仪表盘")
  785.    
  786.     if dashboard:
  787.         print("创建仪表盘成功,运行 dashboard.run_server() 启动仪表盘")
  788.    
  789.     return {
  790.         'data': data,
  791.         'visualizations': visualizations,
  792.         'viz_manager': viz_manager,
  793.         'dashboard': dashboard
  794.     }
  795. if __name__ == "__main__":
  796.     visualization_manager_example()   
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

反转基因福娃

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表