import os
import pandas as pd
import xml.etree.ElementTree as ET
from PIL import Image
class CustomDataset:
def __init__(self, data_dir, annotations_dir, image_dir, annotation_format):
self.data_dir = data_dir
self.annotations_dir = annotations_dir
self.image_dir = image_dir
self.annotation_format = annotation_format
self.image_paths = []
self.annotations = []
self._load_data()
def _load_data(self):
for filename in os.listdir(os.path.join(self.data_dir, self.annotations_dir)):
if filename.endswith(f'.{self.annotation_format}'):
annotation_path = os.path.join(self.data_dir, self.annotations_dir, filename)
image_filename = os.path.splitext(filename)[0] + '.jpg'
image_path = os.path.join(self.data_dir, self.image_dir, image_filename)
self.image_paths.append(image_path)
self.annotations.append(self._parse_annotation(annotation_path))
def _parse_annotation(self, annotation_path):
if self.annotation_format == 'xml':
return self._parse_xml(annotation_path)
elif self.annotation_format == 'csv':
return self._parse_csv(annotation_path)
elif self.annotation_format == 'txt':
return self._parse_txt(annotation_path)
else:
raise ValueError("Unsupported annotation format")
def _parse_xml(self, xml_path):
tree = ET.parse(xml_path)
root = tree.getroot()
# Parse XML structure to extract bounding box information
# Adjust this part based on your XML structure
annotations = [] # List to store bounding box annotations
for object_elem in root.findall('object'):
# Extract object name and bounding box coordinates
name = object_elem.find('name').text
bbox_elem = object_elem.find('bndbox')
xmin = int(bbox_elem.find('xmin').text)
ymin = int(bbox_elem.find('ymin').text)
xmax = int(bbox_elem.find('xmax').text)
ymax = int(bbox_elem.find('ymax').text)
annotations.append({'name': name, 'bbox': (xmin, ymin, xmax, ymax)})
return annotations
def _parse_csv(self, csv_path):
df = pd.read_csv(csv_path)
# Process the CSV dataframe to extract bounding box information
# Adjust this part based on your CSV structure
annotations = [] # List to store bounding box annotations
for index, row in df.iterrows():
# Extract relevant information and append to annotations
annotations.append({'name': row['class'], 'bbox': (row['xmin'], row['ymin'], row['xmax'], row['ymax'])})
return annotations
def _parse_txt(self, txt_path):
with open(txt_path, 'r') as f:
lines = f.readlines()
# Process the TXT lines to extract bounding box information
# Adjust this part based on your TXT structure
annotations = [] # List to store bounding box annotations
for line in lines:
# Extract relevant information and append to annotations
parts = line.strip().split()
annotations.append({'name': parts[0], 'bbox': (int(parts[1]), int(parts[2]), int(parts[3]), int(parts[4]))})
return annotations
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
annotation = self.annotations[idx]
# Load and preprocess the image
image = Image.open(image_path)
# Apply any preprocessing steps to the image if needed
return {'image': image, 'annotation': annotation}
# Example usage
data_dir = 'path/to/dataset'
annotations_dir = 'annotations'
image_dir = 'images'
annotation_format = 'xml' # Change this based on the actual format
dataset = CustomDataset(data_dir, annotations_dir, image_dir, annotation_format)
# Access the data
sample_data = dataset[0]
sample_image = sample_data['image']
sample_annotation = sample_data['annotation']