训练数据集使用的图像量的规模往往较大,而现实受各种因素影响,我们所用于标注的图像可能是不够的,因此将一张图片,通过旋转、平移、调整亮度等方式来丰富我们的数据集。
此代码可随机挑选几种图像处理方式,来达到增强数据集的效果,可以通过修改参数,来固定几种特定的自己想选择的图像处理方法。
代码:
import os
import cv2
import numpy as np
import random
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
from PIL import Image, ImageTk, ImageEnhance, ImageFilter, ImageOps
import matplotlib.pyplot as plt
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import threading
import sys
import traceback
class ImageAugmentorApp:
def __init__(self, root):
self.root = root
self.root.title("图像增强工具")
self.root.geometry("900x700")
self.root.configure(bg='#f0f0f0')
# 默认值设置
self.input_path = "F:/桌面/daibiaozhu" # 默认输入路径
self.output_path = "F:/桌面/zengqiang" # 默认输出路径
self.num_augmentations = tk.IntVar(value=5)
self.target_width = tk.IntVar(value=0)
self.target_height = tk.IntVar(value=0)
self.preserve_ratio = tk.BooleanVar(value=True)
self.max_workers = tk.IntVar(value=cpu_count())
# 创建UI
self.create_widgets()
# 初始化增强器
self.augmentor = ImageAugmentor()
# 自动设置路径
self.input_entry.delete(0, tk.END)
self.input_entry.insert(0, self.input_path)
self.output_entry.delete(0, tk.END)
self.output_entry.insert(0, self.output_path)
def create_widgets(self):
# 主框架
main_frame = tk.Frame(self.root, bg='#f0f0f0')
main_frame.pack(fill=tk.BOTH, expand=True, padx=20, pady=20)
# 输入输出设置
settings_frame = tk.LabelFrame(main_frame, text="设置", bg='#e0e0e0', padx=10, pady=10)
settings_frame.pack(fill=tk.X, pady=10)
# 输入路径
tk.Label(settings_frame, text="输入路径:", bg='#e0e0e0').grid(row=0, column=0, padx=5, pady=5, sticky='e')
self.input_entry = tk.Entry(settings_frame, width=50)
self.input_entry.grid(row=0, column=1, padx=5, pady=5, sticky='we')
tk.Button(settings_frame, text="浏览...", command=self.select_input_path).grid(row=0, column=2, padx=5, pady=5)
# 输出路径
tk.Label(settings_frame, text="输出路径:", bg='#e0e0e0').grid(row=1, column=0, padx=5, pady=5, sticky='e')
self.output_entry = tk.Entry(settings_frame, width=50)
self.output_entry.grid(row=1, column=1, padx=5, pady=5, sticky='we')
tk.Button(settings_frame, text="浏览...", command=self.select_output_path).grid(row=1, column=2, padx=5, pady=5)
# 增强数量
tk.Label(settings_frame, text="每张图增强数量:", bg='#e0e0e0').grid(row=2, column=0, padx=5, pady=5, sticky='e')
tk.Spinbox(settings_frame, from_=1, to=50, textvariable=self.num_augmentations, width=5).grid(row=2, column=1,
padx=5, pady=5,
sticky='w')
# 目标尺寸
tk.Label(settings_frame, text="目标尺寸:", bg='#e0e0e0').grid(row=3, column=0, padx=5, pady=5, sticky='e')
tk.Label(settings_frame, text="宽:", bg='#e0e0e0').grid(row=3, column=1, padx=(5, 0), pady=5, sticky='e')
tk.Entry(settings_frame, textvariable=self.target_width, width=5).grid(row=3, column=1, padx=(30, 0), pady=5,
sticky='w')
tk.Label(settings_frame, text="高:", bg='#e0e0e0').grid(row=3, column=1, padx=(80, 0), pady=5, sticky='e')
tk.Entry(settings_frame, textvariable=self.target_height, width=5).grid(row=3, column=1, padx=(105, 0), pady=5,
sticky='w')
tk.Checkbutton(settings_frame, text="保持宽高比", variable=self.preserve_ratio, bg='#e0e0e0').grid(row=3,
column=2,
padx=5,
pady=5)
# 进程数
tk.Label(settings_frame, text="并行进程数:", bg='#e0e0e0').grid(row=4, column=0, padx=5, pady=5, sticky='e')
tk.Spinbox(settings_frame, from_=1, to=cpu_count(), textvariable=self.max_workers, width=5).grid(row=4,
column=1,
padx=5, pady=5,
sticky='w')
# 按钮区域
button_frame = tk.Frame(main_frame, bg='#f0f0f0')
button_frame.pack(fill=tk.X, pady=10)
tk.Button(button_frame, text="预览增强效果", command=self.preview_augmentations,
bg="#2196F3", fg="white", font=("Arial", 10, "bold")).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="开始增强", command=self.start_augmentation,
bg="#4CAF50", fg="white", font=("Arial", 10, "bold")).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="打开输出目录", command=self.open_output_dir,
bg="#FF9800", fg="white", font=("Arial", 10, "bold")).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="退出", command=self.root.quit,
bg="#F44336", fg="white", font=("Arial", 10, "bold")).pack(side=tk.RIGHT, padx=5)
# 预览区域
preview_frame = tk.LabelFrame(main_frame, text="预览", bg='white')
preview_frame.pack(fill=tk.BOTH, expand=True, pady=10)
self.canvas = tk.Canvas(preview_frame, bg='white')
self.canvas.pack(fill=tk.BOTH, expand=True)
# 进度条
self.progress_frame = tk.Frame(main_frame, bg='#f0f0f0')
self.progress_frame.pack(fill=tk.X, pady=5)
self.progress_label = tk.Label(self.progress_frame, text="准备就绪", bg='#f0f0f0')
self.progress_label.pack(fill=tk.X)
self.progress_bar = ttk.Progressbar(self.progress_frame, orient=tk.HORIZONTAL, length=500, mode='determinate')
# 状态栏
self.status_var = tk.StringVar(value="就绪")
status_bar = tk.Label(self.root, textvariable=self.status_var, bd=1, relief=tk.SUNKEN, anchor=tk.W,
bg='#e0e0e0')
status_bar.pack(side=tk.BOTTOM, fill=tk.X)
# 显示初始状态
self.status_var.set(f"输入目录: {self.input_path} | 输出目录: {self.output_path}")
def select_input_path(self):
path = filedialog.askdirectory(title="选择输入目录", initialdir=self.input_path)
if path:
self.input_path = path
self.input_entry.delete(0, tk.END)
self.input_entry.insert(0, path)
self.status_var.set(f"输入目录: {path}")
def select_output_path(self):
path = filedialog.askdirectory(title="选择输出目录", initialdir=self.output_path)
if path:
self.output_path = path
self.output_entry.delete(0, tk.END)
self.output_entry.insert(0, path)
self.status_var.set(f"输出目录: {path}")
def open_output_dir(self):
if os.path.exists(self.output_path):
os.startfile(self.output_path)
else:
messagebox.showinfo("信息", "输出目录尚未创建")
def preview_augmentations(self):
if not os.path.exists(self.input_path):
messagebox.showwarning("警告", "输入目录不存在")
return
# 获取输入目录中的一张图片
image_files = self.get_image_files(self.input_path)
if not image_files:
messagebox.showwarning("警告", "输入目录中没有找到图片")
return
sample_image = image_files[0]
# 初始化增强器
target_size = (self.target_width.get(), self.target_height.get()) if self.target_width.get() > 0 else None
self.augmentor = ImageAugmentor(
output_dir=self.output_path,
num_augmentations=self.num_augmentations.get(),
target_size=target_size,
preserve_ratio=self.preserve_ratio.get()
)
# 显示预览
self.show_preview(sample_image)
def show_preview(self, image_path):
self.canvas.delete("all")
# 加载图片
try:
img = Image.open(image_path)
if img.mode != 'RGB':
img = img.convert('RGB')
except Exception as e:
messagebox.showerror("错误", f"无法加载图片: {str(e)}")
return
# 调整预览大小
canvas_width = self.canvas.winfo_width()
canvas_height = self.canvas.winfo_height()
if canvas_width < 10 or canvas_height < 10:
canvas_width = 800
canvas_height = 400
img.thumbnail((canvas_width, canvas_height), Image.LANCZOS)
self.preview_image = ImageTk.PhotoImage(img)
# 显示原始图片
self.canvas.create_image(
canvas_width // 4,
canvas_height // 2,
image=self.preview_image
)
self.canvas.create_text(
canvas_width // 4, 20,
text="原始图片",
fill="blue", font=("Arial", 10, "bold")
)
# 生成并显示增强预览
for i in range(4):
try:
# 创建增强图片
augmented_img = self.augmentor.apply_random_augmentations(img.copy())
augmented_img.thumbnail((canvas_width // 2, canvas_height // 2), Image.LANCZOS)
preview_img = ImageTk.PhotoImage(augmented_img)
# 保存引用避免被垃圾回收
setattr(self, f"preview_aug_{i}", preview_img)
# 显示增强图片
x_pos = canvas_width // 4 * 3 if i % 2 == 0 else canvas_width // 4 * 2
y_pos = canvas_height // 4 if i < 2 else canvas_height // 4 * 3
self.canvas.create_image(
x_pos, y_pos,
image=preview_img
)
self.canvas.create_text(
x_pos, y_pos - canvas_height // 4 + 10,
text=f"增强示例 {i + 1}",
fill="red", font=("Arial", 10, "bold")
)
except Exception as e:
print(f"生成预览时出错: {str(e)}")
def start_augmentation(self):
if not os.path.exists(self.input_path):
messagebox.showwarning("警告", "输入目录不存在")
return
# 获取所有图片文件
image_files = self.get_image_files(self.input_path)
if not image_files:
messagebox.showwarning("警告", "输入目录中没有找到图片")
return
# 确保输出目录存在
os.makedirs(self.output_path, exist_ok=True)
# 更新UI
self.progress_bar.pack(fill=tk.X, pady=5)
self.progress_bar['value'] = 0
self.progress_label.config(text="正在增强图片...")
self.status_var.set(f"开始增强图片: 共 {len(image_files)} 张")
# 初始化增强器
target_size = (self.target_width.get(), self.target_height.get()) if self.target_width.get() > 0 else None
self.augmentor = ImageAugmentor(
output_dir=self.output_path,
num_augmentations=self.num_augmentations.get(),
target_size=target_size,
preserve_ratio=self.preserve_ratio.get(),
max_workers=self.max_workers.get()
)
# 在新线程中运行增强过程
threading.Thread(target=self.run_augmentation, args=(image_files,), daemon=True).start()
def run_augmentation(self, image_files):
try:
total_files = len(image_files)
processed = 0
# 显示进度条
self.progress_bar['maximum'] = total_files
# 处理图片
for image_path in tqdm(image_files, desc="图像增强", file=sys.stdout):
try:
self.augmentor.augment_image(image_path)
processed += 1
self.progress_bar['value'] = processed
self.status_var.set(f"处理中: {processed}/{total_files} ({processed / total_files * 100:.1f}%)")
except Exception as e:
print(f"处理图片 {image_path} 时出错: {str(e)}")
# 打印详细错误信息
traceback.print_exc()
# 完成处理
self.progress_label.config(text=f"完成! 共增强 {total_files} 张图片")
self.status_var.set(f"完成! 共增强 {total_files} 张图片,输出到 {self.output_path}")
messagebox.showinfo("完成", f"图片增强完成!\n共处理 {total_files} 张图片\n输出到: {self.output_path}")
except Exception as e:
messagebox.showerror("错误", f"处理过程中出错: {str(e)}")
self.status_var.set(f"错误: {str(e)}")
# 打印详细错误信息
traceback.print_exc()
finally:
self.progress_bar.pack_forget()
def get_image_files(self, directory):
"""获取目录中所有支持的图片文件"""
image_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.gif')
image_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.lower().endswith(image_extensions):
image_files.append(os.path.join(root, file))
return image_files
class ImageAugmentor:
def __init__(self, output_dir='augmented_images', num_augmentations=5,
target_size=None, preserve_ratio=True, max_workers=None):
self.output_dir = output_dir
self.num_augmentations = num_augmentations
self.target_size = target_size
self.preserve_ratio = preserve_ratio
self.max_workers = max_workers if max_workers else max(1, cpu_count() - 1)
os.makedirs(output_dir, exist_ok=True)
# 增强操作及其参数范围
self.augmentations = [
(self.random_rotation, {'angle_range': (-45, 45)}),
(self.random_flip, {}),
(self.random_brightness, {'factor_range': (0.6, 1.4)}),
(self.random_contrast, {'factor_range': (0.7, 1.3)}),
(self.random_color, {'factor_range': (0.8, 1.2)}),
(self.random_saturation, {'factor_range': (0.7, 1.3)}),
(self.random_blur, {'max_radius': 2}),
(self.random_sharpness, {'factor_range': (0.5, 2.0)}),
(self.random_crop, {'crop_ratio_range': (0.7, 1.0)}),
(self.random_zoom, {'zoom_range': (0.8, 1.2)}),
(self.random_scale, {'scale_range': (0.8, 1.2)}),
(self.random_translation, {'max_translation': 0.2}),
(self.random_gaussian_noise, {'intensity_range': (5, 30)}),
(self.random_salt_pepper_noise, {'amount_range': (0.001, 0.01)}),
]
def _to_int_coords(self, *coords):
"""确保坐标值为整数"""
return tuple(int(round(c)) for c in coords)
def apply_random_augmentations(self, image):
"""应用随机选择的增强操作"""
# 随机选择1-4个增强操作
num_augs = random.randint(1, 4)
selected_augs = random.sample(self.augmentations, num_augs)
for aug_func, params in selected_augs:
# 从参数范围中随机选择参数值
augmented_params = {}
for param_name, value_range in params.items():
if isinstance(value_range, tuple):
# 在范围内随机选择值
if param_name in ['angle_range', 'max_radius']:
# 整数参数
augmented_params[param_name] = random.randint(*value_range)
else:
# 浮点数参数
augmented_params[param_name] = random.uniform(*value_range)
else:
# 固定参数
augmented_params[param_name] = value_range
# 应用增强
try:
image = aug_func(image, **augmented_params)
except Exception as e:
print(f"应用增强操作 {aug_func.__name__} 时出错: {str(e)}")
# 发生错误时返回原始图像
return image
return image
def preprocess_image(self, image_path):
"""读取并预处理图像"""
# 使用PIL读取图像以保留EXIF信息
with Image.open(image_path) as img:
# 处理EXIF方向信息
try:
exif = img._getexif()
if exif:
orientation = exif.get(0x0112)
if orientation == 3:
img = img.rotate(180, expand=True)
elif orientation == 6:
img = img.rotate(270, expand=True)
elif orientation == 8:
img = img.rotate(90, expand=True)
except Exception:
pass
# 转换为RGB(如果是RGBA或其他模式)
if img.mode != 'RGB':
img = img.convert('RGB')
# 调整大小(如果需要)
if self.target_size:
if self.preserve_ratio:
img.thumbnail(self.target_size, Image.LANCZOS)
else:
img = img.resize(self.target_size, Image.LANCZOS)
return img.copy()
def save_image(self, image, original_path, suffix):
"""保存增强后的图像"""
base_name = os.path.basename(original_path)
name, ext = os.path.splitext(base_name)
output_path = os.path.join(self.output_dir, f"{name}_{suffix}{ext}")
image.save(output_path)
return output_path
def augment_image(self, image_path):
"""对单张图像应用增强"""
try:
original_img = self.preprocess_image(image_path)
for i in range(self.num_augmentations):
# 创建原始图像的副本
img = original_img.copy()
# 应用随机增强
augmented_img = self.apply_random_augmentations(img)
# 保存增强后的图像
self.save_image(augmented_img, image_path, f"aug_{i + 1}")
return True
except Exception as e:
print(f"处理图片 {image_path} 时出错: {str(e)}")
traceback.print_exc()
return False
def augment_images(self, image_paths):
"""批量增强图像"""
if len(image_paths) > 1 and self.max_workers > 1:
# 使用多进程并行处理
with Pool(processes=self.max_workers) as pool:
results = []
with tqdm(total=len(image_paths), desc="批量增强") as pbar:
for result in pool.imap_unordered(self.augment_image, image_paths):
results.append(result)
pbar.update()
# 统计成功/失败数量
success_count = sum(1 for r in results if r)
failure_count = len(image_paths) - success_count
print(f"处理完成: {success_count} 成功, {failure_count} 失败")
else:
# 单进程处理
success_count = 0
failure_count = 0
for path in tqdm(image_paths, desc="增强图像"):
if self.augment_image(path):
success_count += 1
else:
failure_count += 1
print(f"处理完成: {success_count} 成功, {failure_count} 失败")
# --------------------- 增强操作实现 ---------------------
def random_rotation(self, image, angle_range=(-30, 30)):
"""随机旋转图像"""
angle = random.uniform(angle_range[0], angle_range[1])
return image.rotate(angle, expand=True, fillcolor=(128, 128, 128))
def random_flip(self, image):
"""随机翻转图像"""
flip_options = [
None, # 不翻转
Image.FLIP_LEFT_RIGHT, # 水平翻转
Image.FLIP_TOP_BOTTOM, # 垂直翻转
]
flip_type = random.choice(flip_options)
if flip_type:
return image.transpose(flip_type)
return image
def random_brightness(self, image, factor_range=(0.7, 1.3)):
"""随机调整亮度"""
factor = random.uniform(factor_range[0], factor_range[1])
enhancer = ImageEnhance.Brightness(image)
return enhancer.enhance(factor)
def random_contrast(self, image, factor_range=(0.7, 1.3)):
"""随机调整对比度"""
factor = random.uniform(factor_range[0], factor_range[1])
enhancer = ImageEnhance.Contrast(image)
return enhancer.enhance(factor)
def random_color(self, image, factor_range=(0.8, 1.2)):
"""随机调整色彩饱和度"""
factor = random.uniform(factor_range[0], factor_range[1])
enhancer = ImageEnhance.Color(image)
return enhancer.enhance(factor)
def random_saturation(self, image, factor_range=(0.7, 1.3)):
"""随机调整饱和度 (替代方法)"""
# 转换为HSV空间调整饱和度
img_arr = np.array(image)
hsv = cv2.cvtColor(img_arr, cv2.COLOR_RGB2HSV).astype(np.float32)
factor = random.uniform(factor_range[0], factor_range[1])
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
hsv = hsv.astype(np.uint8)
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
return Image.fromarray(rgb)
def random_blur(self, image, max_radius=2):
"""随机模糊"""
radius = random.uniform(0.5, max_radius)
return image.filter(ImageFilter.GaussianBlur(radius))
def random_sharpness(self, image, factor_range=(0.5, 2.0)):
"""随机调整锐度"""
factor = random.uniform(factor_range[0], factor_range[1])
enhancer = ImageEnhance.Sharpness(image)
return enhancer.enhance(factor)
def random_crop(self, image, crop_ratio_range=(0.7, 1.0)):
"""随机裁剪"""
crop_ratio = random.uniform(crop_ratio_range[0], crop_ratio_range[1])
width, height = image.size
# 计算裁剪区域
new_width = int(width * crop_ratio)
new_height = int(height * crop_ratio)
left = random.randint(0, width - new_width)
top = random.randint(0, height - new_height)
right = left + new_width
bottom = top + new_height
# 确保坐标值为整数
crop_box = self._to_int_coords(left, top, right, bottom)
# 裁剪图像
cropped = image.crop(crop_box)
# 如果需要,调整回原始尺寸
if crop_ratio < 1.0:
return cropped.resize((width, height), Image.LANCZOS)
return cropped
def random_zoom(self, image, zoom_range=(0.8, 1.2)):
"""随机缩放"""
zoom_factor = random.uniform(zoom_range[0], zoom_range[1])
width, height = image.size
# 计算缩放后的尺寸
new_width = int(width * zoom_factor)
new_height = int(height * zoom_factor)
# 缩放图像
resized = image.resize((new_width, new_height), Image.LANCZOS)
# 如果缩放后小于原始尺寸,填充到原始尺寸
if zoom_factor < 1.0:
new_img = Image.new('RGB', (width, height), (128, 128, 128))
paste_x = (width - new_width) // 2
paste_y = (height - new_height) // 2
# 确保坐标值为整数
paste_pos = self._to_int_coords(paste_x, paste_y)
new_img.paste(resized, paste_pos)
return new_img
# 如果缩放后大于原始尺寸,裁剪到原始尺寸
left = (new_width - width) // 2
top = (new_height - height) // 2
right = left + width
bottom = top + height
# 确保坐标值为整数
crop_box = self._to_int_coords(left, top, right, bottom)
return resized.crop(crop_box)
def random_scale(self, image, scale_range=(0.8, 1.2)):
"""随机缩放 (保持宽高比)"""
scale_factor = random.uniform(scale_range[0], scale_range[1])
width, height = image.size
# 计算新尺寸
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
# 缩放图像
scaled = image.resize((new_width, new_height), Image.LANCZOS)
# 创建新图像并居中放置缩放后的图像
new_img = Image.new('RGB', (width, height), (128, 128, 128))
paste_x = (width - new_width) // 2
paste_y = (height - new_height) // 2
# 确保坐标值为整数
paste_pos = self._to_int_coords(paste_x, paste_y)
new_img.paste(scaled, paste_pos)
return new_img
def random_translation(self, image, max_translation=0.2):
"""随机平移"""
width, height = image.size
# 计算最大平移像素
max_x = int(width * max_translation)
max_y = int(height * max_translation)
# 随机平移量
dx = random.randint(-max_x, max_x)
dy = random.randint(-max_y, max_y)
# 创建新图像
new_img = Image.new('RGB', (width, height), (128, 128, 128))
# 计算粘贴位置
paste_x = max(0, dx)
paste_y = max(0, dy)
crop_x = max(0, -dx)
crop_y = max(0, -dy)
# 裁剪并粘贴
crop_width = width - abs(dx)
crop_height = height - abs(dy)
# 确保坐标值为整数
crop_box = self._to_int_coords(crop_x, crop_y, crop_x + crop_width, crop_y + crop_height)
paste_pos = self._to_int_coords(paste_x, paste_y)
cropped = image.crop(crop_box)
new_img.paste(cropped, paste_pos)
return new_img
def random_gaussian_noise(self, image, intensity_range=(5, 30)):
"""添加高斯噪声"""
intensity = random.uniform(intensity_range[0], intensity_range[1])
img_arr = np.array(image)
noise = np.random.normal(0, intensity, img_arr.shape).astype(np.int16)
noisy_arr = np.clip(img_arr.astype(np.int16) + noise, 0, 255).astype(np.uint8)
return Image.fromarray(noisy_arr)
def random_salt_pepper_noise(self, image, amount_range=(0.001, 0.01)):
"""添加椒盐噪声"""
amount = random.uniform(amount_range[0], amount_range[1])
img_arr = np.array(image)
noisy_arr = img_arr.copy()
# 添加盐噪声 (白色)
salt_coords = np.random.rand(*img_arr.shape[:2]) < amount / 2
noisy_arr[salt_coords] = 255
# 添加椒噪声 (黑色)
pepper_coords = np.random.rand(*img_arr.shape[:2]) < amount / 2
noisy_arr[pepper_coords] = 0
return Image.fromarray(noisy_arr)
if __name__ == "__main__":
# 创建主窗口
root = tk.Tk()
# 设置应用程序图标(可选)
try:
root.iconbitmap("icon.ico") # 如果有图标文件
except:
pass
# 创建应用实例
app = ImageAugmentorApp(root)
# 启动主循环
root.mainloop()
展示: