meta sam segment anything
Wednesday, 7 February 2024 -import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from tkinter import Tk
from tkinter.filedialog import askopenfilename
# Ensure the 'segment_anything' library is in your Python path
sys.path.append("..") # Adjust this path if necessary to locate the library
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# Function to display annotations/masks
def show_anns(anns):
if len(anns) == 0:
return
print("Displaying annotations...")
sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:, :, 3] = 0 # Set alpha channel to 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]]) # Random color with some transparency
img[m] = color_mask
ax.imshow(img)
# Hide the main Tkinter window
Tk().withdraw()
# Show an "Open" dialog box and return the path to the selected file
print("Please select an image file:")
image_path = askopenfilename() # Show an "Open" dialog box and return the path to the selected file
if not image_path:
print("No file selected. Exiting...")
sys.exit()
image = cv2.imread(image_path)
if image is None:
print(f"Failed to load image from '{image_path}'. Exiting...")
sys.exit()
# Resize the image if it's too large (e.g., longer side to 1024 pixels)
max_size = 1024
scale_ratio = min(max_size / image.shape[0], max_size / image.shape[1])
if scale_ratio < 1: # Only downscale if either dimension is greater than max_size
new_size = (int(image.shape[1] * scale_ratio), int(image.shape[0] * scale_ratio))
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
print(f"Resized image to {new_size} for faster processing.")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('off')
plt.show()
plt.close() # Close the plot to allow the script to continue
# Define model type, checkpoint path, and device
model_type = "vit_b" # Adjust this to the specific model type you're using
sam_checkpoint = r"C:\Users\john\Documents\meta sam\sam_vit_b_01ec64.pth" # Ensure this path is correct
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SAM model...")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
print("SAM model loaded.")
# Initialize SamAutomaticMaskGenerator with the loaded SAM model
mask_generator = SamAutomaticMaskGenerator(sam)
# Now, you can use mask_generator to generate masks
print("Generating masks...")
masks = mask_generator.generate(image)
print(f"Generated {len(masks)} masks.")
# Display generated masks
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()