import torch
from .yolo_utils import subcore
import folder_paths


class UltralyticsDetectorProvider:
    @classmethod
    def INPUT_TYPES(s):
        bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("ultralytics_bbox")]
        segms = ["segm/"+x for x in folder_paths.get_filename_list("ultralytics_segm")]
        return {"required": {"model_name": (bboxs + segms, )}}
    RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR")
    FUNCTION = "doit"
    CATEGORY = "ImpactPack"

    def doit(self, model_name):
        model_path = folder_paths.get_full_path("ultralytics", model_name)

        if model_path is None:
            print(f"[Impact Subpack] model file '{model_name}' is not found in one of the following directories:")

            cands = []
            cands.extend(folder_paths.get_folder_paths("ultralytics"))
            if model_name.startswith('bbox/'):
                cands.extend(folder_paths.get_folder_paths("ultralytics_bbox"))
            elif model_name.startswith('segm/'):
                cands.extend(folder_paths.get_folder_paths("ultralytics_segm"))

            formatted_cands = "\n\t".join(cands)
            print(f'\t{formatted_cands}\n')

            raise ValueError(f"[Impact Subpack] model file '{model_name}' is not found.")

        model = subcore.load_yolo(model_path)

        if model_name.startswith("bbox"):
            return subcore.UltraBBoxDetector(model), subcore.NO_SEGM_DETECTOR()
        else:
            return subcore.UltraBBoxDetector(model), subcore.UltraSegmDetector(model)


class SegmDetectorCombined_batch:
    DESCRIPTION = """
    Original plugin: Impact-Pack
    Original node: SegmDetectorCombined
    change 1: batch detection of masks
    Change 2: Supports both modes simultaneously
    原始插件：Impact-Pack
    原始节点：SegmDetectorCombined
    更改1：支持批次
    更改2：支持两个模式
    """
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image": ("IMAGE", ),
                    "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
                    "dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
                    },
                "optional": {
                    "bbox_detector": ("BBOX_DETECTOR", ),
                    "segm_detector": ("SEGM_DETECTOR", ),
                    }
                }

    RETURN_TYPES = ("MASK",)
    FUNCTION = "doit"
    CATEGORY = CATEGORY_NAME

    def doit(self, image, threshold, dilation, segm_detector=None, bbox_detector=None):
        #图像预处理
        if image.dim() == 3:
            image = image.unsqueeze(0)
        mask = torch.zeros((0,*image.shape[1:-1]), dtype=torch.float, device="cpu")
        mask_0 = torch.zeros((1,*image.shape[1:-1]), dtype=torch.float32, device="cpu")

        #检测器类型
        seg_class = segm_detector
        if segm_detector is None:
            seg_class = bbox_detector
        else:
            print("Error: No detector selected, Return empty mask !")
            return(mask_0.unsqueeze(0),)

        #运行检测
        for i in range(image.shape[0]):
            mask_temp = seg_class.detect_combined(image[i].unsqueeze(0), threshold, dilation)
            if mask_temp is None:
                mask_temp = mask_0
            else:
                mask_temp = mask_temp.unsqueeze(0)
            mask = torch.cat((mask, mask_temp), dim=0)
        mask = mask.squeeze(-1)
        return (mask,)


class run_yolo_bboxs:
    DESCRIPTION = """
    使用YOLO模型检测图像序列中的目标，并输出边界框详细信息
    功能：运行yolo模型输入图像序列输出bboxs支持批量处理图像
    """
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                    "image": ("IMAGE", ),
                    "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
                    "bbox_detector": ("BBOX_DETECTOR", ),
                    }
                }
    RETURN_TYPES = ("bboxs",)
    RETURN_NAMES = ("bboxs",)
    FUNCTION = "doit"
    CATEGORY = "ImpactPack"
    def doit(self, image, threshold, bbox_detector):
        if image.dim() == 3: # 图像预处理
            image = image.unsqueeze(0)
        all_bboxs = [] # 存储所有图像的边界框结果
        for i in range(image.shape[0]): # 对每张图像进行处理
            img = image[i].unsqueeze(0) # 将单张图像转换为PIL格式并进行检测
            detected_results = subcore.inference_bbox(bbox_detector.bbox_model, 
                                                     subcore.utils.tensor2pil(img), 
                                                     threshold)
            image_bboxs = [] # 提取边界框信息
            if len(detected_results[1]) > 0:  # 如果检测到目标
                for j in range(len(detected_results[1])):
                    bbox_info = { # 创建边界框信息字典
                        "label": detected_results[0][j],  # 类别标签
                        "bbox": detected_results[1][j].tolist(),  # 边界框坐标 [y1, x1, y2, x2]
                        "confidence": float(detected_results[3][j])  # 置信度
                    }
                    image_bboxs.append(bbox_info)
            all_bboxs.append(image_bboxs) # 将当前图像的边界框结果添加到总结果中
        return (all_bboxs,)