By fanbingbing, 31 May, 2025
import os import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from scipy.stats import mannwhitneyu from itertools import combinations import matplotlib sns.set_style("whitegrid") sns.set_palette("Set2") plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'Microsoft YaHei'] plt.rcParams['axes.unicode_minus'] = False file_path = r"E:\表型图片\5.16" fig_path = r"E:\表型图片\改进" if not os.path.exists(fig_path): os.makedirs(fig_path) def add_stat_annotation(ax, data, x_col, y_col): groups = data[x_col].unique() if len(groups) < 2: return pairs = list(combinations(groups, 2)) y_min, y_max = ax.get_ylim() y_range = y_max - y_min for i, (group1, group2) in enumerate(pairs): data1 = data[data[x_col] == group1][y_col] data2 = data[data[x_col] == group2][y_col] try: stat, p_value = mannwhitneyu(data1, data2, alternative='two-sided') except: continue if p_value < 0.001: sig_symbol = '***' p_text = 'p < 0.001' elif p_value < 0.01: sig_symbol = '**' p_text = f'p = {p_value:.3f}' elif p_value < 0.05: sig_symbol = '*' p_text = f'p = {p_value:.3f}' else: p_text = f'p = {p_value:.3f}' # 即使不显著也显示p值 sig_symbol = 'ns' # 不显著 x1, x2 = groups.tolist().index(group1), groups.tolist().index(group2) y_pos = y_max + (0.05 + 0.15 * i) * y_range # 增加间距 h = 0.02 * y_range font_size = 9 ax.plot([x1, x1, x2, x2], [y_pos-h, y_pos, y_pos, y_pos-h], lw=1.5, color='black') ax.text((x1+x2)*0.5, y_pos + h*0.5, sig_symbol, ha='center', va='bottom', color='black', fontsize=font_size, fontweight='bold') ax.text((x1+x2)*0.5, y_pos - h*2, p_text, ha='center', va='top', color='black', fontsize=font_size-1, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1)) current_ymax = ax.get_ylim()[1] if y_pos + h*3 > current_ymax: ax.set_ylim(top=y_pos + h*4) files = [f for f in os.listdir(file_path) if f.endswith(('.xlsx', '.xls'))] for i in files: ipath = os.path.join(file_path, i) excel_file = pd.ExcelFile(ipath) sheet_names = excel_file.sheet_names num_sheets = len(sheet_names) cols = min(3, num_sheets) # 每行最多3个子图 rows = (num_sheets + cols - 1) // cols # 向上取整 fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows)) if num_sheets == 1: axes = [axes] else: axes = axes.flatten() for idx, sheet in enumerate(sheet_names): try: data = pd.read_excel(ipath, sheet_name=sheet) data = data.dropna(subset=['name']) data['group'] = data['category'].fillna(method='ffill') ax = axes[idx] sns.boxplot(data=data, x='group', y='area', ax=ax, width=0.6, linewidth=1.5, fliersize=4) sns.stripplot(data=data, x='group', y='area', ax=ax, color='black', alpha=0.5, size=4, jitter=0.2) add_stat_annotation(ax, data, 'group', 'area') ax.set_title(f"{sheet}", fontsize=12, pad=10, fontweight='bold') ax.set_xlabel("Group", fontsize=10, labelpad=8) ax.set_ylabel("Area", fontsize=10, labelpad=8) ax.tick_params(axis='both', which='major', labelsize=9) current_ymin, current_ymax = ax.get_ylim() data_max = data['area'].max() new_ymax = max(current_ymax, data_max * 1.3) ax.set_ylim(current_ymin, new_ymax) if idx == 0: from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], marker='o', color='w', label='Data Points', markerfacecolor='black', markersize=6, alpha=0.5), Line2D([0], [0], color='black', lw=2, label='Median'), Line2D([0], [0], color='black', lw=1, linestyle='--', label='IQR') ] ax.legend(handles=legend_elements, loc='upper right', fontsize=8) except Exception as e: print(f"Error processing sheet {sheet} in file {i}: {str(e)}") axes[idx].axis('off') # 出错时关闭该子图 continue for ax in axes[num_sheets:]: ax.axis('off') plt.tight_layout(pad=3.0, h_pad=2.0, w_pad=2.0) output_path = os.path.join(fig_path, f"{os.path.splitext(i)[0]}.pdf") plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() def add_stat_annotation(ax, data, x_col, y_col): """添加统计显著性标注,格式与示例图片一致""" groups = data[x_col].unique() if len(groups) < 2: return # 生成所有可能的组间比较组合 pairs = list(combinations(groups, 2)) # 计算y轴范围用于确定标注位置 y_min, y_max = ax.get_ylim() y_range = y_max - y_min for i, (group1, group2) in enumerate(pairs): # 获取两组数据 data1 = data[data[x_col] == group1][y_col] data2 = data[data[x_col] == group2][y_col] # 执行Mann-Whitney U检验 try: stat, p_value = mannwhitneyu(data1, data2, alternative='two-sided') except: continue # 计算标注位置 x1, x2 = groups.tolist().index(group1), groups.tolist().index(group2) y_pos = y_max + (0.05 + 0.12 * i) * y_range # 调整间距 h = 0.02 * y_range font_size = 9 # 绘制标注线 ax.plot([x1, x1, x2, x2], [y_pos-h, y_pos, y_pos, y_pos-h], lw=1.5, color='black') # 格式化p值为科学计数法(保留2位小数) if p_value < 0.0001: p_text = f"p = {p_value:.2e}" else: p_text = f"p = {p_value:.4f}" # 添加p值文本(在比较线下方) ax.text((x1+x2)*0.5, y_pos - h*1.5, p_text, ha='center', va='top', color='black', fontsize=font_size) # 调整y轴上限以适应标注 current_ymax = ax.get_ylim()[1] if y_pos + h > current_ymax: ax.set_ylim(top=y_pos + h*2) print(f"所有图表已保存至: {fig_path}")