Forums
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import mannwhitneyu
from itertools import combinations
import matplotlib
sns.set_style("whitegrid")
sns.set_palette("Set2")
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False
file_path = r"E:\表型图片\5.16"
fig_path = r"E:\表型图片\改进"
if not os.path.exists(fig_path):
os.makedirs(fig_path)
def add_stat_annotation(ax, data, x_col, y_col):
groups = data[x_col].unique()
if len(groups) < 2:
return
pairs = list(combinations(groups, 2))
y_min, y_max = ax.get_ylim()
y_range = y_max - y_min
for i, (group1, group2) in enumerate(pairs):
data1 = data[data[x_col] == group1][y_col]
data2 = data[data[x_col] == group2][y_col]
try:
stat, p_value = mannwhitneyu(data1, data2, alternative='two-sided')
except:
continue
if p_value < 0.001:
sig_symbol = '***'
p_text = 'p < 0.001'
elif p_value < 0.01:
sig_symbol = '**'
p_text = f'p = {p_value:.3f}'
elif p_value < 0.05:
sig_symbol = '*'
p_text = f'p = {p_value:.3f}'
else:
p_text = f'p = {p_value:.3f}' # 即使不显著也显示p值
sig_symbol = 'ns' # 不显著
x1, x2 = groups.tolist().index(group1), groups.tolist().index(group2)
y_pos = y_max + (0.05 + 0.15 * i) * y_range # 增加间距
h = 0.02 * y_range
font_size = 9
ax.plot([x1, x1, x2, x2], [y_pos-h, y_pos, y_pos, y_pos-h],
lw=1.5, color='black')
ax.text((x1+x2)*0.5, y_pos + h*0.5, sig_symbol,
ha='center', va='bottom',
color='black', fontsize=font_size, fontweight='bold')
ax.text((x1+x2)*0.5, y_pos - h*2, p_text,
ha='center', va='top',
color='black', fontsize=font_size-1,
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=1))
current_ymax = ax.get_ylim()[1]
if y_pos + h*3 > current_ymax:
ax.set_ylim(top=y_pos + h*4)
files = [f for f in os.listdir(file_path) if f.endswith(('.xlsx', '.xls'))]
for i in files:
ipath = os.path.join(file_path, i)
excel_file = pd.ExcelFile(ipath)
sheet_names = excel_file.sheet_names
num_sheets = len(sheet_names)
cols = min(3, num_sheets) # 每行最多3个子图
rows = (num_sheets + cols - 1) // cols # 向上取整
fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
if num_sheets == 1:
axes = [axes]
else:
axes = axes.flatten()
for idx, sheet in enumerate(sheet_names):
try:
data = pd.read_excel(ipath, sheet_name=sheet)
data = data.dropna(subset=['name'])
data['group'] = data['category'].fillna(method='ffill')
ax = axes[idx]
sns.boxplot(data=data, x='group', y='area', ax=ax, width=0.6,
linewidth=1.5, fliersize=4)
sns.stripplot(data=data, x='group', y='area', ax=ax,
color='black', alpha=0.5, size=4, jitter=0.2)
add_stat_annotation(ax, data, 'group', 'area')
ax.set_title(f"{sheet}", fontsize=12, pad=10, fontweight='bold')
ax.set_xlabel("Group", fontsize=10, labelpad=8)
ax.set_ylabel("Area", fontsize=10, labelpad=8)
ax.tick_params(axis='both', which='major', labelsize=9)
current_ymin, current_ymax = ax.get_ylim()
data_max = data['area'].max()
new_ymax = max(current_ymax, data_max * 1.3)
ax.set_ylim(current_ymin, new_ymax)
if idx == 0:
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], marker='o', color='w', label='Data Points',
markerfacecolor='black', markersize=6, alpha=0.5),
Line2D([0], [0], color='black', lw=2, label='Median'),
Line2D([0], [0], color='black', lw=1, linestyle='--', label='IQR')
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=8)
except Exception as e:
print(f"Error processing sheet {sheet} in file {i}: {str(e)}")
axes[idx].axis('off') # 出错时关闭该子图
continue
for ax in axes[num_sheets:]:
ax.axis('off')
plt.tight_layout(pad=3.0, h_pad=2.0, w_pad=2.0)
output_path = os.path.join(fig_path, f"{os.path.splitext(i)[0]}.pdf")
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
def add_stat_annotation(ax, data, x_col, y_col):
"""添加统计显著性标注,格式与示例图片一致"""
groups = data[x_col].unique()
if len(groups) < 2:
return
# 生成所有可能的组间比较组合
pairs = list(combinations(groups, 2))
# 计算y轴范围用于确定标注位置
y_min, y_max = ax.get_ylim()
y_range = y_max - y_min
for i, (group1, group2) in enumerate(pairs):
# 获取两组数据
data1 = data[data[x_col] == group1][y_col]
data2 = data[data[x_col] == group2][y_col]
# 执行Mann-Whitney U检验
try:
stat, p_value = mannwhitneyu(data1, data2, alternative='two-sided')
except:
continue
# 计算标注位置
x1, x2 = groups.tolist().index(group1), groups.tolist().index(group2)
y_pos = y_max + (0.05 + 0.12 * i) * y_range # 调整间距
h = 0.02 * y_range
font_size = 9
# 绘制标注线
ax.plot([x1, x1, x2, x2], [y_pos-h, y_pos, y_pos, y_pos-h],
lw=1.5, color='black')
# 格式化p值为科学计数法(保留2位小数)
if p_value < 0.0001:
p_text = f"p = {p_value:.2e}"
else:
p_text = f"p = {p_value:.4f}"
# 添加p值文本(在比较线下方)
ax.text((x1+x2)*0.5, y_pos - h*1.5, p_text,
ha='center', va='top',
color='black', fontsize=font_size)
# 调整y轴上限以适应标注
current_ymax = ax.get_ylim()[1]
if y_pos + h > current_ymax:
ax.set_ylim(top=y_pos + h*2)
print(f"所有图表已保存至: {fig_path}")