Image Segmentation
PyTorch
sam2
English
computer-vision
segmentation
few-shot-learning
zero-shot-learning
clip
Instructions to use ParallelLLC/Segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sam2
How to use ParallelLLC/Segmentation with sam2:
# Use SAM2 with images import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained(ParallelLLC/Segmentation) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): predictor.set_image(<your_image>) masks, _, _ = predictor.predict(<input_prompts>)# Use SAM2 with videos import torch from sam2.sam2_video_predictor import SAM2VideoPredictor predictor = SAM2VideoPredictor.from_pretrained(ParallelLLC/Segmentation) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): state = predictor.init_state(<your_video>) # add new prompts and instantly get the output on the same frame frame_idx, object_ids, masks = predictor.add_new_points(state, <your_prompts>): # propagate the prompts to get masklets throughout the video for frame_idx, object_ids, masks in predictor.propagate_in_video(state): ... - Notebooks
- Google Colab
- Kaggle
| """ | |
| Data Loader Utilities | |
| This module provides data loading utilities for different domains | |
| (satellite, fashion, robotics) with support for few-shot and zero-shot learning. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| import json | |
| from typing import List, Dict, Tuple, Optional | |
| import random | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchvision.transforms as transforms | |
| from torchvision.transforms import functional as F | |
| import cv2 | |
| class BaseDataLoader: | |
| """Base class for domain-specific data loaders.""" | |
| def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): | |
| self.data_dir = data_dir | |
| self.image_size = image_size | |
| # Standard transforms | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| self.mask_transform = transforms.Compose([ | |
| transforms.Resize(image_size, interpolation=transforms.InterpolationMode.NEAREST), | |
| transforms.ToTensor() | |
| ]) | |
| def load_image(self, image_path: str) -> torch.Tensor: | |
| """Load and preprocess image.""" | |
| image = Image.open(image_path).convert('RGB') | |
| return self.transform(image) | |
| def load_mask(self, mask_path: str) -> torch.Tensor: | |
| """Load and preprocess mask.""" | |
| mask = Image.open(mask_path).convert('L') | |
| return self.mask_transform(mask) | |
| def get_random_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Get a random sample from the dataset.""" | |
| raise NotImplementedError | |
| def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| """Get examples for a specific class.""" | |
| raise NotImplementedError | |
| class SatelliteDataLoader(BaseDataLoader): | |
| """Data loader for satellite imagery segmentation.""" | |
| def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): | |
| super().__init__(data_dir, image_size) | |
| # Satellite-specific classes | |
| self.classes = ["building", "road", "vegetation", "water"] | |
| self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} | |
| # Load dataset structure | |
| self.load_dataset_structure() | |
| def load_dataset_structure(self): | |
| """Load dataset structure and file paths.""" | |
| self.images = [] | |
| self.masks = [] | |
| self.class_samples = {cls: [] for cls in self.classes} | |
| # Assuming structure: data_dir/images/ and data_dir/masks/ | |
| images_dir = os.path.join(self.data_dir, "images") | |
| masks_dir = os.path.join(self.data_dir, "masks") | |
| if not os.path.exists(images_dir) or not os.path.exists(masks_dir): | |
| # Create dummy data for demonstration | |
| self.create_dummy_data() | |
| return | |
| # Load real data | |
| for filename in os.listdir(images_dir): | |
| if filename.endswith(('.jpg', '.png', '.tif')): | |
| image_path = os.path.join(images_dir, filename) | |
| mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) | |
| if os.path.exists(mask_path): | |
| self.images.append(image_path) | |
| self.masks.append(mask_path) | |
| # Categorize by class (simplified) | |
| self.categorize_sample(image_path, mask_path) | |
| def create_dummy_data(self): | |
| """Create dummy satellite data for demonstration.""" | |
| print("Creating dummy satellite data...") | |
| # Create dummy directory structure | |
| os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) | |
| os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) | |
| # Generate dummy images and masks | |
| for i in range(100): | |
| # Create dummy image (satellite-like) | |
| image = np.random.randint(50, 200, (512, 512, 3), dtype=np.uint8) | |
| # Add some structure to make it look like satellite imagery | |
| # Buildings (rectangular shapes) | |
| for _ in range(5): | |
| x, y = np.random.randint(0, 400), np.random.randint(0, 400) | |
| w, h = np.random.randint(20, 80), np.random.randint(20, 80) | |
| image[y:y+h, x:x+w] = np.random.randint(100, 150, 3) | |
| # Roads (linear structures) | |
| for _ in range(3): | |
| x, y = np.random.randint(0, 512), np.random.randint(0, 512) | |
| length = np.random.randint(50, 150) | |
| angle = np.random.uniform(0, 2*np.pi) | |
| for j in range(length): | |
| px = int(x + j * np.cos(angle)) | |
| py = int(y + j * np.sin(angle)) | |
| if 0 <= px < 512 and 0 <= py < 512: | |
| image[py, px] = [80, 80, 80] | |
| # Save image | |
| image_path = os.path.join(self.data_dir, "images", f"satellite_{i:03d}.jpg") | |
| Image.fromarray(image).save(image_path) | |
| # Create corresponding mask | |
| mask = np.zeros((512, 512), dtype=np.uint8) | |
| # Add building masks | |
| for _ in range(3): | |
| x, y = np.random.randint(0, 400), np.random.randint(0, 400) | |
| w, h = np.random.randint(20, 80), np.random.randint(20, 80) | |
| mask[y:y+h, x:x+w] = 1 # Building class | |
| # Add road masks | |
| for _ in range(2): | |
| x, y = np.random.randint(0, 512), np.random.randint(0, 512) | |
| length = np.random.randint(50, 150) | |
| angle = np.random.uniform(0, 2*np.pi) | |
| for j in range(length): | |
| px = int(x + j * np.cos(angle)) | |
| py = int(y + j * np.sin(angle)) | |
| if 0 <= px < 512 and 0 <= py < 512: | |
| mask[py, px] = 2 # Road class | |
| # Save mask | |
| mask_path = os.path.join(self.data_dir, "masks", f"satellite_{i:03d}_mask.png") | |
| Image.fromarray(mask * 85).save(mask_path) # Scale for visibility | |
| # Add to lists | |
| self.images.append(image_path) | |
| self.masks.append(mask_path) | |
| # Categorize | |
| self.categorize_sample(image_path, mask_path) | |
| def categorize_sample(self, image_path: str, mask_path: str): | |
| """Categorize sample by dominant class.""" | |
| mask = np.array(Image.open(mask_path)) | |
| # Count pixels for each class | |
| class_counts = {} | |
| for i, class_name in enumerate(self.classes): | |
| class_counts[class_name] = np.sum(mask == i) | |
| # Find dominant class | |
| dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] | |
| self.class_samples[dominant_class].append((image_path, mask_path)) | |
| def get_random_query(self, class_name: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Get a random query image and mask for a specific class.""" | |
| if class_name not in self.class_samples or not self.class_samples[class_name]: | |
| # Fallback to any available sample | |
| idx = random.randint(0, len(self.images) - 1) | |
| image = self.load_image(self.images[idx]) | |
| mask = self.load_mask(self.masks[idx]) | |
| return image, mask | |
| # Get random sample from specified class | |
| image_path, mask_path = random.choice(self.class_samples[class_name]) | |
| image = self.load_image(image_path) | |
| mask = self.load_mask(mask_path) | |
| return image, mask | |
| def get_class_examples(self, class_name: str, num_examples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| """Get examples for a specific class.""" | |
| examples = [] | |
| if class_name in self.class_samples: | |
| available_samples = self.class_samples[class_name] | |
| selected_samples = random.sample(available_samples, min(num_examples, len(available_samples))) | |
| for image_path, mask_path in selected_samples: | |
| image = self.load_image(image_path) | |
| mask = self.load_mask(mask_path) | |
| examples.append((image, mask)) | |
| return examples | |
| class FashionDataLoader(BaseDataLoader): | |
| """Data loader for fashion segmentation.""" | |
| def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): | |
| super().__init__(data_dir, image_size) | |
| # Fashion-specific classes | |
| self.classes = ["shirt", "pants", "dress", "shoes"] | |
| self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} | |
| # Load dataset structure | |
| self.load_dataset_structure() | |
| def load_dataset_structure(self): | |
| """Load dataset structure and file paths.""" | |
| self.images = [] | |
| self.masks = [] | |
| self.class_samples = {cls: [] for cls in self.classes} | |
| # Assuming structure: data_dir/images/ and data_dir/masks/ | |
| images_dir = os.path.join(self.data_dir, "images") | |
| masks_dir = os.path.join(self.data_dir, "masks") | |
| if not os.path.exists(images_dir) or not os.path.exists(masks_dir): | |
| # Create dummy data for demonstration | |
| self.create_dummy_data() | |
| return | |
| # Load real data | |
| for filename in os.listdir(images_dir): | |
| if filename.endswith(('.jpg', '.png')): | |
| image_path = os.path.join(images_dir, filename) | |
| mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) | |
| if os.path.exists(mask_path): | |
| self.images.append(image_path) | |
| self.masks.append(mask_path) | |
| # Categorize by class | |
| self.categorize_sample(image_path, mask_path) | |
| def create_dummy_data(self): | |
| """Create dummy fashion data for demonstration.""" | |
| print("Creating dummy fashion data...") | |
| # Create dummy directory structure | |
| os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) | |
| os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) | |
| # Generate dummy images and masks | |
| for i in range(100): | |
| # Create dummy image (fashion-like) | |
| image = np.random.randint(200, 255, (512, 512, 3), dtype=np.uint8) | |
| # Add fashion items | |
| class_id = i % len(self.classes) | |
| if class_id == 0: # Shirt | |
| # Create shirt-like shape | |
| center_x, center_y = 256, 256 | |
| width, height = 150, 200 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 150, 200] | |
| elif class_id == 1: # Pants | |
| # Create pants-like shape | |
| center_x, center_y = 256, 300 | |
| width, height = 120, 180 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [50, 100, 150] | |
| elif class_id == 2: # Dress | |
| # Create dress-like shape | |
| center_x, center_y = 256, 250 | |
| width, height = 140, 220 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 100, 150] | |
| else: # Shoes | |
| # Create shoes-like shape | |
| center_x, center_y = 256, 400 | |
| width, height = 100, 60 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [80, 80, 80] | |
| # Save image | |
| image_path = os.path.join(self.data_dir, "images", f"fashion_{i:03d}.jpg") | |
| Image.fromarray(image).save(image_path) | |
| # Create corresponding mask | |
| mask = np.zeros((512, 512), dtype=np.uint8) | |
| # Add mask for the fashion item | |
| if class_id == 0: # Shirt | |
| center_x, center_y = 256, 256 | |
| width, height = 150, 200 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1 | |
| elif class_id == 1: # Pants | |
| center_x, center_y = 256, 300 | |
| width, height = 120, 180 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2 | |
| elif class_id == 2: # Dress | |
| center_x, center_y = 256, 250 | |
| width, height = 140, 220 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3 | |
| else: # Shoes | |
| center_x, center_y = 256, 400 | |
| width, height = 100, 60 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 4 | |
| # Save mask | |
| mask_path = os.path.join(self.data_dir, "masks", f"fashion_{i:03d}_mask.png") | |
| Image.fromarray(mask * 51).save(mask_path) # Scale for visibility | |
| # Add to lists | |
| self.images.append(image_path) | |
| self.masks.append(mask_path) | |
| # Categorize | |
| self.categorize_sample(image_path, mask_path) | |
| def categorize_sample(self, image_path: str, mask_path: str): | |
| """Categorize sample by dominant class.""" | |
| mask = np.array(Image.open(mask_path)) | |
| # Count pixels for each class | |
| class_counts = {} | |
| for i, class_name in enumerate(self.classes): | |
| class_counts[class_name] = np.sum(mask == (i + 1)) # +1 because 0 is background | |
| # Find dominant class | |
| dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] | |
| self.class_samples[dominant_class].append((image_path, mask_path)) | |
| def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Get a random test sample with ground truth masks.""" | |
| idx = random.randint(0, len(self.images) - 1) | |
| image = self.load_image(self.images[idx]) | |
| mask = self.load_mask(self.masks[idx]) | |
| # Convert single mask to multi-class dictionary | |
| ground_truth = {} | |
| for i, class_name in enumerate(self.classes): | |
| class_mask = (mask == (i + 1)).float() # +1 because 0 is background | |
| ground_truth[class_name] = class_mask | |
| return image, ground_truth | |
| class RoboticsDataLoader(BaseDataLoader): | |
| """Data loader for robotics segmentation.""" | |
| def __init__(self, data_dir: str, image_size: Tuple[int, int] = (512, 512)): | |
| super().__init__(data_dir, image_size) | |
| # Robotics-specific classes | |
| self.classes = ["robot", "tool", "safety"] | |
| self.class_to_id = {cls: i for i, cls in enumerate(self.classes)} | |
| # Load dataset structure | |
| self.load_dataset_structure() | |
| def load_dataset_structure(self): | |
| """Load dataset structure and file paths.""" | |
| self.images = [] | |
| self.masks = [] | |
| self.class_samples = {cls: [] for cls in self.classes} | |
| # Assuming structure: data_dir/images/ and data_dir/masks/ | |
| images_dir = os.path.join(self.data_dir, "images") | |
| masks_dir = os.path.join(self.data_dir, "masks") | |
| if not os.path.exists(images_dir) or not os.path.exists(masks_dir): | |
| # Create dummy data for demonstration | |
| self.create_dummy_data() | |
| return | |
| # Load real data | |
| for filename in os.listdir(images_dir): | |
| if filename.endswith(('.jpg', '.png')): | |
| image_path = os.path.join(images_dir, filename) | |
| mask_path = os.path.join(masks_dir, filename.replace('.jpg', '_mask.png')) | |
| if os.path.exists(mask_path): | |
| self.images.append(image_path) | |
| self.masks.append(mask_path) | |
| # Categorize by class | |
| self.categorize_sample(image_path, mask_path) | |
| def create_dummy_data(self): | |
| """Create dummy robotics data for demonstration.""" | |
| print("Creating dummy robotics data...") | |
| # Create dummy directory structure | |
| os.makedirs(os.path.join(self.data_dir, "images"), exist_ok=True) | |
| os.makedirs(os.path.join(self.data_dir, "masks"), exist_ok=True) | |
| # Generate dummy images and masks | |
| for i in range(100): | |
| # Create dummy image (robotics-like) | |
| image = np.random.randint(50, 150, (512, 512, 3), dtype=np.uint8) | |
| # Add robotics elements | |
| class_id = i % len(self.classes) | |
| if class_id == 0: # Robot | |
| # Create robot-like shape | |
| center_x, center_y = 256, 256 | |
| width, height = 120, 160 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [100, 100, 100] | |
| elif class_id == 1: # Tool | |
| # Create tool-like shape | |
| center_x, center_y = 256, 256 | |
| width, height = 80, 120 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [150, 100, 50] | |
| else: # Safety equipment | |
| # Create safety equipment-like shape | |
| center_x, center_y = 256, 256 | |
| width, height = 100, 100 | |
| image[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = [200, 200, 50] | |
| # Save image | |
| image_path = os.path.join(self.data_dir, "images", f"robotics_{i:03d}.jpg") | |
| Image.fromarray(image).save(image_path) | |
| # Create corresponding mask | |
| mask = np.zeros((512, 512), dtype=np.uint8) | |
| # Add mask for the robotics element | |
| if class_id == 0: # Robot | |
| center_x, center_y = 256, 256 | |
| width, height = 120, 160 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 1 | |
| elif class_id == 1: # Tool | |
| center_x, center_y = 256, 256 | |
| width, height = 80, 120 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 2 | |
| else: # Safety equipment | |
| center_x, center_y = 256, 256 | |
| width, height = 100, 100 | |
| mask[center_y-height//2:center_y+height//2, center_x-width//2:center_x+width//2] = 3 | |
| # Save mask | |
| mask_path = os.path.join(self.data_dir, "masks", f"robotics_{i:03d}_mask.png") | |
| Image.fromarray(mask * 85).save(mask_path) # Scale for visibility | |
| # Add to lists | |
| self.images.append(image_path) | |
| self.masks.append(mask_path) | |
| # Categorize | |
| self.categorize_sample(image_path, mask_path) | |
| def categorize_sample(self, image_path: str, mask_path: str): | |
| """Categorize sample by dominant class.""" | |
| mask = np.array(Image.open(mask_path)) | |
| # Count pixels for each class | |
| class_counts = {} | |
| for i, class_name in enumerate(self.classes): | |
| class_counts[class_name] = np.sum(mask == (i + 1)) # +1 because 0 is background | |
| # Find dominant class | |
| dominant_class = max(class_counts.items(), key=lambda x: x[1])[0] | |
| self.class_samples[dominant_class].append((image_path, mask_path)) | |
| def get_test_sample(self) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
| """Get a random test sample with ground truth masks.""" | |
| idx = random.randint(0, len(self.images) - 1) | |
| image = self.load_image(self.images[idx]) | |
| mask = self.load_mask(self.masks[idx]) | |
| # Convert single mask to multi-class dictionary | |
| ground_truth = {} | |
| for i, class_name in enumerate(self.classes): | |
| class_mask = (mask == (i + 1)).float() # +1 because 0 is background | |
| ground_truth[class_name] = class_mask | |
| return image, ground_truth |