class Load_CLIP_Vision_With_Projection:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "pretrained_clip_vision_folder_path": ("STRING", {"default": "./pretrained_weights/image_encoder/", "multiline": False}),
            },
        }

    RETURN_TYPES = (
        "CLIP_VISION",
    )
    RETURN_NAMES = (
        "clip_vision",
    )
    FUNCTION = "load_clip_vision"

    CATEGORY = "AnimateAnyone-Evolved/loaders"
    
    def load_clip_vision(self, pretrained_clip_vision_folder_path):
        if not os.path.isabs(pretrained_clip_vision_folder_path):
            pretrained_clip_vision_folder_path = os.path.join(ROOT_PATH, pretrained_clip_vision_folder_path)
        
        clip_vision = CLIPVisionModelWithProjection.from_pretrained(
            pretrained_clip_vision_folder_path
        ).to(dtype=WEIGHT_DETYPE, device=DEVICE)
        
        return (clip_vision,)
    
class CLIP_Vision_With_Projection_Encode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "clip_vision": ("CLIP_VISION",),
                "reference_image": ("IMAGE",), 
            },
        }

    RETURN_TYPES = (
        "CLIP_VISION_OUTPUT",
    )
    RETURN_NAMES = (
        "clip_vision_output",
    )
    FUNCTION = "clip_vision_encode"

    CATEGORY = "AnimateAnyone-Evolved/processors"
    OUTPUT_NODE = True
    def clip_vision_encode(self, clip_vision, reference_image):
        clip_image_processor = CLIPImageProcessor()
        
        clip_image = clip_image_processor.preprocess(
            reference_image, do_resize=True, size={"height": CONFIG.clip_img_height, "width": CONFIG.clip_img_width}, return_tensors="pt"
        ).pixel_values
        clip_image_embeds = clip_vision(
            clip_image.to(DEVICE, dtype=clip_vision.dtype)
        ).image_embeds
        encoder_hidden_states = clip_image_embeds.unsqueeze(1)
        #print(f"encoder_hidden_states.shape: {encoder_hidden_states.shape}\n clip_image_embeds.shape: {clip_image_embeds.shape}")
        #encoder_hidden_states.shape: torch.Size([1, 1, 768]) clip_image_embeds.shape: torch.Size([1, 768])
        uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
        
        with open("_Test_CLIP_VISION_Projected_OUTPUT.txt", "w") as file1:
            # Writing data to a file
            file1.write(f"[CLIP_VISION_Projected_OUTPUT]: \nclip_image_embeds shape: {clip_image_embeds.shape}\nclip_image_embeds: {clip_image_embeds}")
        
        return (clip_image_embeds,)