Code Snippets Repository

Home

meta sam segment anything

Wednesday, 7 February 2024 -

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from tkinter import Tk
from tkinter.filedialog import askopenfilename

# Ensure the 'segment_anything' library is in your Python path
sys.path.append("..")  # Adjust this path if necessary to locate the library
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# Function to display annotations/masks
def show_anns(anns):
    if len(anns) == 0:
        return
    print("Displaying annotations...")
    sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0  # Set alpha channel to 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])  # Random color with some transparency
        img[m] = color_mask
    ax.imshow(img)

# Hide the main Tkinter window
Tk().withdraw()

# Show an "Open" dialog box and return the path to the selected file
print("Please select an image file:")
image_path = askopenfilename()  # Show an "Open" dialog box and return the path to the selected file

if not image_path:
    print("No file selected. Exiting...")
    sys.exit()

image = cv2.imread(image_path)
if image is None:
    print(f"Failed to load image from '{image_path}'. Exiting...")
    sys.exit()

# Resize the image if it's too large (e.g., longer side to 1024 pixels)
max_size = 1024
scale_ratio = min(max_size / image.shape[0], max_size / image.shape[1])
if scale_ratio < 1:  # Only downscale if either dimension is greater than max_size
    new_size = (int(image.shape[1] * scale_ratio), int(image.shape[0] * scale_ratio))
    image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
    print(f"Resized image to {new_size} for faster processing.")

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('off')
plt.show()
plt.close()  # Close the plot to allow the script to continue

# Define model type, checkpoint path, and device
model_type = "vit_b"  # Adjust this to the specific model type you're using
sam_checkpoint = r"C:\Users\john\Documents\meta sam\sam_vit_b_01ec64.pth"  # Ensure this path is correct
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading SAM model...")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
print("SAM model loaded.")

# Initialize SamAutomaticMaskGenerator with the loaded SAM model
mask_generator = SamAutomaticMaskGenerator(sam)

# Now, you can use mask_generator to generate masks
print("Generating masks...")
masks = mask_generator.generate(image)
print(f"Generated {len(masks)} masks.")


# Display generated masks
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()