83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
|
import os
|
||
|
import xml.etree.ElementTree as ET
|
||
|
|
||
|
def convert_xml_to_txt(xml_path, txt_path):
|
||
|
# 解析XML文件
|
||
|
tree = ET.parse(xml_path)
|
||
|
root = tree.getroot()
|
||
|
|
||
|
# 获取图像尺寸信息
|
||
|
size = root.find("size")
|
||
|
img_width = int(size.find("width").text)
|
||
|
img_height = int(size.find("height").text)
|
||
|
|
||
|
# 标签映射字典
|
||
|
label_mapping = {
|
||
|
"light": 0,
|
||
|
"red": 1,
|
||
|
"yellow": 2,
|
||
|
"green": 3
|
||
|
}
|
||
|
|
||
|
lines = []
|
||
|
# 遍历每个object节点
|
||
|
for obj in root.findall("object"):
|
||
|
# 获取并转换标签名称,统一转为小写
|
||
|
class_name = obj.find("name").text.strip().lower()
|
||
|
if class_name in label_mapping:
|
||
|
class_id = label_mapping[class_name]
|
||
|
else:
|
||
|
print(f"未知标签 {class_name},文件 {xml_path} 中跳过该对象")
|
||
|
continue
|
||
|
|
||
|
# 获取边界框坐标
|
||
|
bndbox = obj.find("bndbox")
|
||
|
xmin = float(bndbox.find("xmin").text)
|
||
|
ymin = float(bndbox.find("ymin").text)
|
||
|
xmax = float(bndbox.find("xmax").text)
|
||
|
ymax = float(bndbox.find("ymax").text)
|
||
|
|
||
|
# 计算边界框中心点和宽高
|
||
|
x_center = (xmin + xmax) / 2.0
|
||
|
y_center = (ymin + ymax) / 2.0
|
||
|
bbox_width = xmax - xmin
|
||
|
bbox_height = ymax - ymin
|
||
|
|
||
|
# 归一化
|
||
|
x_center_norm = x_center / img_width
|
||
|
y_center_norm = y_center / img_height
|
||
|
width_norm = bbox_width / img_width
|
||
|
height_norm = bbox_height / img_height
|
||
|
|
||
|
# 每行格式:类别编号 x_center y_center width height
|
||
|
line = f"{class_id} {x_center_norm:.6f} {y_center_norm:.6f} {width_norm:.6f} {height_norm:.6f}"
|
||
|
lines.append(line)
|
||
|
|
||
|
# 将转换结果写入TXT文件
|
||
|
with open(txt_path, "w") as f:
|
||
|
for line in lines:
|
||
|
f.write(line + "\n")
|
||
|
|
||
|
def batch_convert_xml_to_txt(input_dir, output_dir):
|
||
|
# 如果输出目录不存在则创建
|
||
|
if not os.path.exists(output_dir):
|
||
|
os.makedirs(output_dir)
|
||
|
|
||
|
# 遍历输入目录下所有XML文件
|
||
|
for file in os.listdir(input_dir):
|
||
|
if file.lower().endswith(".xml"):
|
||
|
xml_path = os.path.join(input_dir, file)
|
||
|
txt_file = os.path.splitext(file)[0] + ".txt"
|
||
|
txt_path = os.path.join(output_dir, txt_file)
|
||
|
try:
|
||
|
convert_xml_to_txt(xml_path, txt_path)
|
||
|
print(f"成功转换: {xml_path} --> {txt_path}")
|
||
|
except Exception as e:
|
||
|
print(f"处理 {xml_path} 时出错: {e}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# 请替换下面的目录为实际路径
|
||
|
input_directory = r"C:\Users\10561\Downloads\Compressed\xml" # XML文件所在的目录
|
||
|
output_directory = r"C:\Users\10561\Downloads\Compressed\txt" # 转换后txt文件保存的目录
|
||
|
batch_convert_xml_to_txt(input_directory, output_directory)
|