马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
1. 项目简介
在当今数据驱动的期间,网络爬虫和数据可视化已成为获取、分析和展示信息的重要工具。本文将详细介绍如何使用Python构建一个完备的网络爬虫与数据可视化系统,该系统能够自动从互联网收集数据,进行处置处罚分析,并通过直观的图表展示结果。
2. 技术栈
- Python 3.8+:重要编程语言
- 网络爬虫:Requests、BeautifulSoup4、Scrapy、Selenium
- 数据处置处罚:Pandas、NumPy
- 数据可视化:Matplotlib、Seaborn、Plotly、Dash
- 数据存储:SQLite、MongoDB
- 其他工具:Jupyter Notebook、Flask
3. 系统架构
- 网络爬虫与数据可视化系统
- ├── 爬虫模块
- │ ├── 数据采集器
- │ ├── 解析器
- │ └── 数据清洗器
- ├── 数据存储模块
- │ ├── 关系型数据库接口
- │ └── NoSQL数据库接口
- ├── 数据分析模块
- │ ├── 统计分析
- │ └── 数据挖掘
- └── 可视化模块
- ├── 静态图表生成器
- ├── 交互式图表生成器
- └── Web展示界面
复制代码 4. 爬虫模块实现
4.1 根本爬虫实现
首先,我们使用Requests和BeautifulSoup构建一个简单的爬虫:
- import requests
- from bs4 import BeautifulSoup
- import pandas as pd
- class BasicScraper:
- """基础网页爬虫类"""
-
- def __init__(self, user_agent=None):
- """初始化爬虫"""
- self.session = requests.Session()
- default_ua = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
- self.headers = {'User-Agent': user_agent if user_agent else default_ua}
-
- def fetch_page(self, url, params=None):
- """获取网页内容"""
- try:
- response = self.session.get(url, headers=self.headers, params=params)
- response.raise_for_status() # 检查请求是否成功
- return response.text
- except requests.exceptions.RequestException as e:
- print(f"请求错误: {e}")
- return None
-
- def parse_html(self, html, parser='html.parser'):
- """解析HTML内容"""
- if html:
- return BeautifulSoup(html, parser)
- return None
-
- def extract_data(self, soup, selectors):
- """提取数据
-
- 参数:
- soup: BeautifulSoup对象
- selectors: 字典,键为数据名称,值为CSS选择器
-
- 返回:
- pandas.DataFrame: 提取的数据
- """
- data = {}
- for key, selector in selectors.items():
- elements = soup.select(selector)
- data[key] = [element.text.strip() for element in elements]
-
- # 确保所有列的长度一致
- max_length = max([len(v) for v in data.values()]) if data else 0
- for key in data:
- if len(data[key]) < max_length:
- data[key].extend([None] * (max_length - len(data[key])))
-
- return pd.DataFrame(data)
-
- def scrape(self, url, selectors, params=None):
- """执行完整的爬取过程"""
- html = self.fetch_page(url, params)
- if not html:
- return pd.DataFrame()
-
- soup = self.parse_html(html)
- if not soup:
- return pd.DataFrame()
-
- return self.extract_data(soup, selectors)
- # 使用示例
- def scrape_books_example():
- scraper = BasicScraper()
- url = "http://books.toscrape.com/"
- selectors = {
- "title": ".product_pod h3 a",
- "price": ".price_color",
- "rating": ".star-rating",
- "availability": ".availability"
- }
-
- # 爬取数据
- books_data = scraper.scrape(url, selectors)
-
- # 数据清洗
- if not books_data.empty:
- # 处理价格 - 移除货币符号并转换为浮点数
- books_data['price'] = books_data['price'].str.replace('£', '').astype(float)
-
- # 处理评分 - 从类名中提取星级
- books_data['rating'] = books_data['rating'].apply(lambda x: x.split()[1] + ' stars' if x else None)
-
- # 处理库存状态
- books_data['availability'] = books_data['availability'].str.strip()
-
- return books_data
- # 执行爬取
- if __name__ == "__main__":
- books = scrape_books_example()
- print(f"爬取到 {len(books)} 本书的信息")
- print(books.head())
复制代码 4.2 使用Scrapy框架构建爬虫
对于更复杂的爬虫需求,我们可以使用Scrapy框架:
- # 文件结构:
- # my_scraper/
- # ├── scrapy.cfg
- # └── my_scraper/
- # ├── __init__.py
- # ├── items.py
- # ├── middlewares.py
- # ├── pipelines.py
- # ├── settings.py
- # └── spiders/
- # ├── __init__.py
- # └── book_spider.py
- # items.py
- import scrapy
- class BookItem(scrapy.Item):
- """定义爬取的图书项目"""
- title = scrapy.Field()
- price = scrapy.Field()
- rating = scrapy.Field()
- availability = scrapy.Field()
- category = scrapy.Field()
- description = scrapy.Field()
- upc = scrapy.Field()
- image_url = scrapy.Field()
- url = scrapy.Field()
- # book_spider.py
- import scrapy
- from ..items import BookItem
- class BookSpider(scrapy.Spider):
- """图书爬虫"""
- name = 'bookspider'
- allowed_domains = ['books.toscrape.com']
- start_urls = ['http://books.toscrape.com/']
-
- def parse(self, response):
- """解析图书列表页面"""
- # 提取当前页面的所有图书
- books = response.css('article.product_pod')
-
- for book in books:
- # 获取图书详情页链接
- book_url = book.css('h3 a::attr(href)').get()
- if book_url:
- if 'catalogue/' not in book_url:
- book_url = 'catalogue/' + book_url
- book_url = response.urljoin(book_url)
- yield scrapy.Request(book_url, callback=self.parse_book)
-
- # 处理分页
- next_page = response.css('li.next a::attr(href)').get()
- if next_page:
- yield response.follow(next_page, self.parse)
-
- def parse_book(self, response):
- """解析图书详情页面"""
- book = BookItem()
-
- # 提取基本信息
- book['title'] = response.css('div.product_main h1::text').get()
- book['price'] = response.css('p.price_color::text').get()
- book['availability'] = response.css('p.availability::text').extract()[1].strip()
-
- # 提取评分
- rating_class = response.css('p.star-rating::attr(class)').get()
- if rating_class:
- book['rating'] = rating_class.split()[1]
-
- # 提取产品信息表格
- rows = response.css('table.table-striped tr')
- for row in rows:
- header = row.css('th::text').get()
- if header == 'UPC':
- book['upc'] = row.css('td::text').get()
- elif header == 'Product Type':
- book['category'] = row.css('td::text').get()
-
- # 提取描述
- book['description'] = response.css('div#product_description + p::text').get()
-
- # 提取图片URL
- image_url = response.css('div.item.active img::attr(src)').get()
- if image_url:
- book['image_url'] = response.urljoin(image_url)
-
- book['url'] = response.url
-
- yield book
- # pipelines.py (数据处理管道)
- import re
- from itemadapter import ItemAdapter
- class BookPipeline:
- """图书数据处理管道"""
-
- def process_item(self, item, spider):
- adapter = ItemAdapter(item)
-
- # 清洗价格字段
- if adapter.get('price'):
- price_str = adapter['price']
- # 提取数字并转换为浮点数
- price_match = re.search(r'(\d+\.\d+)', price_str)
- if price_match:
- adapter['price'] = float(price_match.group(1))
-
- # 标准化评分
- rating_map = {
- 'One': 1,
- 'Two': 2,
- 'Three': 3,
- 'Four': 4,
- 'Five': 5
- }
- if adapter.get('rating'):
- adapter['rating'] = rating_map.get(adapter['rating'], 0)
-
- # 处理库存信息
- if adapter.get('availability'):
- if 'In stock' in adapter['availability']:
- # 提取库存数量
- stock_match = re.search(r'(\d+)', adapter['availability'])
- if stock_match:
- adapter['availability'] = int(stock_match.group(1))
- else:
- adapter['availability'] = 'In stock'
- else:
- adapter['availability'] = 'Out of stock'
-
- return item
- # 运行爬虫的脚本 (run_spider.py)
- from scrapy.crawler import CrawlerProcess
- from scrapy.utils.project import get_project_settings
- def run_spider():
- """运行Scrapy爬虫"""
- process = CrawlerProcess(get_project_settings())
- process.crawl('bookspider')
- process.start()
- if __name__ == '__main__':
- run_spider()
复制代码 4.3 处置处罚动态网页的爬虫
对于JavaScript渲染的网页,我们必要使用Selenium:
- from selenium import webdriver
- from selenium.webdriver.chrome.options import Options
- from selenium.webdriver.chrome.service import Service
- from selenium.webdriver.common.by import By
- from selenium.webdriver.support.ui import WebDriverWait
- from selenium.webdriver.support import expected_conditions as EC
- from webdriver_manager.chrome import ChromeDriverManager
- import pandas as pd
- import time
- import logging
- class DynamicScraper:
- """动态网页爬虫类"""
-
- def __init__(self, headless=True, wait_time=10):
- """初始化爬虫
-
- 参数:
- headless: 是否使用无头模式
- wait_time: 等待元素出现的最大时间(秒)
- """
- self.wait_time = wait_time
- self.logger = self._setup_logger()
- self.driver = self._setup_driver(headless)
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('DynamicScraper')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def _setup_driver(self, headless):
- """设置WebDriver"""
- try:
- chrome_options = Options()
- if headless:
- chrome_options.add_argument("--headless")
-
- # 添加其他有用的选项
- chrome_options.add_argument("--disable-gpu")
- chrome_options.add_argument("--no-sandbox")
- chrome_options.add_argument("--disable-dev-shm-usage")
- chrome_options.add_argument("--window-size=1920,1080")
-
- # 使用webdriver_manager自动管理ChromeDriver
- service = Service(ChromeDriverManager().install())
- driver = webdriver.Chrome(service=service, options=chrome_options)
-
- return driver
- except Exception as e:
- self.logger.error(f"设置WebDriver时出错: {e}")
- raise
-
- def navigate_to(self, url):
- """导航到指定URL"""
- try:
- self.logger.info(f"正在导航到: {url}")
- self.driver.get(url)
- return True
- except Exception as e:
- self.logger.error(f"导航到 {url} 时出错: {e}")
- return False
-
- def wait_for_element(self, by, value):
- """等待元素出现
-
- 参数:
- by: 定位方式 (By.ID, By.CSS_SELECTOR 等)
- value: 定位值
-
- 返回:
- 找到的元素或None
- """
- try:
- element = WebDriverWait(self.driver, self.wait_time).until(
- EC.presence_of_element_located((by, value))
- )
- return element
- except Exception as e:
- self.logger.warning(f"等待元素 {value} 超时: {e}")
- return None
-
- def wait_for_elements(self, by, value):
- """等待多个元素出现"""
- try:
- elements = WebDriverWait(self.driver, self.wait_time).until(
- EC.presence_of_all_elements_located((by, value))
- )
- return elements
- except Exception as e:
- self.logger.warning(f"等待元素 {value} 超时: {e}")
- return []
-
- def extract_data(self, selectors):
- """从当前页面提取数据
-
- 参数:
- selectors: 字典,键为数据名称,值为(定位方式, 定位值)元组
-
- 返回:
- pandas.DataFrame: 提取的数据
- """
- data = {}
-
- for key, (by, value) in selectors.items():
- try:
- elements = self.driver.find_elements(by, value)
- data[key] = [element.text for element in elements]
- self.logger.info(f"提取了 {len(elements)} 个 '{key}' 元素")
- except Exception as e:
- self.logger.error(f"提取 '{key}' 数据时出错: {e}")
- data[key] = []
-
- # 确保所有列的长度一致
- max_length = max([len(v) for v in data.values()]) if data else 0
- for key in data:
- if len(data[key]) < max_length:
- data[key].extend([None] * (max_length - len(data[key])))
-
- return pd.DataFrame(data)
-
- def scroll_to_bottom(self, scroll_pause_time=1.0):
- """滚动到页面底部以加载更多内容"""
- self.logger.info("开始滚动页面以加载更多内容")
-
- # 获取初始页面高度
- last_height = self.driver.execute_script("return document.body.scrollHeight")
-
- while True:
- # 滚动到底部
- self.driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
-
- # 等待页面加载
- time.sleep(scroll_pause_time)
-
- # 计算新的页面高度并与上一个高度比较
- new_height = self.driver.execute_script("return document.body.scrollHeight")
- if new_height == last_height:
- # 如果高度没有变化,说明已经到底了
- break
- last_height = new_height
-
- self.logger.info("页面滚动完成")
-
- def click_element(self, by, value):
- """点击元素"""
- try:
- element = self.wait_for_element(by, value)
- if element:
- element.click()
- return True
- return False
- except Exception as e:
- self.logger.error(f"点击元素 {value} 时出错: {e}")
- return False
-
- def close(self):
- """关闭浏览器"""
- if self.driver:
- self.driver.quit()
- self.logger.info("浏览器已关闭")
- # 使用示例
- def scrape_dynamic_website_example():
- """爬取动态网站示例"""
- # 创建爬虫实例
- scraper = DynamicScraper(headless=True)
-
- try:
- # 导航到目标网站 (以SPA电商网站为例)
- url = "https://www.example-dynamic-site.com/products"
- if not scraper.navigate_to(url):
- return pd.DataFrame()
-
- # 等待页面加载完成
- scraper.wait_for_element(By.CSS_SELECTOR, ".product-grid")
-
- # 滚动页面以加载更多产品
- scraper.scroll_to_bottom(scroll_pause_time=2.0)
-
- # 定义要提取的数据选择器
- selectors = {
- "product_name": (By.CSS_SELECTOR, ".product-item .product-name"),
- "price": (By.CSS_SELECTOR, ".product-item .product-price"),
- "rating": (By.CSS_SELECTOR, ".product-item .product-rating"),
- "reviews_count": (By.CSS_SELECTOR, ".product-item .reviews-count")
- }
-
- # 提取数据
- products_data = scraper.extract_data(selectors)
-
- # 数据清洗
- if not products_data.empty:
- # 处理价格 - 移除货币符号并转换为浮点数
- products_data['price'] = products_data['price'].str.replace('$', '').str.replace(',', '').astype(float)
-
- # 处理评分 - 提取数值
- products_data['rating'] = products_data['rating'].str.extract(r'(\d\.\d)').astype(float)
-
- # 处理评论数 - 提取数值
- products_data['reviews_count'] = products_data['reviews_count'].str.extract(r'(\d+)').astype(int)
-
- return products_data
-
- finally:
- # 确保浏览器关闭
- scraper.close()
- # 执行爬取
- if __name__ == "__main__":
- products = scrape_dynamic_website_example()
- print(f"爬取到 {len(products)} 个产品的信息")
- print(products.head())
复制代码 4.4 爬虫管理器
创建一个爬虫管理器来统一调用不同类型的爬虫:
- class ScraperManager:
- """爬虫管理器,用于管理不同类型的爬虫"""
-
- def __init__(self):
- self.scrapers = {}
-
- def register_scraper(self, name, scraper_class, **kwargs):
- """注册爬虫
-
- 参数:
- name: 爬虫名称
- scraper_class: 爬虫类
- kwargs: 传递给爬虫构造函数的参数
- """
- self.scrapers[name] = (scraper_class, kwargs)
- print(f"已注册爬虫: {name}")
-
- def get_scraper(self, name):
- """获取爬虫实例"""
- if name not in self.scrapers:
- raise ValueError(f"未找到名为 '{name}' 的爬虫")
-
- scraper_class, kwargs = self.scrapers[name]
- return scraper_class(**kwargs)
-
- def run_scraper(self, name, *args, **kwargs):
- """运行指定的爬虫
-
- 参数:
- name: 爬虫名称
- args, kwargs: 传递给爬虫方法的参数
-
- 返回:
- 爬虫返回的数据
- """
- scraper = self.get_scraper(name)
-
- if hasattr(scraper, 'scrape'):
- return scraper.scrape(*args, **kwargs)
- elif hasattr(scraper, 'run'):
- return scraper.run(*args, **kwargs)
- else:
- raise AttributeError(f"爬虫 '{name}' 没有 'scrape' 或 'run' 方法")
- # 使用示例
- def scraper_manager_example():
- # 创建爬虫管理器
- manager = ScraperManager()
-
- # 注册基础爬虫
- manager.register_scraper('basic', BasicScraper)
-
- # 注册动态爬虫
- manager.register_scraper('dynamic', DynamicScraper, headless=True, wait_time=15)
-
- # 使用基础爬虫爬取数据
- url = "http://books.toscrape.com/"
- selectors = {
- "title": ".product_pod h3 a",
- "price": ".price_color",
- "rating": ".star-rating"
- }
-
- books_data = manager.run_scraper('basic', url, selectors)
-
- print(f"使用基础爬虫爬取到 {len(books_data)} 本书的信息")
-
- return books_data
- # 执行示例
- if __name__ == "__main__":
- data = scraper_manager_example()
- print(data.head())
复制代码 4.5 署理IP和请求头轮换
为了制止被目的网站封锁,我们可以实当署理IP和请求头轮换功能:
- import random
- import time
- from fake_useragent import UserAgent
- class ProxyRotator:
- """代理IP轮换器"""
-
- def __init__(self, proxies=None):
- """初始化代理轮换器
-
- 参数:
- proxies: 代理列表,格式为 [{'http': 'http://ip:port', 'https': 'https://ip:port'}, ...]
- """
- self.proxies = proxies or []
- self.current_index = 0
-
- def add_proxy(self, proxy):
- """添加代理"""
- self.proxies.append(proxy)
-
- def get_proxy(self):
- """获取下一个代理"""
- if not self.proxies:
- return None
-
- proxy = self.proxies[self.current_index]
- self.current_index = (self.current_index + 1) % len(self.proxies)
- return proxy
-
- def remove_proxy(self, proxy):
- """移除失效的代理"""
- if proxy in self.proxies:
- self.proxies.remove(proxy)
- self.current_index = self.current_index % max(1, len(self.proxies))
- class UserAgentRotator:
- """User-Agent轮换器"""
-
- def __init__(self, use_fake_ua=True):
- """初始化User-Agent轮换器"""
- self.use_fake_ua = use_fake_ua
- self.ua = UserAgent() if use_fake_ua else None
-
- # 预定义的User-Agent列表(备用)
- self.user_agents = [
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
- 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.1.1 Safari/605.1.15',
- 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0',
- 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36',
- 'Mozilla/5.0 (iPhone; CPU iPhone OS 14_6 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0 Mobile/15E148 Safari/604.1'
- ]
-
- def get_random_ua(self):
- """获取随机User-Agent"""
- if self.use_fake_ua and self.ua:
- try:
- return self.ua.random
- except:
- pass
-
- return random.choice(self.user_agents)
- class EnhancedScraper(BasicScraper):
- """增强型爬虫,支持代理和请求头轮换"""
-
- def __init__(self, proxy_rotator=None, ua_rotator=None, retry_times=3, retry_delay=2):
- """初始化增强型爬虫
-
- 参数:
- proxy_rotator: 代理轮换器
- ua_rotator: User-Agent轮换器
- retry_times: 请求失败重试次数
- retry_delay: 重试延迟时间(秒)
- """
- super().__init__()
- self.proxy_rotator = proxy_rotator or ProxyRotator()
- self.ua_rotator = ua_rotator or UserAgentRotator()
- self.retry_times = retry_times
- self.retry_delay = retry_delay
-
- def fetch_page(self, url, params=None):
- """获取网页内容,支持代理和重试"""
- for attempt in range(self.retry_times):
- try:
- # 获取代理和User-Agent
- proxy = self.proxy_rotator.get_proxy()
- user_agent = self.ua_rotator.get_random_ua()
-
- # 更新请求头
- self.headers['User-Agent'] = user_agent
-
- # 发送请求
- response = self.session.get(
- url,
- headers=self.headers,
- params=params,
- proxies=proxy,
- timeout=10
- )
-
- # 检查请求是否成功
- response.raise_for_status()
- return response.text
-
- except requests.exceptions.RequestException as e:
- print(f"请求错误 (尝试 {attempt+1}/{self.retry_times}): {e}")
-
- # 如果是代理问题,移除当前代理
- if proxy and (isinstance(e, requests.exceptions.ProxyError) or
- isinstance(e, requests.exceptions.ConnectTimeout)):
- self.proxy_rotator.remove_proxy(proxy)
-
- # 最后一次尝试失败
- if attempt == self.retry_times - 1:
- return None
-
- # 延迟后重试
- time.sleep(self.retry_delay)
-
- return None
- # 使用示例
- def enhanced_scraper_example():
- # 创建代理轮换器
- proxy_rotator = ProxyRotator([
- {'http': 'http://proxy1.example.com:8080', 'https': 'https://proxy1.example.com:8080'},
- {'http': 'http://proxy2.example.com:8080', 'https': 'https://proxy2.example.com:8080'}
- ])
-
- # 创建User-Agent轮换器
- ua_rotator = UserAgentRotator()
-
- # 创建增强型爬虫
- scraper = EnhancedScraper(proxy_rotator, ua_rotator, retry_times=3)
-
- # 爬取数据
- url = "http://books.toscrape.com/"
- selectors = {
- "title": ".product_pod h3 a",
- "price": ".price_color",
- "rating": ".star-rating"
- }
-
- books_data = scraper.scrape(url, selectors)
- return books_data
- # 执行示例
- if __name__ == "__main__":
- data = enhanced_scraper_example()
- print(f"爬取到 {len(data)} 本书的信息")
- print(data.head())
复制代码 5. 数据存储模块
数据存储模块负责将爬取的数据生存到不同类型的存储系统中,包罗关系型数据库、NoSQL数据库和文件系统。
5.1 SQLite数据库存储
SQLite是一种轻量级的关系型数据库,恰当单机应用和原型开发:
- import sqlite3
- import pandas as pd
- import os
- import logging
- import csv
- from datetime import datetime
- class SQLiteStorage:
- """SQLite数据存储类"""
-
- def __init__(self, db_path):
- """初始化SQLite数据库连接
-
- 参数:
- db_path: 数据库文件路径
- """
- self.db_path = db_path
- self.logger = self._setup_logger()
-
- # 确保数据库目录存在
- os.makedirs(os.path.dirname(os.path.abspath(db_path)), exist_ok=True)
-
- try:
- self.conn = sqlite3.connect(db_path)
- self.cursor = self.conn.cursor()
- self.logger.info(f"成功连接到SQLite数据库: {db_path}")
- except sqlite3.Error as e:
- self.logger.error(f"连接SQLite数据库时出错: {e}")
- raise
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('SQLiteStorage')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def create_table(self, table_name, columns):
- """创建数据表
-
- 参数:
- table_name: 表名
- columns: 列定义字典,键为列名,值为数据类型
- """
- try:
- # 构建CREATE TABLE语句
- columns_str = ', '.join([f"{col} {dtype}" for col, dtype in columns.items()])
- query = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_str})"
-
- # 执行SQL
- self.cursor.execute(query)
- self.conn.commit()
- self.logger.info(f"成功创建表: {table_name}")
- return True
- except sqlite3.Error as e:
- self.logger.error(f"创建表 {table_name} 时出错: {e}")
- self.conn.rollback()
- return False
-
- def insert_data(self, table_name, data):
- """插入数据
-
- 参数:
- table_name: 表名
- data: 要插入的数据,可以是DataFrame或列表
- """
- try:
- if isinstance(data, pd.DataFrame):
- # 使用pandas的to_sql方法插入DataFrame
- data.to_sql(table_name, self.conn, if_exists='append', index=False)
- self.logger.info(f"成功插入 {len(data)} 行数据到表 {table_name}")
- elif isinstance(data, list) and len(data) > 0:
- # 处理列表数据
- if isinstance(data[0], dict):
- # 字典列表
- if not data:
- return True
-
- # 获取所有键
- columns = list(data[0].keys())
-
- # 准备INSERT语句
- placeholders = ', '.join(['?'] * len(columns))
- columns_str = ', '.join(columns)
- query = f"INSERT INTO {table_name} ({columns_str}) VALUES ({placeholders})"
-
- # 准备数据
- values = [[row.get(col) for col in columns] for row in data]
-
- # 执行插入
- self.cursor.executemany(query, values)
- else:
- # 值列表
- placeholders = ', '.join(['?'] * len(data[0]))
- query = f"INSERT INTO {table_name} VALUES ({placeholders})"
- self.cursor.executemany(query, data)
-
- self.conn.commit()
- self.logger.info(f"成功插入 {len(data)} 行数据到表 {table_name}")
- else:
- self.logger.warning(f"没有数据可插入到表 {table_name}")
-
- return True
- except Exception as e:
- self.logger.error(f"插入数据到表 {table_name} 时出错: {e}")
- self.conn.rollback()
- return False
-
- def query_data(self, query, params=None):
- """执行查询
-
- 参数:
- query: SQL查询语句
- params: 查询参数(可选)
-
- 返回:
- pandas.DataFrame: 查询结果
- """
- try:
- if params:
- return pd.read_sql_query(query, self.conn, params=params)
- else:
- return pd.read_sql_query(query, self.conn)
- except Exception as e:
- self.logger.error(f"执行查询时出错: {e}")
- return pd.DataFrame()
-
- def execute_query(self, query, params=None):
- """执行任意SQL查询
-
- 参数:
- query: SQL查询语句
- params: 查询参数(可选)
-
- 返回:
- bool: 是否成功
- """
- try:
- if params:
- self.cursor.execute(query, params)
- else:
- self.cursor.execute(query)
-
- self.conn.commit()
- return True
- except Exception as e:
- self.logger.error(f"执行查询时出错: {e}")
- self.conn.rollback()
- return False
-
- def table_exists(self, table_name):
- """检查表是否存在
-
- 参数:
- table_name: 表名
-
- 返回:
- bool: 表是否存在
- """
- query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
- self.cursor.execute(query, (table_name,))
- return self.cursor.fetchone() is not None
-
- def get_table_info(self, table_name):
- """获取表信息
-
- 参数:
- table_name: 表名
-
- 返回:
- list: 表的列信息
- """
- if not self.table_exists(table_name):
- return []
-
- query = f"PRAGMA table_info({table_name})"
- return self.cursor.execute(query).fetchall()
-
- def close(self):
- """关闭数据库连接"""
- if hasattr(self, 'conn') and self.conn:
- self.conn.close()
- self.logger.info("数据库连接已关闭")
-
- def __enter__(self):
- """上下文管理器入口"""
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """上下文管理器退出"""
- self.close()
- # 使用示例
- def sqlite_example():
- # 创建SQLite存储实例
- db = SQLiteStorage('data/books.db')
-
- try:
- # 创建表
- db.create_table('books', {
- 'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
- 'title': 'TEXT NOT NULL',
- 'price': 'REAL',
- 'rating': 'INTEGER',
- 'category': 'TEXT',
- 'description': 'TEXT',
- 'created_at': 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP'
- })
-
- # 准备示例数据
- books_data = pd.DataFrame({
- 'title': ['Python编程', '数据科学入门', '机器学习实战'],
- 'price': [59.9, 69.9, 79.9],
- 'rating': [5, 4, 5],
- 'category': ['编程', '数据科学', '机器学习'],
- 'description': ['Python基础教程', '数据分析入门', '机器学习算法详解']
- })
-
- # 插入数据
- db.insert_data('books', books_data)
-
- # 查询数据
- results = db.query_data("SELECT * FROM books WHERE rating >= ?", (4,))
- print(f"查询结果: {len(results)} 行")
- print(results)
-
- return results
-
- finally:
- # 确保关闭连接
- db.close()
- if __name__ == "__main__":
- sqlite_example()
复制代码 5.2 MongoDB数据库存储
MongoDB是一种流行的NoSQL数据库,恰当存储非布局化或半布局化数据:
- import pymongo
- import pandas as pd
- import json
- import logging
- from bson import ObjectId
- from datetime import datetime
- class MongoDBStorage:
- """MongoDB数据存储类"""
-
- def __init__(self, connection_string, database_name):
- """初始化MongoDB连接
-
- 参数:
- connection_string: MongoDB连接字符串
- database_name: 数据库名称
- """
- self.connection_string = connection_string
- self.database_name = database_name
- self.logger = self._setup_logger()
-
- try:
- # 连接到MongoDB
- self.client = pymongo.MongoClient(connection_string)
- self.db = self.client[database_name]
-
- # 测试连接
- self.client.server_info()
- self.logger.info(f"成功连接到MongoDB数据库: {database_name}")
- except Exception as e:
- self.logger.error(f"连接MongoDB数据库时出错: {e}")
- raise
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('MongoDBStorage')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def _convert_to_json_serializable(self, data):
- """转换数据为JSON可序列化格式
-
- 参数:
- data: 要转换的数据
-
- 返回:
- 转换后的数据
- """
- if isinstance(data, dict):
- return {k: self._convert_to_json_serializable(v) for k, v in data.items()}
- elif isinstance(data, list):
- return [self._convert_to_json_serializable(item) for item in data]
- elif isinstance(data, (ObjectId, datetime)):
- return str(data)
- else:
- return data
-
- def insert_document(self, collection_name, document):
- """插入单个文档
-
- 参数:
- collection_name: 集合名称
- document: 要插入的文档(字典)
-
- 返回:
- 插入的文档ID
- """
- try:
- collection = self.db[collection_name]
- result = collection.insert_one(document)
- self.logger.info(f"成功插入文档到集合 {collection_name}, ID: {result.inserted_id}")
- return result.inserted_id
- except Exception as e:
- self.logger.error(f"插入文档到集合 {collection_name} 时出错: {e}")
- return None
-
- def insert_many(self, collection_name, documents):
- """插入多个文档
-
- 参数:
- collection_name: 集合名称
- documents: 要插入的文档列表
-
- 返回:
- 插入的文档ID列表
- """
- try:
- collection = self.db[collection_name]
- result = collection.insert_many(documents)
- self.logger.info(f"成功插入 {len(result.inserted_ids)} 个文档到集合 {collection_name}")
- return result.inserted_ids
- except Exception as e:
- self.logger.error(f"插入多个文档到集合 {collection_name} 时出错: {e}")
- return []
-
- def insert_dataframe(self, collection_name, df):
- """插入DataFrame数据
-
- 参数:
- collection_name: 集合名称
- df: pandas DataFrame
-
- 返回:
- bool: 是否成功
- """
- try:
- if df.empty:
- self.logger.warning(f"DataFrame为空,未插入数据到集合 {collection_name}")
- return True
-
- # 将DataFrame转换为字典列表
- records = df.to_dict('records')
-
- # 插入数据
- collection = self.db[collection_name]
- result = collection.insert_many(records)
-
- self.logger.info(f"成功插入 {len(result.inserted_ids)} 行数据到集合 {collection_name}")
- return True
- except Exception as e:
- self.logger.error(f"插入DataFrame到集合 {collection_name} 时出错: {e}")
- return False
-
- def find_documents(self, collection_name, query=None, projection=None, limit=0):
- """查询文档
-
- 参数:
- collection_name: 集合名称
- query: 查询条件(可选)
- projection: 投影字段(可选)
- limit: 结果限制数量(可选)
-
- 返回:
- pandas.DataFrame: 查询结果
- """
- try:
- collection = self.db[collection_name]
-
- # 执行查询
- if query is None:
- query = {}
-
- cursor = collection.find(query, projection)
-
- if limit > 0:
- cursor = cursor.limit(limit)
-
- # 将结果转换为列表
- results = list(cursor)
-
- # 将ObjectId转换为字符串
- for doc in results:
- if '_id' in doc:
- doc['_id'] = str(doc['_id'])
-
- # 转换为DataFrame
- if results:
- return pd.DataFrame(results)
- else:
- return pd.DataFrame()
- except Exception as e:
- self.logger.error(f"查询集合 {collection_name} 时出错: {e}")
- return pd.DataFrame()
-
- def update_document(self, collection_name, query, update_data, upsert=False):
- """更新文档
-
- 参数:
- collection_name: 集合名称
- query: 查询条件
- update_data: 更新数据
- upsert: 如果不存在是否插入
-
- 返回:
- int: 更新的文档数量
- """
- try:
- collection = self.db[collection_name]
-
- # 确保update_data使用$set操作符
- if not any(k.startswith('$') for k in update_data.keys()):
- update_data = {'$set': update_data}
-
- result = collection.update_one(query, update_data, upsert=upsert)
-
- self.logger.info(f"更新集合 {collection_name} 中的文档: 匹配 {result.matched_count}, 修改 {result.modified_count}")
- return result.modified_count
- except Exception as e:
- self.logger.error(f"更新集合 {collection_name} 中的文档时出错: {e}")
- return 0
-
- def update_many(self, collection_name, query, update_data):
- """更新多个文档
-
- 参数:
- collection_name: 集合名称
- query: 查询条件
- update_data: 更新数据
-
- 返回:
- int: 更新的文档数量
- """
- try:
- collection = self.db[collection_name]
-
- # 确保update_data使用$set操作符
- if not any(k.startswith('$') for k in update_data.keys()):
- update_data = {'$set': update_data}
-
- result = collection.update_many(query, update_data)
-
- self.logger.info(f"更新集合 {collection_name} 中的多个文档: 匹配 {result.matched_count}, 修改 {result.modified_count}")
- return result.modified_count
- except Exception as e:
- self.logger.error(f"更新集合 {collection_name} 中的多个文档时出错: {e}")
- return 0
-
- def delete_document(self, collection_name, query):
- """删除文档
-
- 参数:
- collection_name: 集合名称
- query: 查询条件
-
- 返回:
- int: 删除的文档数量
- """
- try:
- collection = self.db[collection_name]
- result = collection.delete_one(query)
-
- self.logger.info(f"从集合 {collection_name} 中删除了 {result.deleted_count} 个文档")
- return result.deleted_count
- except Exception as e:
- self.logger.error(f"从集合 {collection_name} 中删除文档时出错: {e}")
- return 0
-
- def delete_many(self, collection_name, query):
- """删除多个文档
-
- 参数:
- collection_name: 集合名称
- query: 查询条件
-
- 返回:
- int: 删除的文档数量
- """
- try:
- collection = self.db[collection_name]
- result = collection.delete_many(query)
-
- self.logger.info(f"从集合 {collection_name} 中删除了 {result.deleted_count} 个文档")
- return result.deleted_count
- except Exception as e:
- self.logger.error(f"从集合 {collection_name} 中删除多个文档时出错: {e}")
- return 0
-
- def create_index(self, collection_name, keys, **kwargs):
- """创建索引
-
- 参数:
- collection_name: 集合名称
- keys: 索引键
- **kwargs: 其他索引选项
-
- 返回:
- str: 创建的索引名称
- """
- try:
- collection = self.db[collection_name]
- index_name = collection.create_index(keys, **kwargs)
-
- self.logger.info(f"在集合 {collection_name} 上创建索引: {index_name}")
- return index_name
- except Exception as e:
- self.logger.error(f"在集合 {collection_name} 上创建索引时出错: {e}")
- return None
-
- def drop_collection(self, collection_name):
- """删除集合
-
- 参数:
- collection_name: 集合名称
-
- 返回:
- bool: 是否成功
- """
- try:
- self.db.drop_collection(collection_name)
- self.logger.info(f"成功删除集合: {collection_name}")
- return True
- except Exception as e:
- self.logger.error(f"删除集合 {collection_name} 时出错: {e}")
- return False
-
- def close(self):
- """关闭数据库连接"""
- if hasattr(self, 'client') and self.client:
- self.client.close()
- self.logger.info("MongoDB连接已关闭")
-
- def __enter__(self):
- """上下文管理器入口"""
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """上下文管理器退出"""
- self.close()
- # 使用示例
- def mongodb_example():
- # 创建MongoDB存储实例
- mongo = MongoDBStorage('mongodb://localhost:27017', 'web_scraping_db')
-
- try:
- # 准备示例数据
- products_data = pd.DataFrame({
- 'name': ['智能手机', '笔记本电脑', '平板电脑'],
- 'price': [2999, 4999, 3999],
- 'brand': ['品牌A', '品牌B', '品牌A'],
- 'features': [
- ['5G', '高清摄像头', '快速充电'],
- ['高性能CPU', '大内存', 'SSD'],
- ['触控屏', '长续航', '轻薄']
- ],
- 'in_stock': [True, False, True],
- 'last_updated': [datetime.now() for _ in range(3)]
- })
-
- # 插入DataFrame数据
- mongo.insert_dataframe('products', products_data)
-
- # 插入单个文档
- review = {
- 'product_id': '123456',
- 'user': '用户A',
- 'rating': 5,
- 'comment': '非常好用的产品',
- 'date': datetime.now()
- }
- review_id = mongo.insert_document('reviews', review)
-
- # 查询数据
- results = mongo.find_documents('products', {'brand': '品牌A'})
- print(f"查询结果: {len(results)} 行")
- print(results)
-
- # 更新数据
- mongo.update_document('products', {'name': '智能手机'}, {'$set': {'price': 2899}})
-
- # 创建索引
- mongo.create_index('products', [('name', pymongo.ASCENDING)], unique=True)
-
- return results
-
- finally:
- # 确保关闭连接
- mongo.close()
- if __name__ == "__main__":
- mongodb_example()
复制代码 5.3 CSV文件存储
CSV是一种常用的数据交换格式,恰当存储表格数据:
- import pandas as pd
- import os
- import logging
- import csv
- from datetime import datetime
- class CSVStorage:
- """CSV文件存储类"""
-
- def __init__(self, base_dir='data/csv'):
- """初始化CSV存储
-
- 参数:
- base_dir: CSV文件存储的基础目录
- """
- self.base_dir = base_dir
- self.logger = self._setup_logger()
-
- # 确保目录存在
- os.makedirs(base_dir, exist_ok=True)
- self.logger.info(f"CSV存储目录: {base_dir}")
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('CSVStorage')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def _get_file_path(self, file_name):
- """获取文件完整路径
-
- 参数:
- file_name: 文件名
-
- 返回:
- str: 文件完整路径
- """
- # 确保文件名有.csv后缀
- if not file_name.endswith('.csv'):
- file_name += '.csv'
-
- return os.path.join(self.base_dir, file_name)
-
- def save_dataframe(self, df, file_name, index=False):
- """保存DataFrame到CSV文件
-
- 参数:
- df: 要保存的DataFrame
- file_name: 文件名
- index: 是否保存索引
-
- 返回:
- bool: 是否成功
- """
- try:
- file_path = self._get_file_path(file_name)
- df.to_csv(file_path, index=index, encoding='utf-8')
- self.logger.info(f"成功保存 {len(df)} 行数据到文件: {file_path}")
- return True
- except Exception as e:
- self.logger.error(f"保存数据到文件 {file_name} 时出错: {e}")
- return False
-
- def append_dataframe(self, df, file_name, index=False):
- """追加DataFrame到CSV文件
-
- 参数:
- df: 要追加的DataFrame
- file_name: 文件名
- index: 是否保存索引
-
- 返回:
- bool: 是否成功
- """
- try:
- file_path = self._get_file_path(file_name)
-
- # 检查文件是否存在
- file_exists = os.path.isfile(file_path)
-
- # 如果文件存在,追加数据;否则创建新文件
- df.to_csv(file_path, mode='a', header=not file_exists, index=index, encoding='utf-8')
-
- self.logger.info(f"成功追加 {len(df)} 行数据到文件: {file_path}")
- return True
- except Exception as e:
- self.logger.error(f"追加数据到文件 {file_name} 时出错: {e}")
- return False
-
- def load_csv(self, file_name, **kwargs):
- """加载CSV文件到DataFrame
-
- 参数:
- file_name: 文件名
- **kwargs: 传递给pd.read_csv的参数
-
- 返回:
- pandas.DataFrame: 加载的数据
- """
- try:
- file_path = self._get_file_path(file_name)
-
- if not os.path.isfile(file_path):
- self.logger.warning(f"文件不存在: {file_path}")
- return pd.DataFrame()
-
- df = pd.read_csv(file_path, **kwargs)
- self.logger.info(f"成功从文件 {file_path} 加载 {len(df)} 行数据")
- return df
- except Exception as e:
- self.logger.error(f"从文件 {file_name} 加载数据时出错: {e}")
- return pd.DataFrame()
-
- def save_records(self, records, file_name, fieldnames=None):
- """保存记录列表到CSV文件
-
- 参数:
- records: 字典列表
- file_name: 文件名
- fieldnames: 字段名列表(可选)
-
- 返回:
- bool: 是否成功
- """
- try:
- file_path = self._get_file_path(file_name)
-
- if not records:
- self.logger.warning(f"没有记录可保存到文件: {file_path}")
- return True
-
- # 如果未提供字段名,使用第一条记录的键
- if fieldnames is None:
- fieldnames = list(records[0].keys())
-
- with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
- writer.writeheader()
- writer.writerows(records)
-
- self.logger.info(f"成功保存 {len(records)} 条记录到文件: {file_path}")
- return True
- except Exception as e:
- self.logger.error(f"保存记录到文件 {file_name} 时出错: {e}")
- return False
-
- def append_records(self, records, file_name, fieldnames=None):
- """追加记录列表到CSV文件
-
- 参数:
- records: 字典列表
- file_name: 文件名
- fieldnames: 字段名列表(可选)
-
- 返回:
- bool: 是否成功
- """
- try:
- file_path = self._get_file_path(file_name)
-
- if not records:
- self.logger.warning(f"没有记录可追加到文件: {file_path}")
- return True
-
- # 检查文件是否存在
- file_exists = os.path.isfile(file_path)
-
- # 如果未提供字段名,使用第一条记录的键
- if fieldnames is None:
- fieldnames = list(records[0].keys())
-
- with open(file_path, 'a', newline='', encoding='utf-8') as csvfile:
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
-
- # 如果文件不存在,写入表头
- if not file_exists:
- writer.writeheader()
-
- writer.writerows(records)
-
- self.logger.info(f"成功追加 {len(records)} 条记录到文件: {file_path}")
- return True
- except Exception as e:
- self.logger.error(f"追加记录到文件 {file_name} 时出错: {e}")
- return False
-
- def file_exists(self, file_name):
- """检查文件是否存在
-
- 参数:
- file_name: 文件名
-
- 返回:
- bool: 文件是否存在
- """
- file_path = self._get_file_path(file_name)
- return os.path.isfile(file_path)
-
- def list_files(self):
- """列出所有CSV文件
-
- 返回:
- list: CSV文件列表
- """
- try:
- files = [f for f in os.listdir(self.base_dir) if f.endswith('.csv')]
- self.logger.info(f"找到 {len(files)} 个CSV文件")
- return files
- except Exception as e:
- self.logger.error(f"列出CSV文件时出错: {e}")
- return []
-
- def delete_file(self, file_name):
- """删除CSV文件
-
- 参数:
- file_name: 文件名
-
- 返回:
- bool: 是否成功
- """
- try:
- file_path = self._get_file_path(file_name)
-
- if not os.path.isfile(file_path):
- self.logger.warning(f"文件不存在,无法删除: {file_path}")
- return False
-
- os.remove(file_path)
- self.logger.info(f"成功删除文件: {file_path}")
- return True
- except Exception as e:
- self.logger.error(f"删除文件 {file_name} 时出错: {e}")
- return False
- # 使用示例
- def csv_example():
- # 创建CSV存储实例
- csv_storage = CSVStorage('data/csv')
-
- # 准备示例数据
- data = pd.DataFrame({
- 'date': [datetime.now().strftime('%Y-%m-%d %H:%M:%S') for _ in range(3)],
- 'category': ['电子产品', '家居', '食品'],
- 'item_count': [120, 85, 200],
- 'average_price': [1500.75, 350.25, 45.50]
- })
-
- # 保存数据
- csv_storage.save_dataframe(data, 'inventory')
-
- # 加载数据
- loaded_data = csv_storage.load_csv('inventory')
- print(f"加载的数据: {len(loaded_data)} 行")
- print(loaded_data)
-
- # 追加数据
- new_data = pd.DataFrame({
- 'date': [datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
- 'category': ['服装'],
- 'item_count': [150],
- 'average_price': [250.00]
- })
- csv_storage.append_dataframe(new_data, 'inventory')
-
- return loaded_data
- if __name__ == "__main__":
- csv_example()
复制代码 5.4 存储工厂
创建一个存储工厂类,用于统一管理不同类型的存储:
- class StorageFactory:
- """存储工厂类,用于创建和管理不同类型的存储"""
-
- def __init__(self):
- self.storage_classes = {}
- self.storage_instances = {}
-
- def register_storage(self, storage_type, storage_class):
- """注册存储类
-
- 参数:
- storage_type: 存储类型名称
- storage_class: 存储类
- """
- self.storage_classes[storage_type] = storage_class
- print(f"已注册存储类型: {storage_type}")
-
- def get_storage(self, storage_type, **kwargs):
- """获取存储实例
-
- 参数:
- storage_type: 存储类型名称
- **kwargs: 传递给存储类构造函数的参数
-
- 返回:
- 存储实例
- """
- # 检查存储类型是否已注册
- if storage_type not in self.storage_classes:
- raise ValueError(f"未注册的存储类型: {storage_type}")
-
- # 创建存储实例的键
- instance_key = f"{storage_type}_{hash(frozenset(kwargs.items()))}"
-
- # 如果实例不存在,创建新实例
- if instance_key not in self.storage_instances:
- storage_class = self.storage_classes[storage_type]
- self.storage_instances[instance_key] = storage_class(**kwargs)
-
- return self.storage_instances[instance_key]
-
- def close_all(self):
- """关闭所有存储连接"""
- for instance_key, storage in self.storage_instances.items():
- if hasattr(storage, 'close'):
- storage.close()
-
- self.storage_instances.clear()
- print("已关闭所有存储连接")
- # 使用示例
- def storage_factory_example():
- # 创建存储工厂
- factory = StorageFactory()
-
- # 注册存储类
- factory.register_storage('sqlite', SQLiteStorage)
- factory.register_storage('mongodb', MongoDBStorage)
- factory.register_storage('csv', CSVStorage)
-
- # 获取SQLite存储实例
- sqlite_storage = factory.get_storage('sqlite', db_path='data/example.db')
-
- # 获取MongoDB存储实例
- mongo_storage = factory.get_storage('mongodb',
- connection_string='mongodb://localhost:27017',
- database_name='example_db')
-
- # 获取CSV存储实例
- csv_storage = factory.get_storage('csv', base_dir='data/csv_files')
-
- # 使用存储实例...
-
- # 关闭所有连接
- factory.close_all()
-
- return "存储工厂示例完成"
- if __name__ == "__main__":
- storage_factory_example()
复制代码 6. 数据分析模块
数据分析模块负责对爬取的数据进行清洗、转换、分析和挖掘,从而提取有代价的信息和洞察。
6.1 数据清洗与预处置处罚
数据清洗是数据分析的第一步,用于处置处罚缺失值、非常值和格式不一致的数据:
- import pandas as pd
- import numpy as np
- import re
- from datetime import datetime
- import logging
- class DataCleaner:
- """数据清洗类"""
-
- def __init__(self):
- """初始化数据清洗器"""
- self.logger = self._setup_logger()
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('DataCleaner')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def handle_missing_values(self, df, strategy='drop', fill_value=None):
- """处理缺失值
-
- 参数:
- df: 输入DataFrame
- strategy: 处理策略,可选'drop'(删除)、'fill'(填充)
- fill_value: 填充值,当strategy为'fill'时使用
-
- 返回:
- 处理后的DataFrame
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return df
-
- missing_count = df.isnull().sum().sum()
- self.logger.info(f"检测到 {missing_count} 个缺失值")
-
- if missing_count == 0:
- return df
-
- if strategy == 'drop':
- # 删除包含缺失值的行
- result = df.dropna()
- self.logger.info(f"删除了 {len(df) - len(result)} 行含有缺失值的数据")
- return result
-
- elif strategy == 'fill':
- # 填充缺失值
- if isinstance(fill_value, dict):
- # 对不同列使用不同的填充值
- result = df.fillna(fill_value)
- self.logger.info(f"使用指定值填充了缺失值: {fill_value}")
- else:
- # 使用相同的值填充所有缺失值
- result = df.fillna(fill_value)
- self.logger.info(f"使用 {fill_value} 填充了所有缺失值")
- return result
-
- else:
- self.logger.error(f"未知的缺失值处理策略: {strategy}")
- return df
-
- def remove_duplicates(self, df, subset=None):
- """删除重复行
-
- 参数:
- df: 输入DataFrame
- subset: 用于判断重复的列,默认使用所有列
-
- 返回:
- 处理后的DataFrame
- """
- if df.empty:
- return df
-
- # 删除重复行
- result = df.drop_duplicates(subset=subset)
-
- removed_count = len(df) - len(result)
- self.logger.info(f"删除了 {removed_count} 行重复数据")
-
- return result
-
- def handle_outliers(self, df, columns, method='zscore', threshold=3.0):
- """处理异常值
-
- 参数:
- df: 输入DataFrame
- columns: 要处理的列名列表
- method: 异常值检测方法,可选'zscore'、'iqr'
- threshold: 阈值,zscore方法使用
-
- 返回:
- 处理后的DataFrame
- """
- if df.empty:
- return df
-
- result = df.copy()
- outliers_count = 0
-
- for col in columns:
- if col not in df.columns:
- self.logger.warning(f"列 {col} 不存在")
- continue
-
- if not pd.api.types.is_numeric_dtype(df[col]):
- self.logger.warning(f"列 {col} 不是数值类型,跳过异常值检测")
- continue
-
- # 获取非缺失值
- values = df[col].dropna()
-
- if method == 'zscore':
- # 使用Z-score方法检测异常值
- mean = values.mean()
- std = values.std()
- if std == 0:
- self.logger.warning(f"列 {col} 的标准差为0,跳过异常值检测")
- continue
-
- z_scores = np.abs((values - mean) / std)
- outliers = values[z_scores > threshold].index
-
- elif method == 'iqr':
- # 使用IQR方法检测异常值
- q1 = values.quantile(0.25)
- q3 = values.quantile(0.75)
- iqr = q3 - q1
- lower_bound = q1 - 1.5 * iqr
- upper_bound = q3 + 1.5 * iqr
- outliers = values[(values < lower_bound) | (values > upper_bound)].index
-
- else:
- self.logger.error(f"未知的异常值检测方法: {method}")
- continue
-
- # 将异常值设为NaN
- result.loc[outliers, col] = np.nan
- outliers_count += len(outliers)
-
- self.logger.info(f"检测并处理了 {outliers_count} 个异常值")
- return result
-
- def normalize_text(self, df, text_columns):
- """文本标准化处理
-
- 参数:
- df: 输入DataFrame
- text_columns: 要处理的文本列名列表
-
- 返回:
- 处理后的DataFrame
- """
- if df.empty:
- return df
-
- result = df.copy()
-
- for col in text_columns:
- if col not in df.columns:
- self.logger.warning(f"列 {col} 不存在")
- continue
-
- if not pd.api.types.is_string_dtype(df[col]):
- self.logger.warning(f"列 {col} 不是文本类型")
- continue
-
- # 文本处理:去除多余空格、转为小写
- result[col] = df[col].str.strip().str.lower()
-
- # 去除特殊字符
- result[col] = result[col].apply(lambda x: re.sub(r'[^\w\s]', '', str(x)) if pd.notna(x) else x)
-
- self.logger.info(f"完成列 {col} 的文本标准化处理")
-
- return result
-
- def convert_data_types(self, df, type_dict):
- """转换数据类型
-
- 参数:
- df: 输入DataFrame
- type_dict: 类型转换字典,键为列名,值为目标类型
-
- 返回:
- 处理后的DataFrame
- """
- if df.empty:
- return df
-
- result = df.copy()
-
- for col, dtype in type_dict.items():
- if col not in df.columns:
- self.logger.warning(f"列 {col} 不存在")
- continue
-
- try:
- result[col] = result[col].astype(dtype)
- self.logger.info(f"将列 {col} 的类型转换为 {dtype}")
- except Exception as e:
- self.logger.error(f"转换列 {col} 的类型时出错: {e}")
-
- return result
-
- def parse_dates(self, df, date_columns, date_format=None):
- """解析日期列
-
- 参数:
- df: 输入DataFrame
- date_columns: 日期列名列表
- date_format: 日期格式字符串(可选)
-
- 返回:
- 处理后的DataFrame
- """
- if df.empty:
- return df
-
- result = df.copy()
-
- for col in date_columns:
- if col not in df.columns:
- self.logger.warning(f"列 {col} 不存在")
- continue
-
- try:
- if date_format:
- result[col] = pd.to_datetime(result[col], format=date_format)
- else:
- result[col] = pd.to_datetime(result[col])
-
- self.logger.info(f"将列 {col} 转换为日期时间类型")
- except Exception as e:
- self.logger.error(f"转换列 {col} 为日期时间类型时出错: {e}")
-
- return result
-
- def clean_data(self, df, config=None):
- """综合数据清洗
-
- 参数:
- df: 输入DataFrame
- config: 清洗配置字典
-
- 返回:
- 清洗后的DataFrame
- """
- if df.empty:
- return df
-
- if config is None:
- config = {}
-
- result = df.copy()
-
- # 处理缺失值
- if 'missing_values' in config:
- missing_config = config['missing_values']
- result = self.handle_missing_values(
- result,
- strategy=missing_config.get('strategy', 'drop'),
- fill_value=missing_config.get('fill_value')
- )
-
- # 删除重复行
- if config.get('remove_duplicates', True):
- subset = config.get('duplicate_subset')
- result = self.remove_duplicates(result, subset=subset)
-
- # 处理异常值
- if 'outliers' in config:
- outlier_config = config['outliers']
- result = self.handle_outliers(
- result,
- columns=outlier_config.get('columns', []),
- method=outlier_config.get('method', 'zscore'),
- threshold=outlier_config.get('threshold', 3.0)
- )
-
- # 文本标准化
- if 'text_columns' in config:
- result = self.normalize_text(result, config['text_columns'])
-
- # 转换数据类型
- if 'type_conversions' in config:
- result = self.convert_data_types(result, config['type_conversions'])
-
- # 解析日期
- if 'date_columns' in config:
- date_config = config['date_columns']
- if isinstance(date_config, list):
- result = self.parse_dates(result, date_config)
- elif isinstance(date_config, dict):
- for col, format_str in date_config.items():
- result = self.parse_dates(result, [col], date_format=format_str)
-
- self.logger.info(f"数据清洗完成,从 {len(df)} 行处理为 {len(result)} 行")
- return result
- # 使用示例
- def data_cleaning_example():
- # 创建示例数据
- data = {
- 'product_name': ['iPhone 13 ', 'Samsung Galaxy', 'Xiaomi Mi 11', 'iPhone 13', None],
- 'price': [5999, 4999, 3999, 5999, 2999],
- 'rating': [4.8, 4.6, 4.5, 4.8, 10.0], # 包含异常值
- 'reviews_count': ['120', '98', '75', '120', '30'], # 字符串类型
- 'release_date': ['2021-09-15', '2021-08-20', '2021-03-10', '2021-09-15', '2022-01-01']
- }
- df = pd.DataFrame(data)
-
- # 创建数据清洗器
- cleaner = DataCleaner()
-
- # 配置清洗参数
- config = {
- 'missing_values': {'strategy': 'drop'},
- 'remove_duplicates': True,
- 'outliers': {
- 'columns': ['rating', 'price'],
- 'method': 'zscore',
- 'threshold': 2.5
- },
- 'text_columns': ['product_name'],
- 'type_conversions': {'reviews_count': 'int'},
- 'date_columns': {'release_date': '%Y-%m-%d'}
- }
-
- # 执行数据清洗
- cleaned_df = cleaner.clean_data(df, config)
-
- print("原始数据:")
- print(df)
- print("\n清洗后的数据:")
- print(cleaned_df)
-
- return cleaned_df
- if __name__ == "__main__":
- data_cleaning_example()
复制代码 6.2 统计分析
统计分析用于计算数据的基本统计量和分布特性:
- import pandas as pd
- import numpy as np
- import scipy.stats as stats
- import logging
- class StatisticalAnalyzer:
- """统计分析类"""
-
- def __init__(self):
- """初始化统计分析器"""
- self.logger = self._setup_logger()
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('StatisticalAnalyzer')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def describe_data(self, df, include=None):
- """生成数据描述性统计
-
- 参数:
- df: 输入DataFrame
- include: 包含的数据类型,默认为None(所有数值列)
-
- 返回:
- 描述性统计结果DataFrame
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return pd.DataFrame()
-
- try:
- stats_df = df.describe(include=include, percentiles=[.1, .25, .5, .75, .9])
- self.logger.info("生成描述性统计完成")
- return stats_df
- except Exception as e:
- self.logger.error(f"生成描述性统计时出错: {e}")
- return pd.DataFrame()
-
- def correlation_analysis(self, df, method='pearson'):
- """相关性分析
-
- 参数:
- df: 输入DataFrame
- method: 相关系数计算方法,可选'pearson'、'spearman'、'kendall'
-
- 返回:
- 相关系数矩阵DataFrame
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return pd.DataFrame()
-
- # 筛选数值型列
- numeric_df = df.select_dtypes(include=['number'])
-
- if numeric_df.empty:
- self.logger.warning("没有数值型列可进行相关性分析")
- return pd.DataFrame()
-
- try:
- corr_matrix = numeric_df.corr(method=method)
- self.logger.info(f"使用 {method} 方法完成相关性分析")
- return corr_matrix
- except Exception as e:
- self.logger.error(f"计算相关系数时出错: {e}")
- return pd.DataFrame()
-
- def frequency_analysis(self, df, column, normalize=False, bins=None):
- """频率分析
-
- 参数:
- df: 输入DataFrame
- column: 要分析的列名
- normalize: 是否归一化频率
- bins: 数值型数据的分箱数量
-
- 返回:
- 频率分析结果Series
- """
- if df.empty or column not in df.columns:
- self.logger.warning(f"输入DataFrame为空或不包含列 {column}")
- return pd.Series()
-
- try:
- # 检查列的数据类型
- if pd.api.types.is_numeric_dtype(df[column]) and bins is not None:
- # 数值型数据,进行分箱
- freq = pd.cut(df[column], bins=bins).value_counts(normalize=normalize)
- self.logger.info(f"对数值列 {column} 进行分箱频率分析,分箱数量: {bins}")
- else:
- # 分类数据,直接计算频率
- freq = df[column].value_counts(normalize=normalize)
- self.logger.info(f"对列 {column} 进行频率分析")
-
- return freq
- except Exception as e:
- self.logger.error(f"进行频率分析时出错: {e}")
- return pd.Series()
-
- def group_analysis(self, df, group_by, agg_dict):
- """分组分析
-
- 参数:
- df: 输入DataFrame
- group_by: 分组列名或列名列表
- agg_dict: 聚合字典,键为列名,值为聚合函数或函数列表
-
- 返回:
- 分组分析结果DataFrame
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return pd.DataFrame()
-
- try:
- result = df.groupby(group_by).agg(agg_dict)
- self.logger.info(f"按 {group_by} 完成分组分析")
- return result
- except Exception as e:
- self.logger.error(f"进行分组分析时出错: {e}")
- return pd.DataFrame()
-
- def time_series_analysis(self, df, date_column, value_column, freq='D'):
- """时间序列分析
-
- 参数:
- df: 输入DataFrame
- date_column: 日期列名
- value_column: 值列名
- freq: 重采样频率,如'D'(天)、'W'(周)、'M'(月)
-
- 返回:
- 重采样后的时间序列DataFrame
- """
- if df.empty or date_column not in df.columns or value_column not in df.columns:
- self.logger.warning(f"输入DataFrame为空或缺少必要的列")
- return pd.DataFrame()
-
- try:
- # 确保日期列是datetime类型
- if not pd.api.types.is_datetime64_dtype(df[date_column]):
- df = df.copy()
- df[date_column] = pd.to_datetime(df[date_column])
-
- # 设置日期索引
- ts_df = df.set_index(date_column)
-
- # 按指定频率重采样并计算均值
- resampled = ts_df[value_column].resample(freq).mean()
-
- self.logger.info(f"完成时间序列分析,重采样频率: {freq}")
- return resampled.reset_index()
- except Exception as e:
- self.logger.error(f"进行时间序列分析时出错: {e}")
- return pd.DataFrame()
-
- def hypothesis_testing(self, df, column1, column2=None, test_type='ttest'):
- """假设检验
-
- 参数:
- df: 输入DataFrame
- column1: 第一个数据列名
- column2: 第二个数据列名(对于双样本检验)
- test_type: 检验类型,可选'ttest'、'anova'、'chi2'等
-
- 返回:
- 检验结果字典
- """
- if df.empty or column1 not in df.columns:
- self.logger.warning(f"输入DataFrame为空或不包含列 {column1}")
- return {}
-
- try:
- result = {}
-
- if test_type == 'ttest':
- # t检验
- if column2 and column2 in df.columns:
- # 双样本t检验
- t_stat, p_value = stats.ttest_ind(
- df[column1].dropna(),
- df[column2].dropna(),
- equal_var=False # 不假设方差相等
- )
- result = {
- 'test': 'Independent Samples t-test',
- 't_statistic': t_stat,
- 'p_value': p_value,
- 'significant': p_value < 0.05
- }
- self.logger.info(f"完成独立样本t检验: {column1} vs {column2}")
- else:
- # 单样本t检验(与0比较)
- t_stat, p_value = stats.ttest_1samp(df[column1].dropna(), 0)
- result = {
- 'test': 'One Sample t-test',
- 't_statistic': t_stat,
- 'p_value': p_value,
- 'significant': p_value < 0.05
- }
- self.logger.info(f"完成单样本t检验: {column1}")
-
- elif test_type == 'chi2' and column2 and column2 in df.columns:
- # 卡方检验(分类变量)
- contingency_table = pd.crosstab(df[column1], df[column2])
- chi2, p_value, dof, expected = stats.chi2_contingency(contingency_table)
- result = {
- 'test': 'Chi-square Test',
- 'chi2_statistic': chi2,
- 'p_value': p_value,
- 'degrees_of_freedom': dof,
- 'significant': p_value < 0.05
- }
- self.logger.info(f"完成卡方检验: {column1} vs {column2}")
-
- else:
- self.logger.warning(f"不支持的检验类型: {test_type}")
-
- return result
-
- except Exception as e:
- self.logger.error(f"进行假设检验时出错: {e}")
- return {'error': str(e)}
- # 使用示例
- def statistical_analysis_example():
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 3) # 3个特征
-
- # 生成分类目标变量
- y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
-
- # 生成回归目标变量
- y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=['feature_1', 'feature_2', 'feature_3']
- )
- data['target_class'] = y_class.astype(int)
- data['target_reg'] = y_reg
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
- data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
-
- # 创建统计分析器
- analyzer = StatisticalAnalyzer()
-
- # 描述性统计
- desc_stats = analyzer.describe_data(data)
- print("描述性统计:")
- print(desc_stats)
-
- # 相关性分析
- corr_matrix = analyzer.correlation_analysis(data)
- print("\n相关性矩阵:")
- print(corr_matrix)
-
- # 频率分析
- category_freq = analyzer.frequency_analysis(data, 'feature_1', normalize=True)
- print("\n特征1频率分析:")
- print(category_freq)
-
- # 分组分析
- group_result = analyzer.group_analysis(
- data,
- 'feature_1',
- {'target_class': ['mean', 'sum'], 'feature_2': 'mean', 'feature_3': 'mean'}
- )
- print("\n分组分析结果:")
- print(group_result)
-
- # 时间序列分析
- ts_result = analyzer.time_series_analysis(data, 'feature_1', 'target_reg', freq='W')
- print("\n时间序列分析结果(周均值):")
- print(ts_result.head())
-
- # 假设检验
- test_result = analyzer.hypothesis_testing(data, 'feature_1', test_type='ttest')
- print("\n假设检验结果:")
- print(test_result)
-
- return {
- 'desc_stats': desc_stats,
- 'corr_matrix': corr_matrix,
- 'category_freq': category_freq,
- 'group_result': group_result,
- 'ts_result': ts_result,
- 'test_result': test_result
- }
- if __name__ == "__main__":
- statistical_analysis_example()
复制代码 6.3 数据挖掘
数据挖掘是从大量数据中发现模式和关系的过程,包罗聚类分析、分类和回归模型等:
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import seaborn as sns
- from sklearn.cluster import KMeans, DBSCAN
- from sklearn.preprocessing import StandardScaler, MinMaxScaler
- from sklearn.decomposition import PCA
- from sklearn.model_selection import train_test_split, cross_val_score
- from sklearn.linear_model import LinearRegression, LogisticRegression
- from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
- from sklearn.metrics import (
- accuracy_score, precision_score, recall_score, f1_score,
- mean_squared_error, r2_score, silhouette_score
- )
- import logging
- class DataMiner:
- """数据挖掘类"""
-
- def __init__(self):
- """初始化数据挖掘器"""
- self.logger = self._setup_logger()
- self.models = {} # 存储训练好的模型
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('DataMiner')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def preprocess_data(self, df, scale_method='standard', categorical_cols=None):
- """数据预处理
-
- 参数:
- df: 输入DataFrame
- scale_method: 缩放方法,可选'standard'、'minmax'
- categorical_cols: 分类变量列名列表
-
- 返回:
- 预处理后的DataFrame和预处理器
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return df, None
-
- # 处理缺失值
- df_clean = df.dropna()
- if len(df_clean) < len(df):
- self.logger.info(f"删除了 {len(df) - len(df_clean)} 行含有缺失值的数据")
-
- # 处理分类变量
- if categorical_cols:
- df_encoded = pd.get_dummies(df_clean, columns=categorical_cols)
- self.logger.info(f"对 {len(categorical_cols)} 个分类变量进行了独热编码")
- else:
- df_encoded = df_clean
-
- # 数值变量缩放
- numeric_cols = df_encoded.select_dtypes(include=['number']).columns
-
- if scale_method == 'standard':
- scaler = StandardScaler()
- self.logger.info("使用StandardScaler进行标准化")
- elif scale_method == 'minmax':
- scaler = MinMaxScaler()
- self.logger.info("使用MinMaxScaler进行归一化")
- else:
- self.logger.warning(f"未知的缩放方法: {scale_method},不进行缩放")
- return df_encoded, None
-
- if len(numeric_cols) > 0:
- df_encoded[numeric_cols] = scaler.fit_transform(df_encoded[numeric_cols])
- self.logger.info(f"对 {len(numeric_cols)} 个数值变量进行了缩放")
-
- return df_encoded, scaler
-
- def reduce_dimensions(self, df, n_components=2, method='pca'):
- """降维
-
- 参数:
- df: 输入DataFrame
- n_components: 目标维度
- method: 降维方法,目前支持'pca'
-
- 返回:
- 降维后的DataFrame和降维器
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return df, None
-
- # 确保数据为数值型
- numeric_df = df.select_dtypes(include=['number'])
-
- if numeric_df.empty:
- self.logger.warning("没有数值型列可进行降维")
- return df, None
-
- try:
- if method == 'pca':
- reducer = PCA(n_components=n_components)
- reduced_data = reducer.fit_transform(numeric_df)
-
- # 创建包含降维结果的DataFrame
- result_df = pd.DataFrame(
- reduced_data,
- columns=[f'PC{i+1}' for i in range(n_components)],
- index=df.index
- )
-
- # 计算解释方差比例
- explained_variance = reducer.explained_variance_ratio_.sum()
- self.logger.info(f"PCA降维完成,保留了 {n_components} 个主成分,解释了 {explained_variance:.2%} 的方差")
-
- return result_df, reducer
- else:
- self.logger.warning(f"不支持的降维方法: {method}")
- return df, None
- except Exception as e:
- self.logger.error(f"降维过程中出错: {e}")
- return df, None
-
- def cluster_data(self, df, method='kmeans', n_clusters=3, eps=0.5, min_samples=5):
- """聚类分析
-
- 参数:
- df: 输入DataFrame
- method: 聚类方法,可选'kmeans'、'dbscan'
- n_clusters: KMeans的簇数量
- eps: DBSCAN的邻域半径
- min_samples: DBSCAN的最小样本数
-
- 返回:
- 带有聚类标签的DataFrame和聚类器
- """
- if df.empty:
- self.logger.warning("输入DataFrame为空")
- return df, None
-
- # 确保数据为数值型
- numeric_df = df.select_dtypes(include=['number'])
-
- if numeric_df.empty:
- self.logger.warning("没有数值型列可进行聚类")
- return df, None
-
- try:
- result_df = df.copy()
-
- if method == 'kmeans':
- # K-means聚类
- clusterer = KMeans(n_clusters=n_clusters, random_state=42)
- labels = clusterer.fit_predict(numeric_df)
-
- # 计算轮廓系数
- if n_clusters > 1 and len(numeric_df) > n_clusters:
- silhouette = silhouette_score(numeric_df, labels)
- self.logger.info(f"K-means聚类完成,轮廓系数: {silhouette:.4f}")
- else:
- self.logger.info("K-means聚类完成,但无法计算轮廓系数(簇数过少或数据量不足)")
-
- elif method == 'dbscan':
- # DBSCAN聚类
- clusterer = DBSCAN(eps=eps, min_samples=min_samples)
- labels = clusterer.fit_predict(numeric_df)
-
- # 计算聚类统计信息
- n_clusters_found = len(set(labels)) - (1 if -1 in labels else 0)
- n_noise = list(labels).count(-1)
- self.logger.info(f"DBSCAN聚类完成,发现 {n_clusters_found} 个簇,{n_noise} 个噪声点")
-
- else:
- self.logger.warning(f"不支持的聚类方法: {method}")
- return df, None
-
- # 添加聚类标签
- result_df['cluster'] = labels
-
- return result_df, clusterer
-
- except Exception as e:
- self.logger.error(f"聚类过程中出错: {e}")
- return df, None
-
- def train_classifier(self, df, target_col, feature_cols=None, model_type='random_forest', test_size=0.2):
- """训练分类模型
-
- 参数:
- df: 输入DataFrame
- target_col: 目标变量列名
- feature_cols: 特征列名列表,默认使用所有数值列
- model_type: 模型类型,可选'logistic'、'random_forest'
- test_size: 测试集比例
-
- 返回:
- 模型评估指标字典和训练好的模型
- """
- if df.empty or target_col not in df.columns:
- self.logger.warning(f"输入DataFrame为空或不包含目标列 {target_col}")
- return {}, None
-
- try:
- # 准备特征和目标变量
- if feature_cols is None:
- # 使用除目标列外的所有数值列作为特征
- feature_cols = df.select_dtypes(include=['number']).columns.tolist()
- if target_col in feature_cols:
- feature_cols.remove(target_col)
-
- if not feature_cols:
- self.logger.warning("没有可用的特征列")
- return {}, None
-
- X = df[feature_cols]
- y = df[target_col]
-
- # 划分训练集和测试集
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
-
- # 训练模型
- if model_type == 'logistic':
- model = LogisticRegression(max_iter=1000, random_state=42)
- model_name = 'Logistic Regression'
- elif model_type == 'random_forest':
- model = RandomForestClassifier(n_estimators=100, random_state=42)
- model_name = 'Random Forest'
- else:
- self.logger.warning(f"不支持的分类模型类型: {model_type}")
- return {}, None
-
- model.fit(X_train, y_train)
-
- # 在测试集上评估
- y_pred = model.predict(X_test)
-
- # 计算评估指标
- metrics = {
- 'accuracy': accuracy_score(y_test, y_pred),
- 'precision': precision_score(y_test, y_pred, average='weighted'),
- 'recall': recall_score(y_test, y_pred, average='weighted'),
- 'f1': f1_score(y_test, y_pred, average='weighted')
- }
-
- # 交叉验证
- cv_scores = cross_val_score(model, X, y, cv=5)
- metrics['cv_accuracy_mean'] = cv_scores.mean()
- metrics['cv_accuracy_std'] = cv_scores.std()
-
- self.logger.info(f"{model_name}分类模型训练完成,准确率: {metrics['accuracy']:.4f}")
-
- # 存储模型
- model_id = f"{model_type}_classifier_{target_col}"
- self.models[model_id] = {
- 'model': model,
- 'feature_cols': feature_cols,
- 'target_col': target_col,
- 'metrics': metrics
- }
-
- return metrics, model
-
- except Exception as e:
- self.logger.error(f"训练分类模型时出错: {e}")
- return {}, None
-
- def train_regressor(self, df, target_col, feature_cols=None, model_type='linear', test_size=0.2):
- """训练回归模型
-
- 参数:
- df: 输入DataFrame
- target_col: 目标变量列名
- feature_cols: 特征列名列表,默认使用所有数值列
- model_type: 模型类型,可选'linear'、'random_forest'
- test_size: 测试集比例
-
- 返回:
- 模型评估指标字典和训练好的模型
- """
- if df.empty or target_col not in df.columns:
- self.logger.warning(f"输入DataFrame为空或不包含目标列 {target_col}")
- return {}, None
-
- try:
- # 准备特征和目标变量
- if feature_cols is None:
- # 使用除目标列外的所有数值列作为特征
- feature_cols = df.select_dtypes(include=['number']).columns.tolist()
- if target_col in feature_cols:
- feature_cols.remove(target_col)
-
- if not feature_cols:
- self.logger.warning("没有可用的特征列")
- return {}, None
-
- X = df[feature_cols]
- y = df[target_col]
-
- # 划分训练集和测试集
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
-
- # 训练模型
- if model_type == 'linear':
- model = LinearRegression()
- model_name = 'Linear Regression'
- elif model_type == 'random_forest':
- model = RandomForestRegressor(n_estimators=100, random_state=42)
- model_name = 'Random Forest'
- else:
- self.logger.warning(f"不支持的回归模型类型: {model_type}")
- return {}, None
-
- model.fit(X_train, y_train)
-
- # 在测试集上评估
- y_pred = model.predict(X_test)
-
- # 计算评估指标
- metrics = {
- 'mse': mean_squared_error(y_test, y_pred),
- 'rmse': np.sqrt(mean_squared_error(y_test, y_pred)),
- 'r2': r2_score(y_test, y_pred)
- }
-
- # 交叉验证
- cv_scores = cross_val_score(model, X, y, cv=5, scoring='r2')
- metrics['cv_r2_mean'] = cv_scores.mean()
- metrics['cv_r2_std'] = cv_scores.std()
-
- self.logger.info(f"{model_name}回归模型训练完成,R²: {metrics['r2']:.4f}")
-
- # 存储模型
- model_id = f"{model_type}_regressor_{target_col}"
- self.models[model_id] = {
- 'model': model,
- 'feature_cols': feature_cols,
- 'target_col': target_col,
- 'metrics': metrics
- }
-
- return metrics, model
-
- except Exception as e:
- self.logger.error(f"训练回归模型时出错: {e}")
- return {}, None
-
- def predict(self, model_id, new_data):
- """使用训练好的模型进行预测
-
- 参数:
- model_id: 模型ID
- new_data: 新数据DataFrame
-
- 返回:
- 预测结果
- """
- if model_id not in self.models:
- self.logger.warning(f"模型ID {model_id} 不存在")
- return None
-
- model_info = self.models[model_id]
- model = model_info['model']
- feature_cols = model_info['feature_cols']
-
- # 检查新数据是否包含所有特征列
- missing_cols = [col for col in feature_cols if col not in new_data.columns]
- if missing_cols:
- self.logger.warning(f"新数据缺少特征列: {missing_cols}")
- return None
-
- try:
- # 提取特征
- X_new = new_data[feature_cols]
-
- # 进行预测
- predictions = model.predict(X_new)
-
- self.logger.info(f"使用模型 {model_id} 完成预测,预测样本数: {len(predictions)}")
-
- return predictions
-
- except Exception as e:
- self.logger.error(f"预测过程中出错: {e}")
- return None
-
- def get_feature_importance(self, model_id):
- """获取特征重要性
-
- 参数:
- model_id: 模型ID
-
- 返回:
- 特征重要性DataFrame
- """
- if model_id not in self.models:
- self.logger.warning(f"模型ID {model_id} 不存在")
- return pd.DataFrame()
-
- model_info = self.models[model_id]
- model = model_info['model']
- feature_cols = model_info['feature_cols']
-
- # 检查模型是否有feature_importances_属性
- if not hasattr(model, 'feature_importances_'):
- self.logger.warning(f"模型 {model_id} 不支持特征重要性分析")
-
- # 对于线性模型,可以使用系数作为特征重要性
- if hasattr(model, 'coef_'):
- importances = np.abs(model.coef_)
- if importances.ndim > 1:
- importances = importances.mean(axis=0)
- else:
- return pd.DataFrame()
- else:
- importances = model.feature_importances_
-
- # 创建特征重要性DataFrame
- importance_df = pd.DataFrame({
- 'feature': feature_cols,
- 'importance': importances
- })
-
- # 按重要性降序排序
- importance_df = importance_df.sort_values('importance', ascending=False)
-
- self.logger.info(f"获取模型 {model_id} 的特征重要性")
-
- return importance_df
- # 使用示例
- def data_mining_example():
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 5) # 5个特征
-
- # 生成分类目标变量
- y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
-
- # 生成回归目标变量
- y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=[f'feature_{i+1}' for i in range(5)]
- )
- data['target_class'] = y_class.astype(int)
- data['target_reg'] = y_reg
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
- data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
-
- # 创建数据挖掘器
- miner = DataMiner()
-
- # 数据预处理
- print("数据预处理...")
- data_processed, scaler = miner.preprocess_data(
- data,
- scale_method='standard',
- categorical_cols=['month', 'day_of_week']
- )
-
- # 降维分析
- print("\n降维分析...")
- data_reduced, pca = miner.reduce_dimensions(
- data_processed.drop(['target_class', 'target_reg'], axis=1),
- n_components=2
- )
-
- # 聚类分析
- print("\n聚类分析...")
- data_clustered, kmeans = miner.cluster_data(
- data_reduced,
- method='kmeans',
- n_clusters=3
- )
-
- # 分类模型
- print("\n训练分类模型...")
- class_metrics, classifier = miner.train_classifier(
- data_processed,
- target_col='target_class',
- model_type='random_forest'
- )
- print(f"分类模型评估指标: {class_metrics}")
-
- # 回归模型
- print("\n训练回归模型...")
- reg_metrics, regressor = miner.train_regressor(
- data_processed,
- target_col='target_reg',
- model_type='random_forest'
- )
- print(f"回归模型评估指标: {reg_metrics}")
-
- # 特征重要性
- print("\n特征重要性分析...")
- importance = miner.get_feature_importance('random_forest_regressor_target_reg')
- print(importance)
-
- return {
- 'data_processed': data_processed,
- 'data_reduced': data_reduced,
- 'data_clustered': data_clustered,
- 'class_metrics': class_metrics,
- 'reg_metrics': reg_metrics,
- 'feature_importance': importance
- }
- if __name__ == "__main__":
- data_mining_example()
复制代码 6.4 特性工程
特性工程是数据分析和机器学习中至关重要的一步,它可以明显进步模型性能:
- import pandas as pd
- import numpy as np
- from sklearn.preprocessing import PolynomialFeatures
- from sklearn.feature_selection import SelectKBest, f_regression, mutual_info_regression
- import logging
- class FeatureEngineer:
- """特征工程类"""
-
- def __init__(self):
- """初始化特征工程器"""
- self.logger = self._setup_logger()
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('FeatureEngineer')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def create_polynomial_features(self, df, feature_cols, degree=2, include_bias=False):
- """创建多项式特征
-
- 参数:
- df: 输入DataFrame
- feature_cols: 特征列名列表
- degree: 多项式次数
- include_bias: 是否包含偏置项
-
- 返回:
- 包含多项式特征的DataFrame
- """
- if df.empty or not feature_cols:
- self.logger.warning("输入DataFrame为空或未指定特征列")
- return df
-
- try:
- # 提取特征
- X = df[feature_cols].values
-
- # 创建多项式特征
- poly = PolynomialFeatures(degree=degree, include_bias=include_bias)
- poly_features = poly.fit_transform(X)
-
- # 创建特征名称
- feature_names = poly.get_feature_names_out(feature_cols)
-
- # 创建包含多项式特征的DataFrame
- poly_df = pd.DataFrame(poly_features, columns=feature_names, index=df.index)
-
- # 合并原始DataFrame和多项式特征
- result_df = pd.concat([df.drop(feature_cols, axis=1), poly_df], axis=1)
-
- self.logger.info(f"创建了 {poly_features.shape[1]} 个多项式特征,次数: {degree}")
-
- return result_df
-
- except Exception as e:
- self.logger.error(f"创建多项式特征时出错: {e}")
- return df
-
- def create_interaction_features(self, df, feature_cols):
- """创建交互特征
-
- 参数:
- df: 输入DataFrame
- feature_cols: 特征列名列表
-
- 返回:
- 包含交互特征的DataFrame
- """
- if df.empty or len(feature_cols) < 2:
- self.logger.warning("输入DataFrame为空或特征列不足")
- return df
-
- try:
- result_df = df.copy()
- interaction_count = 0
-
- # 创建两两特征的交互项
- for i in range(len(feature_cols)):
- for j in range(i+1, len(feature_cols)):
- col1 = feature_cols[i]
- col2 = feature_cols[j]
-
- # 创建交互特征
- interaction_name = f"{col1}_x_{col2}"
- result_df[interaction_name] = df[col1] * df[col2]
- interaction_count += 1
-
- self.logger.info(f"创建了 {interaction_count} 个交互特征")
-
- return result_df
-
- except Exception as e:
- self.logger.error(f"创建交互特征时出错: {e}")
- return df
-
- def create_binning_features(self, df, feature_col, bins=5, strategy='uniform'):
- """创建分箱特征
-
- 参数:
- df: 输入DataFrame
- feature_col: 要分箱的特征列名
- bins: 分箱数量或边界列表
- strategy: 分箱策略,可选'uniform'、'quantile'
-
- 返回:
- 包含分箱特征的DataFrame
- """
- if df.empty or feature_col not in df.columns:
- self.logger.warning(f"输入DataFrame为空或不包含列 {feature_col}")
- return df
-
- try:
- result_df = df.copy()
-
- # 确定分箱边界
- if isinstance(bins, int):
- if strategy == 'uniform':
- # 均匀分箱
- bin_edges = np.linspace(
- df[feature_col].min(),
- df[feature_col].max(),
- bins + 1
- )
- elif strategy == 'quantile':
- # 分位数分箱
- bin_edges = np.percentile(
- df[feature_col],
- np.linspace(0, 100, bins + 1)
- )
- else:
- self.logger.warning(f"不支持的分箱策略: {strategy}")
- return df
- else:
- # 使用指定的分箱边界
- bin_edges = bins
-
- # 创建分箱特征
- binned_feature = pd.cut(
- df[feature_col],
- bins=bin_edges,
- labels=False,
- include_lowest=True
- )
-
- # 添加分箱特征
- result_df[f"{feature_col}_bin"] = binned_feature
-
- # 创建独热编码的分箱特征
- bin_dummies = pd.get_dummies(
- binned_feature,
- prefix=f"{feature_col}_bin",
- prefix_sep="_"
- )
-
- # 合并结果
- result_df = pd.concat([result_df, bin_dummies], axis=1)
-
- self.logger.info(f"对特征 {feature_col} 创建了 {len(bin_edges)-1} 个分箱特征")
-
- return result_df
-
- except Exception as e:
- self.logger.error(f"创建分箱特征时出错: {e}")
- return df
-
- def select_best_features(self, df, feature_cols, target_col, k=5, method='f_regression'):
- """选择最佳特征
-
- 参数:
- df: 输入DataFrame
- feature_cols: 特征列名列表
- target_col: 目标变量列名
- k: 选择的特征数量
- method: 特征选择方法,可选'f_regression'、'mutual_info'
-
- 返回:
- 包含选定特征的DataFrame和特征得分
- """
- if df.empty or not feature_cols or target_col not in df.columns:
- self.logger.warning("输入DataFrame为空或未指定特征列或目标列")
- return df, {}
-
- try:
- # 提取特征和目标变量
- X = df[feature_cols]
- y = df[target_col]
-
- # 选择特征选择器
- if method == 'f_regression':
- selector = SelectKBest(score_func=f_regression, k=k)
- method_name = "F回归"
- elif method == 'mutual_info':
- selector = SelectKBest(score_func=mutual_info_regression, k=k)
- method_name = "互信息"
- else:
- self.logger.warning(f"不支持的特征选择方法: {method}")
- return df, {}
-
- # 拟合选择器
- selector.fit(X, y)
-
- # 获取选定的特征索引
- selected_indices = selector.get_support(indices=True)
- selected_features = [feature_cols[i] for i in selected_indices]
-
- # 创建特征得分字典
- feature_scores = dict(zip(feature_cols, selector.scores_))
-
- # 创建包含选定特征的DataFrame
- result_df = df.copy()
- dropped_features = [col for col in feature_cols if col not in selected_features]
- if dropped_features:
- result_df = result_df.drop(dropped_features, axis=1)
-
- self.logger.info(f"使用 {method_name} 方法选择了 {len(selected_features)} 个最佳特征")
-
- return result_df, feature_scores
-
- except Exception as e:
- self.logger.error(f"选择最佳特征时出错: {e}")
- return df, {}
- # 使用示例
- def feature_engineering_example():
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 3) # 3个特征
-
- # 生成目标变量(回归)
- y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=['feature_1', 'feature_2', 'feature_3']
- )
- data['target'] = y
-
- # 创建特征工程器
- engineer = FeatureEngineer()
-
- # 创建多项式特征
- print("创建多项式特征...")
- poly_data = engineer.create_polynomial_features(
- data,
- ['feature_1', 'feature_2', 'feature_3'],
- degree=2
- )
- print(f"多项式特征后的列: {poly_data.columns.tolist()}")
-
- # 创建交互特征
- print("\n创建交互特征...")
- interaction_data = engineer.create_interaction_features(
- data,
- ['feature_1', 'feature_2', 'feature_3']
- )
- print(f"交互特征后的列: {interaction_data.columns.tolist()}")
-
- # 创建分箱特征
- print("\n创建分箱特征...")
- binned_data = engineer.create_binning_features(
- data,
- 'feature_1',
- bins=5,
- strategy='quantile'
- )
- print(f"分箱特征后的列: {binned_data.columns.tolist()}")
-
- # 特征选择
- print("\n特征选择...")
- # 首先创建更多特征用于选择
- combined_data = engineer.create_polynomial_features(
- data,
- ['feature_1', 'feature_2', 'feature_3'],
- degree=2
- )
-
- # 选择最佳特征
- selected_data, feature_scores = engineer.select_best_features(
- combined_data,
- [col for col in combined_data.columns if col != 'target'],
- 'target',
- k=5,
- method='f_regression'
- )
-
- print("特征得分:")
- for feature, score in sorted(feature_scores.items(), key=lambda x: x[1], reverse=True):
- print(f"{feature}: {score:.4f}")
-
- print(f"\n选择的特征: {[col for col in selected_data.columns if col != 'target']}")
-
- return {
- 'original_data': data,
- 'poly_data': poly_data,
- 'interaction_data': interaction_data,
- 'binned_data': binned_data,
- 'selected_data': selected_data,
- 'feature_scores': feature_scores
- }
- if __name__ == "__main__":
- feature_engineering_example()
复制代码 6.5 数据分析模块集成
以下是如何将数据清洗、统计分析、数据挖掘和特性工程组件集成到一个完备的数据分析流程中:
- def complete_data_analysis_pipeline(data, config=None):
- """完整的数据分析流程
-
- 参数:
- data: 输入DataFrame
- config: 配置字典
-
- 返回:
- 分析结果字典
- """
- if config is None:
- config = {}
-
- results = {'original_data': data}
-
- # 1. 数据清洗
- print("1. 执行数据清洗...")
- cleaner = DataCleaner()
- clean_config = config.get('cleaning', {})
- cleaned_data = cleaner.clean_data(data, clean_config)
- results['cleaned_data'] = cleaned_data
-
- # 2. 统计分析
- print("\n2. 执行统计分析...")
- analyzer = StatisticalAnalyzer()
-
- # 描述性统计
- desc_stats = analyzer.describe_data(cleaned_data)
- results['descriptive_stats'] = desc_stats
-
- # 相关性分析
- corr_matrix = analyzer.correlation_analysis(cleaned_data)
- results['correlation_matrix'] = corr_matrix
-
- # 3. 特征工程
- print("\n3. 执行特征工程...")
- engineer = FeatureEngineer()
- feature_config = config.get('feature_engineering', {})
-
- engineered_data = cleaned_data.copy()
-
- # 应用多项式特征
- if 'polynomial' in feature_config:
- poly_config = feature_config['polynomial']
- engineered_data = engineer.create_polynomial_features(
- engineered_data,
- poly_config.get('features', []),
- degree=poly_config.get('degree', 2)
- )
-
- # 应用交互特征
- if 'interaction' in feature_config:
- interaction_config = feature_config['interaction']
- engineered_data = engineer.create_interaction_features(
- engineered_data,
- interaction_config.get('features', [])
- )
-
- # 应用分箱特征
- if 'binning' in feature_config:
- for bin_config in feature_config['binning']:
- engineered_data = engineer.create_binning_features(
- engineered_data,
- bin_config.get('feature'),
- bins=bin_config.get('bins', 5),
- strategy=bin_config.get('strategy', 'uniform')
- )
-
- results['engineered_data'] = engineered_data
-
- # 4. 数据挖掘
- print("\n4. 执行数据挖掘...")
- miner = DataMiner()
- mining_config = config.get('mining', {})
-
- # 数据预处理
- processed_data, scaler = miner.preprocess_data(
- engineered_data,
- scale_method=mining_config.get('scale_method', 'standard'),
- categorical_cols=mining_config.get('categorical_cols', [])
- )
- results['processed_data'] = processed_data
-
- # 降维分析
- if 'dimensionality_reduction' in mining_config:
- dr_config = mining_config['dimensionality_reduction']
- reduced_data, reducer = miner.reduce_dimensions(
- processed_data,
- n_components=dr_config.get('n_components', 2),
- method=dr_config.get('method', 'pca')
- )
- results['reduced_data'] = reduced_data
-
- # 聚类分析
- if 'clustering' in mining_config:
- cluster_config = mining_config['clustering']
- data_to_cluster = results.get('reduced_data', processed_data)
- clustered_data, clusterer = miner.cluster_data(
- data_to_cluster,
- method=cluster_config.get('method', 'kmeans'),
- n_clusters=cluster_config.get('n_clusters', 3)
- )
- results['clustered_data'] = clustered_data
-
- # 模型训练
- if 'models' in mining_config:
- models_results = {}
-
- for model_config in mining_config['models']:
- model_type = model_config.get('type')
- target = model_config.get('target')
- features = model_config.get('features')
-
- if model_type == 'classifier':
- metrics, model = miner.train_classifier(
- processed_data,
- target_col=target,
- feature_cols=features,
- model_type=model_config.get('algorithm', 'random_forest')
- )
- models_results[f'classifier_{target}'] = {
- 'metrics': metrics,
- 'model_id': f"{model_config.get('algorithm', 'random_forest')}_classifier_{target}"
- }
-
- elif model_type == 'regressor':
- metrics, model = miner.train_regressor(
- processed_data,
- target_col=target,
- feature_cols=features,
- model_type=model_config.get('algorithm', 'random_forest')
- )
- models_results[f'regressor_{target}'] = {
- 'metrics': metrics,
- 'model_id': f"{model_config.get('algorithm', 'random_forest')}_regressor_{target}"
- }
-
- results['models'] = models_results
-
- print("\n数据分析流程完成!")
- return results
- # 使用示例
- def data_analysis_example():
- # 创建示例数据
- np.random.seed(42)
- n_samples = 500
-
- # 生成特征
- X = np.random.randn(n_samples, 4) # 4个特征
-
- # 生成分类目标变量
- y_class = (X[:, 0] + X[:, 1] * 0.5 + np.random.randn(n_samples) * 0.1) > 0
-
- # 生成回归目标变量
- y_reg = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=[f'feature_{i+1}' for i in range(4)]
- )
- data['category'] = np.random.choice(['A', 'B', 'C', 'D'], n_samples)
- data['target_class'] = y_class.astype(int)
- data['target_reg'] = y_reg
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
- data['sales_per_customer'] = data['target_reg'] / np.random.poisson(10, n_samples)
-
- # 配置分析流程
- config = {
- 'cleaning': {
- 'missing_values': {'strategy': 'drop'},
- 'remove_duplicates': True,
- 'outliers': {
- 'columns': ['feature_1', 'feature_2', 'feature_3', 'feature_4'],
- 'method': 'zscore',
- 'threshold': 3.0
- },
- 'text_columns': [],
- 'type_conversions': {},
- 'date_columns': {}
- },
- 'feature_engineering': {
- 'polynomial': {
- 'features': ['feature_1', 'feature_2'],
- 'degree': 2
- },
- 'interaction': {
- 'features': ['feature_1', 'feature_2', 'feature_3']
- },
- 'binning': [
- {
- 'feature': 'feature_4',
- 'bins': 5,
- 'strategy': 'quantile'
- }
- ]
- },
- 'mining': {
- 'scale_method': 'standard',
- 'categorical_cols': ['category'],
- 'dimensionality_reduction': {
- 'n_components': 2,
- 'method': 'pca'
- },
- 'clustering': {
- 'method': 'kmeans',
- 'n_clusters': 3
- },
- 'models': [
- {
- 'type': 'classifier',
- 'target': 'target_class',
- 'algorithm': 'random_forest'
- },
- {
- 'type': 'regressor',
- 'target': 'target_reg',
- 'algorithm': 'random_forest'
- }
- ]
- }
- }
-
- # 执行分析流程
- results = complete_data_analysis_pipeline(data, config)
-
- # 打印部分结果
- print("\n描述性统计:")
- print(results['descriptive_stats'])
-
- print("\n模型性能:")
- for model_name, model_info in results['models'].items():
- print(f"{model_name}: {model_info['metrics']}")
-
- return results
- if __name__ == "__main__":
- data_analysis_example()
复制代码 7. 数据可视化模块
数据可视化是将数据转化为图形表现的过程,通过视觉元素如图表、图形和地图,使复杂数据更容易理解和分析。
7.1 静态可视化
静态可视化是指天生不可交互的图表,重要使用Matplotlib和Seaborn库:
- import matplotlib.pyplot as plt
- import seaborn as sns
- import pandas as pd
- import numpy as np
- import matplotlib.ticker as ticker
- from matplotlib.colors import LinearSegmentedColormap
- import logging
- from pathlib import Path
- class StaticVisualizer:
- """静态可视化类"""
-
- def __init__(self, output_dir='visualizations'):
- """初始化可视化器
-
- 参数:
- output_dir: 输出目录
- """
- self.output_dir = output_dir
- self.logger = self._setup_logger()
- self._setup_style()
- self._ensure_output_dir()
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('StaticVisualizer')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def _setup_style(self):
- """设置可视化样式"""
- # 设置Seaborn样式
- sns.set(style="whitegrid")
-
- # 设置Matplotlib参数
- plt.rcParams['figure.figsize'] = (10, 6)
- plt.rcParams['font.size'] = 12
- plt.rcParams['axes.labelsize'] = 14
- plt.rcParams['axes.titlesize'] = 16
- plt.rcParams['xtick.labelsize'] = 12
- plt.rcParams['ytick.labelsize'] = 12
- plt.rcParams['legend.fontsize'] = 12
- plt.rcParams['figure.titlesize'] = 20
-
- def _ensure_output_dir(self):
- """确保输出目录存在"""
- Path(self.output_dir).mkdir(parents=True, exist_ok=True)
- self.logger.info(f"输出目录: {self.output_dir}")
-
- def save_figure(self, fig, filename, dpi=300):
- """保存图表
-
- 参数:
- fig: 图表对象
- filename: 文件名
- dpi: 分辨率
- """
- filepath = Path(self.output_dir) / filename
- fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
- self.logger.info(f"图表已保存: {filepath}")
-
- return filepath
-
- def plot_bar_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
- color='skyblue', figsize=(10, 6), save_as=None, **kwargs):
- """绘制条形图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 条形颜色
- figsize: 图表大小
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 绘制条形图
- sns.barplot(x=x, y=y, data=data, color=color, ax=ax, **kwargs)
-
- # 设置标题和标签
- if title:
- ax.set_title(title)
- if xlabel:
- ax.set_xlabel(xlabel)
- if ylabel:
- ax.set_ylabel(ylabel)
-
- # 格式化y轴标签
- ax.yaxis.set_major_formatter(ticker.StrMethodFormatter('{x:,.0f}'))
-
- # 添加数值标签
- for p in ax.patches:
- ax.annotate(f'{p.get_height():,.0f}',
- (p.get_x() + p.get_width() / 2., p.get_height()),
- ha='center', va='bottom', fontsize=10)
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制条形图时出错: {e}")
- return None
-
- def plot_line_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
- color='royalblue', figsize=(12, 6), save_as=None, **kwargs):
- """绘制折线图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名或列名列表
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 线条颜色或颜色列表
- figsize: 图表大小
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 处理多条线的情况
- if isinstance(y, list):
- if not isinstance(color, list):
- color = [plt.cm.tab10(i) for i in range(len(y))]
-
- for i, col in enumerate(y):
- data.plot(x=x, y=col, ax=ax, label=col, color=color[i % len(color)], **kwargs)
- else:
- data.plot(x=x, y=y, ax=ax, color=color, **kwargs)
-
- # 设置标题和标签
- if title:
- ax.set_title(title)
- if xlabel:
- ax.set_xlabel(xlabel)
- if ylabel:
- ax.set_ylabel(ylabel)
-
- # 添加网格线
- ax.grid(True, linestyle='--', alpha=0.7)
-
- # 添加图例
- if isinstance(y, list) and len(y) > 1:
- ax.legend()
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制折线图时出错: {e}")
- return None
-
- def plot_pie_chart(self, data, values, names, title=None, figsize=(10, 10),
- colors=None, autopct='%1.1f%%', save_as=None, **kwargs):
- """绘制饼图
-
- 参数:
- data: DataFrame
- values: 值列名
- names: 名称列名
- title: 图表标题
- figsize: 图表大小
- colors: 颜色列表
- autopct: 百分比格式
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 准备数据
- if isinstance(data, pd.DataFrame):
- values_data = data[values].values
- names_data = data[names].values
- else:
- values_data = values
- names_data = names
-
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 绘制饼图
- wedges, texts, autotexts = ax.pie(
- values_data,
- labels=names_data,
- autopct=autopct,
- colors=colors,
- startangle=90,
- **kwargs
- )
-
- # 设置标题
- if title:
- ax.set_title(title)
-
- # 设置等比例
- ax.axis('equal')
-
- # 调整文本样式
- plt.setp(autotexts, size=10, weight='bold')
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制饼图时出错: {e}")
- return None
-
- def plot_histogram(self, data, column, bins=30, title=None, xlabel=None, ylabel='频率',
- color='skyblue', kde=True, figsize=(10, 6), save_as=None, **kwargs):
- """绘制直方图
-
- 参数:
- data: DataFrame
- column: 列名
- bins: 分箱数量
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 直方图颜色
- kde: 是否显示核密度估计
- figsize: 图表大小
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 绘制直方图
- sns.histplot(data=data, x=column, bins=bins, kde=kde, color=color, ax=ax, **kwargs)
-
- # 设置标题和标签
- if title:
- ax.set_title(title)
- if xlabel:
- ax.set_xlabel(xlabel)
- if ylabel:
- ax.set_ylabel(ylabel)
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制直方图时出错: {e}")
- return None
-
- def plot_scatter(self, data, x, y, title=None, xlabel=None, ylabel=None,
- hue=None, palette='viridis', size=None, figsize=(10, 8), save_as=None, **kwargs):
- """绘制散点图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- hue: 分组变量
- palette: 颜色调色板
- size: 点大小变量
- figsize: 图表大小
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 绘制散点图
- scatter = sns.scatterplot(
- data=data,
- x=x,
- y=y,
- hue=hue,
- palette=palette,
- size=size,
- ax=ax,
- **kwargs
- )
-
- # 设置标题和标签
- if title:
- ax.set_title(title)
- if xlabel:
- ax.set_xlabel(xlabel)
- if ylabel:
- ax.set_ylabel(ylabel)
-
- # 添加网格线
- ax.grid(True, linestyle='--', alpha=0.7)
-
- # 如果有分组变量,调整图例
- if hue:
- plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制散点图时出错: {e}")
- return None
-
- def plot_heatmap(self, data, title=None, cmap='viridis', annot=True, fmt='.2f',
- figsize=(12, 10), save_as=None, **kwargs):
- """绘制热力图
-
- 参数:
- data: DataFrame或矩阵
- title: 图表标题
- cmap: 颜色映射
- annot: 是否显示数值
- fmt: 数值格式
- figsize: 图表大小
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 绘制热力图
- heatmap = sns.heatmap(
- data,
- cmap=cmap,
- annot=annot,
- fmt=fmt,
- linewidths=.5,
- ax=ax,
- **kwargs
- )
-
- # 设置标题
- if title:
- ax.set_title(title)
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制热力图时出错: {e}")
- return None
-
- def plot_box(self, data, x=None, y=None, title=None, xlabel=None, ylabel=None,
- hue=None, palette='Set3', figsize=(12, 8), save_as=None, **kwargs):
- """绘制箱线图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- hue: 分组变量
- palette: 颜色调色板
- figsize: 图表大小
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 创建图表
- fig, ax = plt.subplots(figsize=figsize)
-
- # 绘制箱线图
- sns.boxplot(
- data=data,
- x=x,
- y=y,
- hue=hue,
- palette=palette,
- ax=ax,
- **kwargs
- )
-
- # 设置标题和标签
- if title:
- ax.set_title(title)
- if xlabel:
- ax.set_xlabel(xlabel)
- if ylabel:
- ax.set_ylabel(ylabel)
-
- # 如果有分组变量,调整图例
- if hue:
- plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制箱线图时出错: {e}")
- return None
-
- def plot_multiple_charts(self, data, chart_configs, title=None, figsize=(15, 10),
- nrows=None, ncols=None, save_as=None):
- """绘制多个子图
-
- 参数:
- data: DataFrame
- chart_configs: 子图配置列表,每个配置是一个字典,包含:
- - 'type': 图表类型 ('bar', 'line', 'scatter', 'hist', 'box', 'pie')
- - 'x', 'y': 数据列名
- - 'title': 子图标题
- - 其他特定图表类型的参数
- title: 总标题
- figsize: 图表大小
- nrows: 行数,如果为None则自动计算
- ncols: 列数,如果为None则自动计算
- save_as: 保存文件名
-
- 返回:
- matplotlib图表对象
- """
- try:
- # 确定子图布局
- n_charts = len(chart_configs)
-
- if nrows is None and ncols is None:
- # 自动计算行列数
- ncols = min(3, n_charts)
- nrows = (n_charts + ncols - 1) // ncols
- elif nrows is None:
- nrows = (n_charts + ncols - 1) // ncols
- elif ncols is None:
- ncols = (n_charts + nrows - 1) // nrows
-
- # 创建图表
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
-
- # 确保axes是二维数组
- if nrows == 1 and ncols == 1:
- axes = np.array([[axes]])
- elif nrows == 1:
- axes = axes.reshape(1, -1)
- elif ncols == 1:
- axes = axes.reshape(-1, 1)
-
- # 绘制每个子图
- for i, config in enumerate(chart_configs):
- if i >= nrows * ncols:
- self.logger.warning(f"子图数量超过布局容量,跳过第{i+1}个子图")
- break
-
- # 获取当前子图的轴
- row, col = i // ncols, i % ncols
- ax = axes[row, col]
-
- # 根据类型绘制不同的图表
- chart_type = config.get('type', 'bar').lower()
-
- if chart_type == 'bar':
- sns.barplot(
- data=data,
- x=config.get('x'),
- y=config.get('y'),
- hue=config.get('hue'),
- palette=config.get('palette', 'viridis'),
- ax=ax
- )
- elif chart_type == 'line':
- if isinstance(config.get('y'), list):
- for y_col in config.get('y'):
- data.plot(
- x=config.get('x'),
- y=y_col,
- ax=ax,
- label=y_col
- )
- else:
- data.plot(
- x=config.get('x'),
- y=config.get('y'),
- ax=ax
- )
- elif chart_type == 'scatter':
- sns.scatterplot(
- data=data,
- x=config.get('x'),
- y=config.get('y'),
- hue=config.get('hue'),
- palette=config.get('palette', 'viridis'),
- ax=ax
- )
- elif chart_type == 'hist':
- sns.histplot(
- data=data,
- x=config.get('x'),
- bins=config.get('bins', 30),
- kde=config.get('kde', True),
- ax=ax
- )
- elif chart_type == 'box':
- sns.boxplot(
- data=data,
- x=config.get('x'),
- y=config.get('y'),
- hue=config.get('hue'),
- palette=config.get('palette', 'viridis'),
- ax=ax
- )
- elif chart_type == 'pie':
- # 饼图需要特殊处理
- values = data[config.get('values')].values
- names = data[config.get('names')].values
- ax.pie(
- values,
- labels=names,
- autopct='%1.1f%%',
- startangle=90
- )
- ax.axis('equal')
-
- # 设置子图标题和标签
- if 'title' in config:
- ax.set_title(config['title'])
- if 'xlabel' in config:
- ax.set_xlabel(config['xlabel'])
- if 'ylabel' in config:
- ax.set_ylabel(config['ylabel'])
-
- # 隐藏多余的子图
- for i in range(n_charts, nrows * ncols):
- row, col = i // ncols, i % ncols
- fig.delaxes(axes[row, col])
-
- # 设置总标题
- if title:
- fig.suptitle(title, fontsize=16)
- plt.subplots_adjust(top=0.9)
-
- # 调整布局
- plt.tight_layout()
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制多个子图时出错: {e}")
- return None
- # 使用示例
- def static_visualization_example():
- """静态可视化示例"""
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 3) # 3个特征
-
- # 生成目标变量(回归)
- y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=['feature_1', 'feature_2', 'feature_3']
- )
- data['target'] = y
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu'], n_samples)
- data['sales_per_customer'] = data['target'] / np.random.poisson(10, n_samples)
-
- # 创建可视化器
- visualizer = StaticVisualizer(output_dir='visualizations/static')
-
- # 1. 绘制条形图 - 按月份的销售额
- monthly_sales = data.groupby('month')['target'].sum().reset_index()
- monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
- categories=['Jan', 'Feb', 'Mar', 'Apr'],
- ordered=True)
- monthly_sales = monthly_sales.sort_values('month')
-
- visualizer.plot_bar_chart(
- data=monthly_sales,
- x='month',
- y='target',
- title='Monthly Sales',
- xlabel='Month',
- ylabel='Total Sales',
- color='skyblue',
- save_as='monthly_sales_bar.png'
- )
-
- # 2. 绘制折线图 - 销售额和利润趋势
- visualizer.plot_line_chart(
- data=data,
- x='feature_1',
- y=['target', 'sales_per_customer'],
- title='Sales and Profit Trends',
- xlabel='Feature 1',
- ylabel='Amount',
- figsize=(14, 7),
- save_as='sales_profit_trend.png'
- )
-
- # 3. 绘制饼图 - 按区域的销售额分布
- region_sales = data.groupby('day_of_week')['target'].sum().reset_index()
-
- visualizer.plot_pie_chart(
- data=region_sales,
- values='target',
- names='day_of_week',
- title='Sales Distribution by Region',
- save_as='region_sales_pie.png'
- )
-
- # 4. 绘制直方图 - 每位客户销售额分布
- visualizer.plot_histogram(
- data=data,
- column='sales_per_customer',
- bins=20,
- title='Distribution of Sales per Customer',
- xlabel='Sales per Customer',
- ylabel='Frequency',
- save_as='sales_per_customer_hist.png'
- )
-
- # 5. 绘制散点图 - 客户数量与销售额的关系
- visualizer.plot_scatter(
- data=data,
- x='feature_2',
- y='target',
- title='Relationship between Number of Customers and Sales',
- xlabel='Number of Customers',
- ylabel='Sales',
- hue='month',
- save_as='customers_sales_scatter.png'
- )
-
- # 6. 绘制热力图 - 相关性矩阵
- correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'target']].corr()
-
- visualizer.plot_heatmap(
- data=correlation_matrix,
- title='Correlation Matrix',
- save_as='correlation_heatmap.png'
- )
-
- # 7. 绘制箱线图 - 按区域的销售额分布
- visualizer.plot_box(
- data=data,
- x='day_of_week',
- y='target',
- title='Sales Distribution by Region',
- xlabel='Region',
- ylabel='Sales',
- save_as='region_sales_box.png'
- )
-
- # 8. 绘制多个子图
- chart_configs = [
- {
- 'type': 'bar',
- 'x': 'month',
- 'y': 'target',
- 'title': 'Sales by Month'
- },
- {
- 'type': 'line',
- 'x': 'feature_1',
- 'y': 'target',
- 'title': 'Sales Trend'
- },
- {
- 'type': 'scatter',
- 'x': 'feature_2',
- 'y': 'target',
- 'title': 'Sales vs Feature 2'
- },
- {
- 'type': 'hist',
- 'x': 'sales_per_customer',
- 'title': 'Sales per Customer Distribution'
- }
- ]
-
- visualizer.plot_multiple_charts(
- data=data,
- chart_configs=chart_configs,
- title='Sales Dashboard',
- save_as='sales_dashboard.png'
- )
-
- print("静态可视化示例完成,图表已保存到 'visualizations/static' 目录")
-
- return {
- 'sales_data': data,
- 'visualizer': visualizer
- }
- if __name__ == "__main__":
- static_visualization_example()
复制代码 7.2 交互式可视化
交互式可视化允许用户与图表进行交互,比方缩放、悬停查察详情、筛选数据等,重要使用Plotly和Bokeh库:
- import plotly.express as px
- import plotly.graph_objects as go
- from plotly.subplots import make_subplots
- import pandas as pd
- import numpy as np
- import logging
- from pathlib import Path
- import json
- import plotly.io as pio
- class InteractiveVisualizer:
- """交互式可视化类"""
-
- def __init__(self, output_dir='visualizations/interactive'):
- """初始化可视化器
-
- 参数:
- output_dir: 输出目录
- """
- self.output_dir = output_dir
- self.logger = self._setup_logger()
- self._setup_style()
- self._ensure_output_dir()
-
- def _setup_logger(self):
- """设置日志记录器"""
- logger = logging.getLogger('InteractiveVisualizer')
- logger.setLevel(logging.INFO)
-
- if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
- return logger
-
- def _setup_style(self):
- """设置可视化样式"""
- # 设置Plotly模板
- pio.templates.default = "plotly_white"
-
- def _ensure_output_dir(self):
- """确保输出目录存在"""
- Path(self.output_dir).mkdir(parents=True, exist_ok=True)
- self.logger.info(f"输出目录: {self.output_dir}")
-
- def save_figure(self, fig, filename, include_plotlyjs='cdn'):
- """保存图表
-
- 参数:
- fig: 图表对象
- filename: 文件名
- include_plotlyjs: 是否包含plotly.js
- """
- filepath = Path(self.output_dir) / filename
-
- # 保存为HTML
- if filename.endswith('.html'):
- fig.write_html(filepath, include_plotlyjs=include_plotlyjs)
- # 保存为JSON
- elif filename.endswith('.json'):
- with open(filepath, 'w') as f:
- json.dump(fig.to_dict(), f)
- # 保存为图像
- else:
- fig.write_image(filepath)
-
- self.logger.info(f"图表已保存: {filepath}")
-
- return filepath
-
- def plot_bar_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
- color=None, barmode='group', figsize=(900, 600),
- save_as=None, **kwargs):
- """绘制交互式条形图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名或列名列表
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 分组变量
- barmode: 条形模式 ('group', 'stack', 'relative', 'overlay')
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 处理y为列表的情况
- if isinstance(y, list):
- fig = go.Figure()
-
- for col in y:
- fig.add_trace(go.Bar(
- x=data[x],
- y=data[col],
- name=col
- ))
-
- fig.update_layout(barmode=barmode)
- else:
- # 使用Plotly Express创建条形图
- fig = px.bar(
- data,
- x=x,
- y=y,
- color=color,
- barmode=barmode,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- width=figsize[0],
- height=figsize[1],
- hovermode='closest'
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式条形图时出错: {e}")
- return None
-
- def plot_line_chart(self, data, x, y, title=None, xlabel=None, ylabel=None,
- color=None, line_shape='linear', figsize=(900, 600),
- save_as=None, **kwargs):
- """绘制交互式折线图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名或列名列表
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 分组变量
- line_shape: 线条形状 ('linear', 'spline', 'hv', 'vh', 'hvh', 'vhv')
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 处理y为列表的情况
- if isinstance(y, list):
- fig = go.Figure()
-
- for col in y:
- fig.add_trace(go.Scatter(
- x=data[x],
- y=data[col],
- mode='lines+markers',
- name=col,
- line=dict(shape=line_shape)
- ))
- else:
- # 使用Plotly Express创建折线图
- fig = px.line(
- data,
- x=x,
- y=y,
- color=color,
- line_shape=line_shape,
- **kwargs
- )
-
- # 添加标记点
- fig.update_traces(mode='lines+markers')
-
- # 更新布局
- fig.update_layout(
- title=title,
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- width=figsize[0],
- height=figsize[1],
- hovermode='closest'
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式折线图时出错: {e}")
- return None
-
- def plot_pie_chart(self, data, values, names, title=None, figsize=(800, 800),
- hole=0, save_as=None, **kwargs):
- """绘制交互式饼图/环形图
-
- 参数:
- data: DataFrame
- values: 值列名
- names: 名称列名
- title: 图表标题
- figsize: 图表大小 (宽, 高)
- hole: 中心孔大小 (0-1),0为饼图,>0为环形图
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建饼图/环形图
- fig = px.pie(
- data,
- values=values,
- names=names,
- hole=hole,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- width=figsize[0],
- height=figsize[1]
- )
-
- # 更新轨迹
- fig.update_traces(
- textposition='inside',
- textinfo='percent+label',
- hoverinfo='label+percent+value'
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式饼图时出错: {e}")
- return None
-
- def plot_histogram(self, data, column, bins=30, title=None, xlabel=None, ylabel='频率',
- color=None, figsize=(900, 600), save_as=None, **kwargs):
- """绘制交互式直方图
-
- 参数:
- data: DataFrame
- column: 列名
- bins: 分箱数量
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 分组变量
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建直方图
- fig = px.histogram(
- data,
- x=column,
- color=color,
- nbins=bins,
- marginal='rug', # 添加边缘分布
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- width=figsize[0],
- height=figsize[1],
- bargap=0.1 # 条形之间的间隙
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式直方图时出错: {e}")
- return None
-
- def plot_scatter(self, data, x, y, title=None, xlabel=None, ylabel=None,
- color=None, size=None, hover_name=None, figsize=(900, 600),
- save_as=None, **kwargs):
- """绘制交互式散点图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 分组变量
- size: 点大小变量
- hover_name: 悬停显示的标识列
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建散点图
- fig = px.scatter(
- data,
- x=x,
- y=y,
- color=color,
- size=size,
- hover_name=hover_name,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- width=figsize[0],
- height=figsize[1],
- hovermode='closest'
- )
-
- # 添加趋势线
- if 'trendline' not in kwargs:
- fig.update_layout(
- shapes=[{
- 'type': 'line',
- 'x0': data[x].min(),
- 'y0': data[y].min(),
- 'x1': data[x].max(),
- 'y1': data[y].max(),
- 'line': {
- 'color': 'rgba(0,0,0,0.2)',
- 'width': 2,
- 'dash': 'dash'
- }
- }]
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式散点图时出错: {e}")
- return None
-
- def plot_heatmap(self, data, title=None, figsize=(900, 700),
- colorscale='Viridis', save_as=None, **kwargs):
- """绘制交互式热力图
-
- 参数:
- data: DataFrame或矩阵
- title: 图表标题
- figsize: 图表大小 (宽, 高)
- colorscale: 颜色映射
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 创建热力图
- fig = go.Figure(data=go.Heatmap(
- z=data.values,
- x=data.columns,
- y=data.index,
- colorscale=colorscale,
- **kwargs
- ))
-
- # 更新布局
- fig.update_layout(
- title=title,
- width=figsize[0],
- height=figsize[1]
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式热力图时出错: {e}")
- return None
-
- def plot_box(self, data, x=None, y=None, title=None, xlabel=None, ylabel=None,
- color=None, figsize=(900, 600), save_as=None, **kwargs):
- """绘制交互式箱线图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 分组变量
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建箱线图
- fig = px.box(
- data,
- x=x,
- y=y,
- color=color,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- width=figsize[0],
- height=figsize[1]
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式箱线图时出错: {e}")
- return None
-
- def plot_bubble(self, data, x, y, size, title=None, xlabel=None, ylabel=None,
- color=None, hover_name=None, figsize=(900, 600), save_as=None, **kwargs):
- """绘制交互式气泡图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- size: 气泡大小列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- color: 分组变量
- hover_name: 悬停显示的标识列
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建气泡图
- fig = px.scatter(
- data,
- x=x,
- y=y,
- size=size,
- color=color,
- hover_name=hover_name,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- width=figsize[0],
- height=figsize[1],
- hovermode='closest'
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式气泡图时出错: {e}")
- return None
-
- def plot_3d_scatter(self, data, x, y, z, title=None, xlabel=None, ylabel=None, zlabel=None,
- color=None, size=None, hover_name=None, figsize=(900, 700),
- save_as=None, **kwargs):
- """绘制交互式3D散点图
-
- 参数:
- data: DataFrame
- x: x轴列名
- y: y轴列名
- z: z轴列名
- title: 图表标题
- xlabel: x轴标签
- ylabel: y轴标签
- zlabel: z轴标签
- color: 分组变量
- size: 点大小变量
- hover_name: 悬停显示的标识列
- figsize: 图表大小 (宽, 高)
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建3D散点图
- fig = px.scatter_3d(
- data,
- x=x,
- y=y,
- z=z,
- color=color,
- size=size,
- hover_name=hover_name,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- scene=dict(
- xaxis_title=xlabel,
- yaxis_title=ylabel,
- zaxis_title=zlabel
- ),
- width=figsize[0],
- height=figsize[1]
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式3D散点图时出错: {e}")
- return None
-
- def plot_choropleth_map(self, data, locations, color, title=None,
- location_mode='ISO-3', figsize=(900, 600),
- colorscale='Viridis', save_as=None, **kwargs):
- """绘制交互式地理热力图
-
- 参数:
- data: DataFrame
- locations: 地理位置列名
- color: 颜色值列名
- title: 图表标题
- location_mode: 地理位置模式 ('ISO-3', 'country names', 等)
- figsize: 图表大小 (宽, 高)
- colorscale: 颜色映射
- save_as: 保存文件名
- **kwargs: 其他参数
-
- 返回:
- plotly图表对象
- """
- try:
- # 使用Plotly Express创建地理热力图
- fig = px.choropleth(
- data,
- locations=locations,
- color=color,
- locationmode=location_mode,
- color_continuous_scale=colorscale,
- **kwargs
- )
-
- # 更新布局
- fig.update_layout(
- title=title,
- width=figsize[0],
- height=figsize[1],
- geo=dict(
- showframe=False,
- showcoastlines=True,
- projection_type='equirectangular'
- )
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制交互式地理热力图时出错: {e}")
- return None
- def plot_multiple_charts(self, chart_configs, title=None, figsize=(1000, 800),
- rows=None, cols=None, subplot_titles=None, save_as=None):
- """绘制多个子图
-
- 参数:
- chart_configs: 子图配置列表,每个配置是一个字典,包含:
- - 'data': 数据
- - 'type': 图表类型 ('bar', 'line', 'scatter', 'pie', 'box', 'heatmap', 等)
- - 'x', 'y': 数据列名
- - 'row', 'col': 子图位置
- - 其他特定图表类型的参数
- title: 总标题
- figsize: 图表大小 (宽, 高)
- rows: 行数
- cols: 列数
- subplot_titles: 子图标题列表
- save_as: 保存文件名
-
- 返回:
- plotly图表对象
- """
- try:
- # 确定子图布局
- if rows is None or cols is None:
- # 查找最大的row和col值
- max_row = max([config.get('row', 1) for config in chart_configs])
- max_col = max([config.get('col', 1) for config in chart_configs])
- rows = max(rows or 0, max_row)
- cols = max(cols or 0, max_col)
-
- # 创建子图
- fig = make_subplots(
- rows=rows,
- cols=cols,
- subplot_titles=subplot_titles,
- specs=[[{"type": "xy"} for _ in range(cols)] for _ in range(rows)]
- )
-
- # 添加每个子图
- for config in chart_configs:
- data = config.get('data')
- chart_type = config.get('type', 'scatter').lower()
- row = config.get('row', 1)
- col = config.get('col', 1)
-
- if chart_type == 'bar':
- trace = go.Bar(
- x=data[config.get('x')],
- y=data[config.get('y')],
- name=config.get('name', config.get('y')),
- marker_color=config.get('color')
- )
- elif chart_type == 'line':
- trace = go.Scatter(
- x=data[config.get('x')],
- y=data[config.get('y')],
- mode='lines+markers',
- name=config.get('name', config.get('y')),
- line=dict(color=config.get('color'))
- )
- elif chart_type == 'scatter':
- trace = go.Scatter(
- x=data[config.get('x')],
- y=data[config.get('y')],
- mode='markers',
- name=config.get('name', config.get('y')),
- marker=dict(
- color=config.get('color'),
- size=config.get('size', 10)
- )
- )
- elif chart_type == 'pie':
- trace = go.Pie(
- values=data[config.get('values')],
- labels=data[config.get('names')],
- name=config.get('name', '')
- )
- elif chart_type == 'box':
- trace = go.Box(
- x=data[config.get('x')] if 'x' in config else None,
- y=data[config.get('y')],
- name=config.get('name', config.get('y'))
- )
- elif chart_type == 'heatmap':
- # 热力图需要特殊处理
- if isinstance(data, pd.DataFrame):
- z_data = data.values
- x_data = data.columns
- y_data = data.index
- else:
- z_data = data
- x_data = config.get('x')
- y_data = config.get('y')
-
- trace = go.Heatmap(
- z=z_data,
- x=x_data,
- y=y_data,
- colorscale=config.get('colorscale', 'Viridis')
- )
- else:
- self.logger.warning(f"未知的图表类型: {chart_type}")
- continue
-
- fig.add_trace(trace, row=row, col=col)
-
- # 更新轴标签
- if 'xlabel' in config:
- fig.update_xaxes(title_text=config['xlabel'], row=row, col=col)
- if 'ylabel' in config:
- fig.update_yaxes(title_text=config['ylabel'], row=row, col=col)
-
- # 更新布局
- fig.update_layout(
- title=title,
- width=figsize[0],
- height=figsize[1],
- showlegend=True
- )
-
- # 保存图表
- if save_as:
- self.save_figure(fig, save_as)
-
- return fig
-
- except Exception as e:
- self.logger.error(f"绘制多个子图时出错: {e}")
- return None
- # 使用示例
- def interactive_visualization_example():
- """交互式可视化示例"""
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 3) # 3个特征
-
- # 生成目标变量(回归)
- y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=['feature_1', 'feature_2', 'feature_3']
- )
- data['target'] = y
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
- data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
- data['sales'] = data['target'] * 100 + 500
- data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
- data['customers'] = np.random.poisson(50, n_samples)
- data['sales_per_customer'] = data['sales'] / data['customers']
-
- # 创建一些国家数据
- countries = ['USA', 'CAN', 'MEX', 'BRA', 'ARG', 'GBR', 'FRA', 'DEU', 'ITA', 'ESP',
- 'RUS', 'CHN', 'JPN', 'IND', 'AUS']
- country_codes = ['USA', 'CAN', 'MEX', 'BRA', 'ARG', 'GBR', 'FRA', 'DEU', 'ITA', 'ESP',
- 'RUS', 'CHN', 'JPN', 'IND', 'AUS']
- country_data = pd.DataFrame({
- 'country': countries,
- 'code': country_codes,
- 'gdp': np.random.uniform(100, 1000, len(countries)),
- 'population': np.random.uniform(10, 500, len(countries))
- })
-
- # 创建可视化器
- visualizer = InteractiveVisualizer(output_dir='visualizations/interactive')
-
- # 1. 绘制交互式条形图 - 按月份的销售额
- monthly_sales = data.groupby('month')['sales'].sum().reset_index()
- monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
- categories=['Jan', 'Feb', 'Mar', 'Apr'],
- ordered=True)
- monthly_sales = monthly_sales.sort_values('month')
-
- bar_fig = visualizer.plot_bar_chart(
- data=monthly_sales,
- x='month',
- y='sales',
- title='Monthly Sales',
- xlabel='Month',
- ylabel='Total Sales',
- save_as='monthly_sales_bar.html'
- )
-
- # 2. 绘制交互式折线图 - 销售额和利润趋势
- line_fig = visualizer.plot_line_chart(
- data=data.sort_values('feature_1').iloc[:50], # 使用部分数据
- x='feature_1',
- y=['sales', 'profit'],
- title='Sales and Profit Trends',
- xlabel='Feature 1',
- ylabel='Amount',
- save_as='sales_profit_trend.html'
- )
-
- # 3. 绘制交互式饼图 - 按区域的销售额分布
- region_sales = data.groupby('region')['sales'].sum().reset_index()
-
- pie_fig = visualizer.plot_pie_chart(
- data=region_sales,
- values='sales',
- names='region',
- title='Sales Distribution by Region',
- save_as='region_sales_pie.html'
- )
-
- # 4. 绘制交互式环形图 - 按星期几的销售额分布
- day_sales = data.groupby('day_of_week')['sales'].sum().reset_index()
-
- donut_fig = visualizer.plot_pie_chart(
- data=day_sales,
- values='sales',
- names='day_of_week',
- title='Sales Distribution by Day of Week',
- hole=0.4, # 环形图
- save_as='day_sales_donut.html'
- )
-
- # 5. 绘制交互式直方图 - 每位客户销售额分布
- hist_fig = visualizer.plot_histogram(
- data=data,
- column='sales_per_customer',
- bins=20,
- title='Distribution of Sales per Customer',
- xlabel='Sales per Customer',
- ylabel='Frequency',
- color='region', # 按区域分组
- save_as='sales_per_customer_hist.html'
- )
-
- # 6. 绘制交互式散点图 - 客户数量与销售额的关系
- scatter_fig = visualizer.plot_scatter(
- data=data,
- x='customers',
- y='sales',
- title='Relationship between Number of Customers and Sales',
- xlabel='Number of Customers',
- ylabel='Sales',
- color='region',
- size='profit', # 使用利润作为点大小
- hover_name='month', # 悬停显示月份
- save_as='customers_sales_scatter.html'
- )
-
- # 7. 绘制交互式热力图 - 相关性矩阵
- correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
-
- heatmap_fig = visualizer.plot_heatmap(
- data=correlation_matrix,
- title='Correlation Matrix',
- save_as='correlation_heatmap.html'
- )
-
- # 8. 绘制交互式箱线图 - 按区域的销售额分布
- box_fig = visualizer.plot_box(
- data=data,
- x='region',
- y='sales',
- title='Sales Distribution by Region',
- xlabel='Region',
- ylabel='Sales',
- color='region',
- save_as='region_sales_box.html'
- )
-
- # 9. 绘制交互式气泡图 - 特征与销售额和利润的关系
- bubble_fig = visualizer.plot_bubble(
- data=data,
- x='feature_1',
- y='feature_2',
- size='sales',
- color='region',
- title='Feature Relationships with Sales',
- xlabel='Feature 1',
- ylabel='Feature 2',
- hover_name='month',
- save_as='feature_sales_bubble.html'
- )
-
- # 10. 绘制交互式3D散点图 - 三个特征的关系
- scatter_3d_fig = visualizer.plot_3d_scatter(
- data=data,
- x='feature_1',
- y='feature_2',
- z='feature_3',
- color='sales',
- size='profit',
- title='3D Relationship between Features',
- xlabel='Feature 1',
- ylabel='Feature 2',
- zlabel='Feature 3',
- save_as='features_3d_scatter.html'
- )
-
- # 11. 绘制交互式地理热力图 - 国家GDP分布
- choropleth_fig = visualizer.plot_choropleth_map(
- data=country_data,
- locations='code',
- color='gdp',
- title='GDP by Country',
- location_mode='ISO-3',
- color_continuous_scale='Viridis',
- save_as='country_gdp_map.html'
- )
-
- # 12. 绘制多个子图 - 销售仪表盘
- chart_configs = [
- {
- 'data': monthly_sales,
- 'type': 'bar',
- 'x': 'month',
- 'y': 'sales',
- 'row': 1,
- 'col': 1,
- 'name': 'Monthly Sales',
- 'xlabel': 'Month',
- 'ylabel': 'Sales'
- },
- {
- 'data': data.sort_values('feature_1').iloc[:50],
- 'type': 'line',
- 'x': 'feature_1',
- 'y': 'sales',
- 'row': 1,
- 'col': 2,
- 'name': 'Sales Trend',
- 'xlabel': 'Feature 1',
- 'ylabel': 'Sales'
- },
- {
- 'data': data,
- 'type': 'scatter',
- 'x': 'customers',
- 'y': 'sales',
- 'row': 2,
- 'col': 1,
- 'name': 'Customers vs Sales',
- 'xlabel': 'Customers',
- 'ylabel': 'Sales'
- },
- {
- 'data': correlation_matrix,
- 'type': 'heatmap',
- 'row': 2,
- 'col': 2,
- 'name': 'Correlation'
- }
- ]
-
- subplot_titles = ['Monthly Sales', 'Sales Trend', 'Customers vs Sales', 'Correlation Matrix']
-
- dashboard_fig = visualizer.plot_multiple_charts(
- chart_configs=chart_configs,
- title='Sales Dashboard',
- rows=2,
- cols=2,
- subplot_titles=subplot_titles,
- save_as='sales_dashboard.html'
- )
-
- print("交互式可视化示例完成,图表已保存到 'visualizations/interactive' 目录")
-
- return {
- 'data': data,
- 'country_data': country_data,
- 'visualizer': visualizer,
- 'figures': {
- 'bar': bar_fig,
- 'line': line_fig,
- 'pie': pie_fig,
- 'donut': donut_fig,
- 'hist': hist_fig,
- 'scatter': scatter_fig,
- 'heatmap': heatmap_fig,
- 'box': box_fig,
- 'bubble': bubble_fig,
- 'scatter_3d': scatter_3d_fig,
- 'choropleth': choropleth_fig,
- 'dashboard': dashboard_fig
- }
- }
- if __name__ == "__main__":
- interactive_visualization_example()
复制代码 交互式仪表盘功能
- # 交互式仪表盘模块
- import dash
- from dash import dcc, html
- from dash.dependencies import Input, Output
- import plotly.express as px
- import plotly.graph_objects as go
- from plotly.subplots import make_subplots
- import pandas as pd
- import numpy as np
- import os
- from pathlib import Path
- import logging
- class DashboardBuilder:
- """交互式仪表盘构建器"""
-
- def __init__(self, title="数据分析仪表盘", theme="plotly_white"):
- """初始化仪表盘构建器
-
- 参数:
- title: 仪表盘标题
- theme: 仪表盘主题
- """
- self.title = title
- self.theme = theme
- self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
- self.app.title = title
- self.visualizer = InteractiveVisualizer()
- self.logger = logging.getLogger(__name__)
-
- # 设置Plotly主题
- pio.templates.default = theme
-
- def create_layout(self, components):
- """创建仪表盘布局
-
- 参数:
- components: 组件列表,每个组件是一个字典,包含:
- - 'type': 组件类型 ('graph', 'table', 'control', 等)
- - 'id': 组件ID
- - 'title': 组件标题
- - 'width': 组件宽度 (1-12)
- - 其他特定组件类型的参数
-
- 返回:
- Dash应用布局
- """
- try:
- # 创建页面布局
- layout = html.Div([
- # 标题
- html.H1(self.title, style={'textAlign': 'center', 'marginBottom': 30}),
-
- # 内容容器
- html.Div([
- # 为每个组件创建一个Div
- html.Div([
- # 组件标题
- html.H3(component.get('title', f"Component {i+1}"),
- style={'marginBottom': 15}),
-
- # 根据组件类型创建不同的内容
- self._create_component(component)
- ], className=f"col-{component.get('width', 12)}",
- style={'padding': '10px'})
-
- for i, component in enumerate(components)
- ], className='row')
- ], className='container-fluid')
-
- self.app.layout = layout
- return layout
-
- except Exception as e:
- self.logger.error(f"创建仪表盘布局时出错: {e}")
- return html.Div(f"创建仪表盘布局时出错: {e}")
-
- def _create_component(self, component):
- """根据组件类型创建组件
-
- 参数:
- component: 组件配置字典
-
- 返回:
- Dash组件
- """
- try:
- component_type = component.get('type', '').lower()
- component_id = component.get('id', f"component-{id(component)}")
-
- if component_type == 'graph':
- # 创建图表组件
- return dcc.Graph(
- id=component_id,
- figure=component.get('figure', {}),
- style={'height': component.get('height', 400)}
- )
-
- elif component_type == 'table':
- # 创建表格组件
- data = component.get('data', pd.DataFrame())
- return html.Div([
- dash.dash_table.DataTable(
- id=component_id,
- columns=[{"name": i, "id": i} for i in data.columns],
- data=data.to_dict('records'),
- page_size=component.get('page_size', 10),
- style_table={'overflowX': 'auto'},
- style_cell={
- 'textAlign': 'left',
- 'padding': '10px',
- 'minWidth': '100px', 'width': '150px', 'maxWidth': '300px',
- 'whiteSpace': 'normal',
- 'height': 'auto'
- },
- style_header={
- 'backgroundColor': 'rgb(230, 230, 230)',
- 'fontWeight': 'bold'
- }
- )
- ])
-
- elif component_type == 'control':
- # 创建控制组件
- control_subtype = component.get('control_type', '').lower()
-
- if control_subtype == 'dropdown':
- return dcc.Dropdown(
- id=component_id,
- options=[{'label': str(opt), 'value': opt}
- for opt in component.get('options', [])],
- value=component.get('value'),
- multi=component.get('multi', False),
- placeholder=component.get('placeholder', 'Select an option')
- )
-
- elif control_subtype == 'slider':
- return dcc.Slider(
- id=component_id,
- min=component.get('min', 0),
- max=component.get('max', 100),
- step=component.get('step', 1),
- value=component.get('value', 50),
- marks={i: str(i) for i in range(
- component.get('min', 0),
- component.get('max', 100) + 1,
- component.get('mark_step', 10)
- )}
- )
-
- elif control_subtype == 'radio':
- return dcc.RadioItems(
- id=component_id,
- options=[{'label': str(opt), 'value': opt}
- for opt in component.get('options', [])],
- value=component.get('value'),
- inline=component.get('inline', True)
- )
-
- elif control_subtype == 'checklist':
- return dcc.Checklist(
- id=component_id,
- options=[{'label': str(opt), 'value': opt}
- for opt in component.get('options', [])],
- value=component.get('value', []),
- inline=component.get('inline', True)
- )
-
- elif control_subtype == 'date':
- return dcc.DatePickerSingle(
- id=component_id,
- date=component.get('date'),
- min_date_allowed=component.get('min_date'),
- max_date_allowed=component.get('max_date')
- )
-
- elif control_subtype == 'daterange':
- return dcc.DatePickerRange(
- id=component_id,
- start_date=component.get('start_date'),
- end_date=component.get('end_date'),
- min_date_allowed=component.get('min_date'),
- max_date_allowed=component.get('max_date')
- )
-
- else:
- return html.Div(f"未知的控制类型: {control_subtype}")
-
- elif component_type == 'text':
- # 创建文本组件
- return html.Div([
- html.P(component.get('text', ''),
- style={'fontSize': component.get('font_size', 16)})
- ])
-
- elif component_type == 'html':
- # 创建自定义HTML组件
- return html.Div([
- html.Div(component.get('html', ''),
- dangerously_set_inner_html=True)
- ])
-
- else:
- return html.Div(f"未知的组件类型: {component_type}")
-
- except Exception as e:
- self.logger.error(f"创建组件时出错: {e}")
- return html.Div(f"创建组件时出错: {e}")
-
- def add_callback(self, outputs, inputs, state=None):
- """添加回调函数
-
- 参数:
- outputs: 输出组件列表,每个元素是一个元组 (component_id, component_property)
- inputs: 输入组件列表,每个元素是一个元组 (component_id, component_property)
- state: 状态组件列表,每个元素是一个元组 (component_id, component_property)
-
- 返回:
- 装饰器函数
- """
- try:
- # 转换为Dash输出格式
- dash_outputs = [Output(component_id, component_property)
- for component_id, component_property in outputs]
-
- # 转换为Dash输入格式
- dash_inputs = [Input(component_id, component_property)
- for component_id, component_property in inputs]
-
- # 转换为Dash状态格式
- dash_state = []
- if state:
- dash_state = [dash.dependencies.State(component_id, component_property)
- for component_id, component_property in state]
-
- # 返回Dash回调装饰器
- return self.app.callback(dash_outputs, dash_inputs, dash_state)
-
- except Exception as e:
- self.logger.error(f"添加回调函数时出错: {e}")
- return None
-
- def run_server(self, debug=True, port=8050, host='0.0.0.0'):
- """运行仪表盘服务器
-
- 参数:
- debug: 是否启用调试模式
- port: 服务器端口
- host: 服务器主机
- """
- try:
- self.app.run_server(debug=debug, port=port, host=host)
- except Exception as e:
- self.logger.error(f"运行仪表盘服务器时出错: {e}")
- # 使用示例
- def interactive_dashboard_example():
- """交互式仪表盘示例"""
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 3) # 3个特征
-
- # 生成目标变量(回归)
- y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=['feature_1', 'feature_2', 'feature_3']
- )
- data['target'] = y
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
- data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
- data['sales'] = data['target'] * 100 + 500
- data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
- data['customers'] = np.random.poisson(50, n_samples)
- data['sales_per_customer'] = data['sales'] / data['customers']
-
- # 创建可视化器
- visualizer = InteractiveVisualizer()
-
- # 创建一些图表
- monthly_sales = data.groupby('month')['sales'].sum().reset_index()
- monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
- categories=['Jan', 'Feb', 'Mar', 'Apr'],
- ordered=True)
- monthly_sales = monthly_sales.sort_values('month')
-
- bar_fig = visualizer.plot_bar_chart(
- data=monthly_sales,
- x='month',
- y='sales',
- title='Monthly Sales',
- xlabel='Month',
- ylabel='Total Sales'
- )
-
- region_sales = data.groupby('region')['sales'].sum().reset_index()
- pie_fig = visualizer.plot_pie_chart(
- data=region_sales,
- values='sales',
- names='region',
- title='Sales Distribution by Region'
- )
-
- scatter_fig = visualizer.plot_scatter(
- data=data,
- x='customers',
- y='sales',
- title='Relationship between Number of Customers and Sales',
- xlabel='Number of Customers',
- ylabel='Sales',
- color='region',
- size='profit'
- )
-
- correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
- heatmap_fig = visualizer.plot_heatmap(
- data=correlation_matrix,
- title='Correlation Matrix'
- )
-
- # 创建仪表盘构建器
- dashboard = DashboardBuilder(title="销售数据分析仪表盘")
-
- # 定义仪表盘组件
- components = [
- {
- 'type': 'control',
- 'id': 'region-filter',
- 'title': '区域筛选',
- 'control_type': 'dropdown',
- 'options': ['All'] + list(data['region'].unique()),
- 'value': 'All',
- 'width': 3
- },
- {
- 'type': 'control',
- 'id': 'month-filter',
- 'title': '月份筛选',
- 'control_type': 'checklist',
- 'options': list(data['month'].unique()),
- 'value': list(data['month'].unique()),
- 'width': 9
- },
- {
- 'type': 'graph',
- 'id': 'monthly-sales-chart',
- 'title': '月度销售额',
- 'figure': bar_fig,
- 'width': 6,
- 'height': 400
- },
- {
- 'type': 'graph',
- 'id': 'region-sales-chart',
- 'title': '区域销售额分布',
- 'figure': pie_fig,
- 'width': 6,
- 'height': 400
- },
- {
- 'type': 'graph',
- 'id': 'customer-sales-chart',
- 'title': '客户数量与销售额关系',
- 'figure': scatter_fig,
- 'width': 6,
- 'height': 400
- },
- {
- 'type': 'graph',
- 'id': 'correlation-matrix',
- 'title': '相关性矩阵',
- 'figure': heatmap_fig,
- 'width': 6,
- 'height': 400
- },
- {
- 'type': 'table',
- 'id': 'sales-table',
- 'title': '销售数据表',
- 'data': data[['month', 'region', 'sales', 'profit', 'customers']].head(10),
- 'width': 12,
- 'page_size': 10
- }
- ]
-
- # 创建仪表盘布局
- dashboard.create_layout(components)
-
- # 添加回调函数 - 区域筛选
- @dashboard.add_callback(
- outputs=[('sales-table', 'data')],
- inputs=[('region-filter', 'value'), ('month-filter', 'value')]
- )
- def update_table(region, months):
- filtered_data = data.copy()
-
- # 筛选区域
- if region != 'All':
- filtered_data = filtered_data[filtered_data['region'] == region]
-
- # 筛选月份
- if months:
- filtered_data = filtered_data[filtered_data['month'].isin(months)]
-
- return [filtered_data[['month', 'region', 'sales', 'profit', 'customers']].head(10).to_dict('records')]
-
- # 添加回调函数 - 更新图表
- @dashboard.add_callback(
- outputs=[
- ('monthly-sales-chart', 'figure'),
- ('region-sales-chart', 'figure'),
- ('customer-sales-chart', 'figure')
- ],
- inputs=[('region-filter', 'value'), ('month-filter', 'value')]
- )
- def update_charts(region, months):
- filtered_data = data.copy()
-
- # 筛选区域
- if region != 'All':
- filtered_data = filtered_data[filtered_data['region'] == region]
-
- # 筛选月份
- if months:
- filtered_data = filtered_data[filtered_data['month'].isin(months)]
-
- # 更新月度销售额图表
- monthly_sales = filtered_data.groupby('month')['sales'].sum().reset_index()
- monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
- categories=['Jan', 'Feb', 'Mar', 'Apr'],
- ordered=True)
- monthly_sales = monthly_sales.sort_values('month')
-
- bar_fig = visualizer.plot_bar_chart(
- data=monthly_sales,
- x='month',
- y='sales',
- title='Monthly Sales',
- xlabel='Month',
- ylabel='Total Sales'
- )
-
- # 更新区域销售额分布图表
- region_sales = filtered_data.groupby('region')['sales'].sum().reset_index()
- pie_fig = visualizer.plot_pie_chart(
- data=region_sales,
- values='sales',
- names='region',
- title='Sales Distribution by Region'
- )
-
- # 更新客户数量与销售额关系图表
- scatter_fig = visualizer.plot_scatter(
- data=filtered_data,
- x='customers',
- y='sales',
- title='Relationship between Number of Customers and Sales',
- xlabel='Number of Customers',
- ylabel='Sales',
- color='region',
- size='profit'
- )
-
- return [bar_fig, pie_fig, scatter_fig]
-
- # 运行仪表盘
- print("启动交互式仪表盘,请访问 http://127.0.0.1:8050/")
- dashboard.run_server(debug=True)
- if __name__ == "__main__":
- interactive_dashboard_example()
- # 可视化模块整合
- class VisualizationManager:
- """可视化管理器,整合静态和交互式可视化"""
-
- def __init__(self, output_dir='visualizations'):
- """初始化可视化管理器
-
- 参数:
- output_dir: 输出目录
- """
- # 创建静态和交互式可视化器
- self.static_visualizer = StaticVisualizer(output_dir=os.path.join(output_dir, 'static'))
- self.interactive_visualizer = InteractiveVisualizer(output_dir=os.path.join(output_dir, 'interactive'))
- self.output_dir = output_dir
- self.logger = logging.getLogger(__name__)
-
- # 确保输出目录存在
- os.makedirs(output_dir, exist_ok=True)
-
- def create_visualization(self, data, chart_type, static=True, interactive=True, **kwargs):
- """创建可视化图表
-
- 参数:
- data: 输入数据
- chart_type: 图表类型 ('bar', 'line', 'scatter', 'pie', 'box', 'heatmap', 等)
- static: 是否创建静态图表
- interactive: 是否创建交互式图表
- **kwargs: 其他参数
-
- 返回:
- 字典,包含静态和交互式图表对象
- """
- try:
- result = {}
-
- # 根据图表类型选择相应的方法
- method_name = f"plot_{chart_type}"
-
- # 创建静态图表
- if static and hasattr(self.static_visualizer, method_name):
- static_method = getattr(self.static_visualizer, method_name)
- static_fig = static_method(data=data, **kwargs)
- result['static'] = static_fig
-
- # 创建交互式图表
- if interactive and hasattr(self.interactive_visualizer, method_name):
- interactive_method = getattr(self.interactive_visualizer, method_name)
- interactive_fig = interactive_method(data=data, **kwargs)
- result['interactive'] = interactive_fig
-
- return result
-
- except Exception as e:
- self.logger.error(f"创建可视化图表时出错: {e}")
- return {}
-
- def create_dashboard(self, data, config, title="数据可视化仪表盘"):
- """创建交互式仪表盘
-
- 参数:
- data: 输入数据
- config: 仪表盘配置,包含组件列表
- title: 仪表盘标题
-
- 返回:
- DashboardBuilder对象
- """
- try:
- # 创建仪表盘构建器
- dashboard = DashboardBuilder(title=title)
-
- # 创建组件
- components = []
-
- for component_config in config:
- component_type = component_config.get('type')
-
- if component_type == 'graph':
- # 创建图表组件
- chart_type = component_config.get('chart_type')
- chart_params = component_config.get('params', {})
-
- # 创建图表
- chart_result = self.create_visualization(
- data=component_config.get('data', data),
- chart_type=chart_type,
- static=False,
- interactive=True,
- **chart_params
- )
-
- # 添加到组件列表
- if 'interactive' in chart_result:
- components.append({
- 'type': 'graph',
- 'id': component_config.get('id', f"graph-{len(components)}"),
- 'title': component_config.get('title', f"{chart_type.capitalize()} Chart"),
- 'figure': chart_result['interactive'],
- 'width': component_config.get('width', 6),
- 'height': component_config.get('height', 400)
- })
-
- elif component_type in ['control', 'table', 'text', 'html']:
- # 直接添加其他类型的组件
- components.append(component_config)
-
- # 创建仪表盘布局
- dashboard.create_layout(components)
-
- return dashboard
-
- except Exception as e:
- self.logger.error(f"创建仪表盘时出错: {e}")
- return None
-
- def export_visualizations(self, visualizations, format='html'):
- """导出可视化图表
-
- 参数:
- visualizations: 可视化图表字典
- format: 导出格式 ('html', 'png', 'pdf', 等)
-
- 返回:
- 导出文件路径列表
- """
- try:
- export_paths = []
-
- for name, viz_dict in visualizations.items():
- # 导出静态图表
- if 'static' in viz_dict and viz_dict['static'] is not None:
- static_path = os.path.join(self.output_dir, 'exports', 'static', f"{name}.{format}")
- os.makedirs(os.path.dirname(static_path), exist_ok=True)
-
- if format == 'html':
- # 对于Matplotlib图表,需要先保存为图像
- temp_path = os.path.join(self.output_dir, 'exports', 'static', f"{name}.png")
- viz_dict['static'].savefig(temp_path)
-
- # 创建HTML包装
- with open(static_path, 'w') as f:
- f.write(f"""
- <html>
- <head><title>{name} - Static Visualization</title></head>
- <body>
- <h1>{name}</h1>
- <img src="{os.path.basename(temp_path)}" alt="{name}">
- </body>
- </html>
- """)
- else:
- viz_dict['static'].savefig(static_path)
-
- export_paths.append(static_path)
-
- # 导出交互式图表
- if 'interactive' in viz_dict and viz_dict['interactive'] is not None:
- interactive_path = os.path.join(self.output_dir, 'exports', 'interactive', f"{name}.html")
- os.makedirs(os.path.dirname(interactive_path), exist_ok=True)
-
- # 保存Plotly图表
- viz_dict['interactive'].write_html(interactive_path)
- export_paths.append(interactive_path)
-
- return export_paths
-
- except Exception as e:
- self.logger.error(f"导出可视化图表时出错: {e}")
- return []
- # 使用示例
- def visualization_manager_example():
- """可视化管理器示例"""
- # 创建示例数据
- np.random.seed(42)
- n_samples = 200
-
- # 生成特征
- X = np.random.randn(n_samples, 3) # 3个特征
-
- # 生成目标变量(回归)
- y = 2 * X[:, 0] + X[:, 1]**2 + 0.5 * X[:, 0] * X[:, 2] + np.random.randn(n_samples) * 0.5
-
- # 创建DataFrame
- data = pd.DataFrame(
- X,
- columns=['feature_1', 'feature_2', 'feature_3']
- )
- data['target'] = y
-
- # 添加一些派生列
- data['month'] = np.random.choice(['Jan', 'Feb', 'Mar', 'Apr'], n_samples)
- data['day_of_week'] = np.random.choice(['Mon', 'Tue', 'Wed', 'Thu', 'Fri'], n_samples)
- data['region'] = np.random.choice(['North', 'South', 'East', 'West'], n_samples)
- data['sales'] = data['target'] * 100 + 500
- data['profit'] = data['sales'] * np.random.uniform(0.1, 0.3, n_samples)
- data['customers'] = np.random.poisson(50, n_samples)
- data['sales_per_customer'] = data['sales'] / data['customers']
-
- # 创建可视化管理器
- viz_manager = VisualizationManager(output_dir='visualizations')
-
- # 创建各种图表
- visualizations = {}
-
- # 1. 条形图
- monthly_sales = data.groupby('month')['sales'].sum().reset_index()
- monthly_sales['month'] = pd.Categorical(monthly_sales['month'],
- categories=['Jan', 'Feb', 'Mar', 'Apr'],
- ordered=True)
- monthly_sales = monthly_sales.sort_values('month')
-
- bar_charts = viz_manager.create_visualization(
- data=monthly_sales,
- chart_type='bar_chart',
- x='month',
- y='sales',
- title='Monthly Sales',
- xlabel='Month',
- ylabel='Total Sales',
- save_as='monthly_sales'
- )
- visualizations['monthly_sales'] = bar_charts
-
- # 2. 散点图
- scatter_charts = viz_manager.create_visualization(
- data=data,
- chart_type='scatter',
- x='customers',
- y='sales',
- title='Relationship between Number of Customers and Sales',
- xlabel='Number of Customers',
- ylabel='Sales',
- color='region',
- size='profit',
- save_as='customers_sales'
- )
- visualizations['customers_sales'] = scatter_charts
-
- # 3. 热力图
- correlation_matrix = data[['feature_1', 'feature_2', 'feature_3', 'sales', 'profit', 'customers']].corr()
- heatmap_charts = viz_manager.create_visualization(
- data=correlation_matrix,
- chart_type='heatmap',
- title='Correlation Matrix',
- save_as='correlation_matrix'
- )
- visualizations['correlation_matrix'] = heatmap_charts
-
- # 导出可视化图表
- export_paths = viz_manager.export_visualizations(visualizations)
- print(f"导出的可视化图表: {export_paths}")
-
- # 创建仪表盘
- dashboard_config = [
- {
- 'type': 'control',
- 'id': 'region-filter',
- 'title': '区域筛选',
- 'control_type': 'dropdown',
- 'options': ['All'] + list(data['region'].unique()),
- 'value': 'All',
- 'width': 3
- },
- {
- 'type': 'graph',
- 'id': 'monthly-sales-chart',
- 'title': '月度销售额',
- 'chart_type': 'bar_chart',
- 'data': monthly_sales,
- 'params': {
- 'x': 'month',
- 'y': 'sales',
- 'title': 'Monthly Sales',
- 'xlabel': 'Month',
- 'ylabel': 'Total Sales'
- },
- 'width': 6
- },
- {
- 'type': 'graph',
- 'id': 'customer-sales-chart',
- 'title': '客户数量与销售额关系',
- 'chart_type': 'scatter',
- 'params': {
- 'x': 'customers',
- 'y': 'sales',
- 'title': 'Relationship between Number of Customers and Sales',
- 'xlabel': 'Number of Customers',
- 'ylabel': 'Sales',
- 'color': 'region',
- 'size': 'profit'
- },
- 'width': 6
- },
- {
- 'type': 'graph',
- 'id': 'correlation-matrix',
- 'title': '相关性矩阵',
- 'chart_type': 'heatmap',
- 'data': correlation_matrix,
- 'params': {
- 'title': 'Correlation Matrix'
- },
- 'width': 12
- }
- ]
-
- dashboard = viz_manager.create_dashboard(data, dashboard_config, title="销售数据分析仪表盘")
-
- if dashboard:
- print("创建仪表盘成功,运行 dashboard.run_server() 启动仪表盘")
-
- return {
- 'data': data,
- 'visualizations': visualizations,
- 'viz_manager': viz_manager,
- 'dashboard': dashboard
- }
- if __name__ == "__main__":
- visualization_manager_example()
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |