155 lines
5.6 KiB
Python
155 lines
5.6 KiB
Python
|
import os
|
|||
|
import xml.etree.ElementTree as ET
|
|||
|
from xml.dom import minidom
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
def convert_txt_to_xml(txt_path, image_path, xml_path):
|
|||
|
# 打开图像,获取尺寸及通道数
|
|||
|
with Image.open(image_path) as img:
|
|||
|
img_width, img_height = img.size
|
|||
|
mode = img.mode
|
|||
|
if mode == "RGB":
|
|||
|
depth = 3
|
|||
|
elif mode == "L":
|
|||
|
depth = 1
|
|||
|
else:
|
|||
|
depth = 3 # 默认3
|
|||
|
|
|||
|
# 读取TXT文件中的标注信息
|
|||
|
with open(txt_path, "r") as f:
|
|||
|
lines = f.readlines()
|
|||
|
|
|||
|
# 类别编号到标签名称的映射
|
|||
|
id_to_label = {
|
|||
|
0: "light",
|
|||
|
1: "red",
|
|||
|
2: "yellow",
|
|||
|
3: "green"
|
|||
|
}
|
|||
|
|
|||
|
# 构建XML结构
|
|||
|
annotation = ET.Element("annotation", verified="yes")
|
|||
|
|
|||
|
folder_elem = ET.SubElement(annotation, "folder")
|
|||
|
folder_elem.text = os.path.basename(os.path.dirname(image_path))
|
|||
|
|
|||
|
filename_elem = ET.SubElement(annotation, "filename")
|
|||
|
image_filename = os.path.basename(image_path)
|
|||
|
filename_elem.text = image_filename
|
|||
|
|
|||
|
path_elem = ET.SubElement(annotation, "path")
|
|||
|
path_elem.text = os.path.abspath(image_path)
|
|||
|
|
|||
|
source_elem = ET.SubElement(annotation, "source")
|
|||
|
database_elem = ET.SubElement(source_elem, "database")
|
|||
|
database_elem.text = "Unknown"
|
|||
|
|
|||
|
size_elem = ET.SubElement(annotation, "size")
|
|||
|
width_elem = ET.SubElement(size_elem, "width")
|
|||
|
width_elem.text = str(img_width)
|
|||
|
height_elem = ET.SubElement(size_elem, "height")
|
|||
|
height_elem.text = str(img_height)
|
|||
|
depth_elem = ET.SubElement(size_elem, "depth")
|
|||
|
depth_elem.text = str(depth)
|
|||
|
|
|||
|
segmented_elem = ET.SubElement(annotation, "segmented")
|
|||
|
segmented_elem.text = "0"
|
|||
|
|
|||
|
# 逐行处理TXT标注
|
|||
|
for line in lines:
|
|||
|
parts = line.strip().split()
|
|||
|
if len(parts) != 5:
|
|||
|
print(f"跳过格式不正确的行: {line.strip()}")
|
|||
|
continue
|
|||
|
class_id_str, x_center_norm_str, y_center_norm_str, width_norm_str, height_norm_str = parts
|
|||
|
try:
|
|||
|
class_id = int(class_id_str)
|
|||
|
x_center_norm = float(x_center_norm_str)
|
|||
|
y_center_norm = float(y_center_norm_str)
|
|||
|
width_norm = float(width_norm_str)
|
|||
|
height_norm = float(height_norm_str)
|
|||
|
except Exception as e:
|
|||
|
print(f"转换数据出错 {line.strip()}: {e}")
|
|||
|
continue
|
|||
|
|
|||
|
if class_id not in id_to_label:
|
|||
|
print(f"未知标签编号 {class_id},跳过该行")
|
|||
|
continue
|
|||
|
|
|||
|
label = id_to_label[class_id]
|
|||
|
|
|||
|
# 计算绝对坐标
|
|||
|
x_center = x_center_norm * img_width
|
|||
|
y_center = y_center_norm * img_height
|
|||
|
bbox_width = width_norm * img_width
|
|||
|
bbox_height = height_norm * img_height
|
|||
|
|
|||
|
xmin = int(x_center - bbox_width / 2)
|
|||
|
ymin = int(y_center - bbox_height / 2)
|
|||
|
xmax = int(x_center + bbox_width / 2)
|
|||
|
ymax = int(y_center + bbox_height / 2)
|
|||
|
|
|||
|
# 构建object节点
|
|||
|
object_elem = ET.SubElement(annotation, "object")
|
|||
|
name_elem = ET.SubElement(object_elem, "name")
|
|||
|
name_elem.text = label
|
|||
|
pose_elem = ET.SubElement(object_elem, "pose")
|
|||
|
pose_elem.text = "Unspecified"
|
|||
|
truncated_elem = ET.SubElement(object_elem, "truncated")
|
|||
|
truncated_elem.text = "0"
|
|||
|
difficult_elem = ET.SubElement(object_elem, "difficult")
|
|||
|
difficult_elem.text = "0"
|
|||
|
|
|||
|
bndbox_elem = ET.SubElement(object_elem, "bndbox")
|
|||
|
xmin_elem = ET.SubElement(bndbox_elem, "xmin")
|
|||
|
xmin_elem.text = str(xmin)
|
|||
|
ymin_elem = ET.SubElement(bndbox_elem, "ymin")
|
|||
|
ymin_elem.text = str(ymin)
|
|||
|
xmax_elem = ET.SubElement(bndbox_elem, "xmax")
|
|||
|
xmax_elem.text = str(xmax)
|
|||
|
ymax_elem = ET.SubElement(bndbox_elem, "ymax")
|
|||
|
ymax_elem.text = str(ymax)
|
|||
|
|
|||
|
# 美化输出的XML字符串
|
|||
|
xml_str = ET.tostring(annotation, encoding="utf-8")
|
|||
|
parsed_str = minidom.parseString(xml_str)
|
|||
|
pretty_xml_str = parsed_str.toprettyxml(indent=" ")
|
|||
|
|
|||
|
with open(xml_path, "w", encoding="utf-8") as f:
|
|||
|
f.write(pretty_xml_str)
|
|||
|
|
|||
|
def batch_convert_txt_to_xml(txt_dir, image_dir, output_dir, image_ext=".jpg"):
|
|||
|
"""
|
|||
|
批量将TXT转换为XML:
|
|||
|
txt_dir: TXT文件所在目录
|
|||
|
image_dir: 对应图像所在目录(图片文件名与TXT同名,后缀可指定)
|
|||
|
output_dir: XML文件保存目录
|
|||
|
image_ext: 图像文件扩展名(默认 .jpg)
|
|||
|
"""
|
|||
|
if not os.path.exists(output_dir):
|
|||
|
os.makedirs(output_dir)
|
|||
|
|
|||
|
for file in os.listdir(txt_dir):
|
|||
|
if file.lower().endswith(".txt"):
|
|||
|
txt_path = os.path.join(txt_dir, file)
|
|||
|
base_name = os.path.splitext(file)[0]
|
|||
|
image_filename = base_name + image_ext
|
|||
|
image_path = os.path.join(image_dir, image_filename)
|
|||
|
if not os.path.exists(image_path):
|
|||
|
print(f"图像文件不存在: {image_path},跳过 {txt_path}")
|
|||
|
continue
|
|||
|
xml_filename = base_name + ".xml"
|
|||
|
xml_path = os.path.join(output_dir, xml_filename)
|
|||
|
try:
|
|||
|
convert_txt_to_xml(txt_path, image_path, xml_path)
|
|||
|
print(f"成功转换: {txt_path} --> {xml_path}")
|
|||
|
except Exception as e:
|
|||
|
print(f"处理 {txt_path} 时出错: {e}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 请修改以下目录为实际路径
|
|||
|
txt_directory = "path_to_txt_directory" # TXT文件所在目录
|
|||
|
image_directory = "path_to_image_directory" # 图像文件所在目录
|
|||
|
output_directory = "path_to_output_xml_directory" # 输出XML文件保存目录
|
|||
|
batch_convert_txt_to_xml(txt_directory, image_directory, output_directory, image_ext=".jpg")
|