Better, Stronger, Faster: Tackling the Trilemma in MLLM-based Segmentation with Simultaneous Textual Mask Prediction
Paper
β’
2512.00395
β’
Published
β’
2
STAMP is a Multimodal Large Language Model (MLLM) capable of performing simultaneous dialogue and segmentation. It resolves the conflict between text generation and mask prediction, achieving high performance and fast inference speed.
Note: This model relies on the codebase and custom architecture defined in the GitHub repository. You must clone the repository to run inference.
Clone the repository and install the required dependencies:
git clone https://github.com/HKUST-LongGroup/STAMP.git
cd STAMP
# Create environment (Recommended)
conda create -n STAMP python=3.10
conda activate STAMP
# Install dependencies
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
# download SAM-H to YOUR_SAM_PATH
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
# Make sure you are in the STAMP directory
python run_seg_ref.py \
--model-path "JiaZL/STAMP-2B-uni" \
--image-file "images/horses.png" \
--sam_path "HCMUE-Research/SAM-vit-h/sam_vit_h_4b8939.pth" \
--query "Please segment the white horse in the image."
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
# Import local modules
from segment_predictor_cache import GenerativeSegmenter
from model.segment_anything import sam_model_registry, SamPredictor
# [New] Import utility functions for SAM prompt generation
from eval.utils import compute_logits_from_mask, masks_sample_points
# --- Configuration ---
# Model paths
MODEL_PATH = "JiaZL/STAMP-2B-uni"
SAM_PATH = "HCMUE-Research/SAM-vit-h/sam_vit_h_4b8939.pth"
IMAGE_PATH = "images/horses.png"
QUERY = "Please segment the white horse in the image."
USE_SAM = True # Enable SAM refinement (Recommended: True)
# --- Load Models ---
print(f"Loading STAMP model from {MODEL_PATH}...")
segmenter = GenerativeSegmenter(
MODEL_PATH,
device_map="cuda",
min_pixels=1024 * 28 * 28,
max_pixels=1280 * 28 * 28
)
print(f"Loading SAM model from {SAM_PATH}...")
sam = sam_model_registry["vit_h"](checkpoint=SAM_PATH)
sam = sam.to(dtype=torch.float32, device='cuda')
predictor = SamPredictor(sam)
# --- Inference ---
image = Image.open(IMAGE_PATH).convert("RGB")
w_ori, h_ori = image.size
with torch.inference_mode():
# 1. Set SAM image embedding (Compute once for efficiency)
if USE_SAM:
predictor.set_image(np.array(image))
# 2. Generate Coarse Mask using STAMP
print("Generating coarse mask with STAMP...")
segmentation_masks, response_text = segmenter.generate_with_segmentation(
image, QUERY
)
print(f"Model Response: {response_text}")
if not segmentation_masks or len(segmentation_masks) == 0:
print("No mask generated.")
exit()
# Extract the first mask
mask = segmentation_masks[0]
# Resize coarse mask to original image size [H, W]
mask_pred = F.interpolate(
mask.unsqueeze(0).unsqueeze(0).double(),
size=(h_ori, w_ori),
mode='nearest'
).squeeze(0).squeeze(0)
# --- SAM Refinement ---
final_mask = np.zeros((h_ori, w_ori), dtype=np.float32)
if USE_SAM:
print("Refining mask with SAM...")
# Get all unique class IDs (excluding background 0)
unique_classes = torch.unique(mask_pred)
for class_id in unique_classes:
if class_id == 0: continue
# Get binary mask for the current class
binary_mask = (mask_pred == class_id).double().cpu()
try:
# Generate Prompts (Logits and Points) from the coarse mask
logits = compute_logits_from_mask(binary_mask)
point_coords, point_labels = masks_sample_points(binary_mask)
# First pass prediction
sam_mask, _, logit = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
mask_input=logits,
multimask_output=False
)
# Iterative refinement (Standard Cascade: 2 times)
for _ in range(2):
sam_mask, _, logit = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
mask_input=logit,
multimask_output=False
)
# Merge results into the final mask
current_refined_mask = sam_mask[0].astype(np.float32)
final_mask = np.maximum(final_mask, current_refined_mask)
except Exception as e:
print(f"SAM Error for class {class_id}: {e}")
# Fallback to coarse mask if SAM fails
final_mask = np.maximum(final_mask, binary_mask.numpy())
else:
# Use coarse mask directly if SAM is disabled
final_mask = mask_pred.cpu().numpy()
# --- Save Result ---
# Convert to 0-255 uint8 format for saving
mask_uint8 = (final_mask > 0).astype(np.uint8) * 255
base_name = os.path.basename(IMAGE_PATH).split(".")[0]
save_name = f"{base_name}_mask_refined.png"
cv2.imwrite(save_name, mask_uint8)
print(f"Saved refined mask to {save_name}")
π Citation
If you find this work useful, please cite our paper:
@misc{liu2025betterstrongerfastertackling,
title={Better, Stronger, Faster: Tackling the Trilemma in MLLM-based Segmentation with Simultaneous Textual Mask Prediction},
author={Jiazhen Liu and Mingkuan Feng and Long Chen},
year={2025},
eprint={2512.00395},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2512.00395},
}