from pathlib import Path
import os, sys, json, torch, importlib

from huggingface_hub import snapshot_download
from open_clip.factory import _MODEL_CONFIGS
from open_clip import create_model_and_transforms, get_tokenizer, build_zero_shot_classifier

import safetensors.torch as st
from torchvision.transforms import Normalize
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


def loader(study_path: str, num_slices: int):
    """
    study_path: folder containing per-slice tensors saved with torch.save()
                each file is a [C, H, W] or [C, H, W, 1] tensor in [0, 255]
    returns:   image tensor of shape [1, n_scans, 1, D, H, W]
    """
    imgs = []
    for scan in [os.path.join(study_path, p) for p in os.listdir(study_path)]:
        # load image tensor
        img = torch.load(scan, weights_only=True)
        if len(img.shape) == 4:
            # [C, H, W, 1] -> [C, H, W]
            img = img[:, :, :, 0]
        img = img.float() / 255.0  # [C, H, W]
        _, h, w = img.shape

        # pad to square
        size = max(h, w)
        pad_h = size - h
        pad_w = size - w
        left = pad_w // 2
        right = pad_w - left
        top = pad_h // 2
        bottom = pad_h - top
        img = torch.nn.functional.pad(
            img, (left, right, top, bottom), mode="constant", value=0
        )

        # resize to 256, make depth=num_slices, center-crop to 224
        img = torch.nn.functional.interpolate(
            img[None, ...], size=(256, 256), mode="bilinear"
        )[0]
        img = torch.nn.functional.interpolate(
            img[None, None, ...], size=(num_slices, 256, 256), mode="nearest-exact"
        )[0, 0]
        img = img[:, 16:240, 16:240]  # [D, 224, 224]

        # normalize (scalar mean/std across slices-as-channels)
        normalizer = Normalize(
            torch.as_tensor(IMAGENET_DEFAULT_MEAN).mean(),
            torch.as_tensor(IMAGENET_DEFAULT_STD).mean(),
        )
        img = normalizer(img[None, ...])  # [1, D, H, W]
        imgs.append(img)

    # [1, n_scans, 1, D, H, W]
    return torch.stack(imgs, dim=0)[None, ...]


# ---- constants ----
REPO_ID = "Zch0414/hlip-2025-10-08"
MODEL_NAME = "ablate_seqposemb_vit_base_multiscan_h2_dinotxt1568"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------

# 1) download snapshot and make vendored package importable
repo_dir = Path(snapshot_download(repo_id=REPO_ID))
sys.path.append(str(repo_dir))
importlib.invalidate_caches()
print(f"[OK] repo_dir = {repo_dir}")

# 2) import your registry so timm/OpenCLIP knows the custom visual encoder
import hlip.visual_encoder  # registers the custom visual encoder with timm

# 3) load OpenCLIP config and register it under MODEL_NAME
cfg = json.loads((repo_dir / "open_clip_config.json").read_text())
_MODEL_CONFIGS[MODEL_NAME] = cfg["model_cfg"]
print("[OK] registered MODEL_CONFIGS key:", MODEL_NAME)

# 4) build model and tokenizer
model, _, _ = create_model_and_transforms(
    MODEL_NAME,
    device=DEVICE,
    output_dict=True,
)
tokenizer = get_tokenizer(MODEL_NAME)
print("[OK] model built on", DEVICE)
print("[OK] tokenizer ready")

# 5) load pretrained weights from the snapshot (prefer safetensors)
weight_path = None
for fname in ("model.safetensors", "pytorch_model.bin"):
    p = repo_dir / fname
    if p.exists():
        weight_path = p
        break
assert weight_path is not None, "No weights found in repo snapshot."

if weight_path.suffix == ".safetensors":
    state_dict = st.load_file(str(weight_path))
else:
    state_dict = torch.load(str(weight_path), map_location="cpu")

missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(
    f"[OK] loaded weights: {weight_path.name} | "
    f"missing={len(missing)} unexpected={len(unexpected)}"
)

# 6) build zero-shot classifier for brain MRI labels
from hlip.zeroshot_metadata_pubbrain5 import PROMPTS, TEMPLATES

classifier = build_zero_shot_classifier(
    model,
    tokenizer=tokenizer,
    classnames=PROMPTS["prompt"],
    templates=TEMPLATES["template"],
    num_classes_per_batch=None,  # use all classes
    device=DEVICE,
    use_tqdm=False,
)

# 7) example data and inference
# Here we use an example study stored inside the repo under docs/BraTS-Glioma
# Replace this with your own study folder of per-slice tensors.
study_dir = repo_dir / "docs" / "BraTS-Glioma"
image = loader(str(study_dir), num_slices=48).to(DEVICE, non_blocking=True)

model.eval()
with torch.no_grad():
    output = model(image=image)  # image: [1, n_scans, 1, D, H, W]
    image_features = output["image_features"]  # [1, feature_dim]

    # use the model's learned logit scale
    logit_scale = model.logit_scale.exp()
    logits_per_image = logit_scale * (image_features @ classifier)  # [1, num_classes]
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()

print("Zero-shot class probabilities:")
for i, prompt in enumerate(PROMPTS["prompt"]):
    print(f"{prompt}: {probs[0, i]:.4f}")

# output:
# no significant abnormalities: 0.0001
# acute stroke: 0.0114
# glioma: 0.9860
# meningioma: 0.0022
# metastasis: 0.0003
Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support