from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
import torch
import random
import numpy as np
from typing import Dict, Any

# 基礎類別 - 共用核心功能
class BaseLLMNode:
    def __init__(self):
        self.model_cache: Dict[str, Any] = {}
        self.tokenizer_cache: Dict[str, Any] = {}
    
    def _load_model(self, model_id: str):
        if model_id not in self.model_cache:
            print(f"Loading model: {model_id}")
            
            tokenizer = AutoTokenizer.from_pretrained(
                model_id, 
                trust_remote_code=True,
                local_files_only=False
            )
            
            if tokenizer.pad_token is None:
                if tokenizer.eos_token:
                    tokenizer.pad_token = tokenizer.eos_token
                else:
                    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="auto",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                local_files_only=False
            )
            
            self.model_cache[model_id] = model
            self.tokenizer_cache[model_id] = tokenizer
            print(f"Model loaded successfully")
        
        return self.model_cache[model_id], self.tokenizer_cache[model_id]
    
    def _generate(self, model, tokenizer, prompt, max_new_tokens, temperature, seed):
        # Set seed
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        
        inputs = tokenizer.encode(prompt, return_tensors="pt")
        inputs = inputs.to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True if temperature > 0 else False,
                pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        new_tokens = outputs[0][inputs.shape[1]:]
        result = tokenizer.decode(new_tokens, skip_special_tokens=True)
        return result.strip()

# Llama 節點 (Llama 2/3, Code Llama, Alpaca, Vicuna)
class LlamaLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "Tell me about machine learning."}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/llama-2-7b-chat"}),
                "template_style": (["llama_instruct", "alpaca", "vicuna"], {"default": "llama_instruct"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Llama"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "llama_instruct": f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_prompt} [/INST]",
            "alpaca": f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{system_prompt}\n\n{user_prompt}\n\n### Response:\n",
            "vicuna": f"A chat between a curious user and an artificial assistant.\n\nUSER: {system_prompt}\n\n{user_prompt}\nASSISTANT:"
        }
        return templates.get(template_style, templates["llama_instruct"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Mistral 節點 (Mistral 7B, Mixtral)
class MistralLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "Explain quantum computing."}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/mistral-7b-instruct"}),
                "template_style": (["mistral_instruct", "mistral_simple"], {"default": "mistral_instruct"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Mistral"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "mistral_instruct": f"<s>[INST] {system_prompt}\n\n{user_prompt} [/INST]",
            "mistral_simple": f"{system_prompt}\n\n{user_prompt}\n\nResponse:"
        }
        return templates.get(template_style, templates["mistral_instruct"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Phi 節點 (Phi-3, Phi-4)
class PhiLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "What is the meaning of life?"}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/phi-4-mini"}),
                "template_style": (["phi_instruct", "phi_simple"], {"default": "phi_instruct"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Phi"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "phi_instruct": f"<|user|>\n{system_prompt}\n\n{user_prompt}<|end|>\n<|assistant|>",
            "phi_simple": f"{system_prompt}\n\n{user_prompt}\n\nAnswer:"
        }
        return templates.get(template_style, templates["phi_instruct"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Qwen 節點 (Qwen 1.5, Qwen 2)
class QwenLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "介紹一下人工智能的發展歷史"}),
                "system_prompt": ("STRING", {"multiline": True, "default": "你是一個樂於助人的助手。"}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/qwen-7b-chat"}),
                "template_style": (["qwen_chatml", "qwen_simple"], {"default": "qwen_chatml"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Qwen"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "qwen_chatml": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n",
            "qwen_simple": f"{system_prompt}\n\n{user_prompt}\n\n回答："
        }
        return templates.get(template_style, templates["qwen_chatml"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Gemma 節點 (Gemma 2B, 7B)
class GemmaLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "Explain the concept of neural networks."}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/gemma-7b-it"}),
                "template_style": (["gemma_chat", "gemma_simple"], {"default": "gemma_chat"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Gemma"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "gemma_chat": f"<start_of_turn>user\n{system_prompt}\n\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n",
            "gemma_simple": f"{system_prompt}\n\n{user_prompt}\n\nResponse:"
        }
        return templates.get(template_style, templates["gemma_chat"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Yi 節點 (Yi-6B, Yi-34B)
class YiLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "請介紹深度學習的基本概念"}),
                "system_prompt": ("STRING", {"multiline": True, "default": "你是一個AI助手，能夠提供準確和有幫助的回答。"}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/yi-6b-chat"}),
                "template_style": (["yi_chatml", "yi_simple"], {"default": "yi_chatml"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Yi"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "yi_chatml": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n",
            "yi_simple": f"{system_prompt}\n\n{user_prompt}\n\n回答："
        }
        return templates.get(template_style, templates["yi_chatml"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# ComfyUI 節點註冊
NODE_CLASS_MAPPINGS = {
    "LlamaLLMNode": LlamaLLMNode,
    "MistralLLMNode": MistralLLMNode,
    "PhiLLMNode": PhiLLMNode,
    "QwenLLMNode": QwenLLMNode,
    "GemmaLLMNode": GemmaLLMNode,
    "YiLLMNode": YiLLMNode,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LlamaLLMNode": "🦙 Llama LLM Node",
    "MistralLLMNode": "🌪️ Mistral LLM Node", 
    "PhiLLMNode": "⚡ Phi LLM Node",
    "QwenLLMNode": "🇨🇳 Qwen LLM Node",
    "GemmaLLMNode": "💎 Gemma LLM Node",
    "YiLLMNode": "🔥 Yi LLM Node",
}from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
import torch
import random
import numpy as np
from typing import Dict, Any

# 基礎類別 - 共用核心功能
class BaseLLMNode:
    def __init__(self):
        self.model_cache: Dict[str, Any] = {}
        self.tokenizer_cache: Dict[str, Any] = {}
    
    def _load_model(self, model_id: str):
        if model_id not in self.model_cache:
            print(f"Loading model: {model_id}")
            
            tokenizer = AutoTokenizer.from_pretrained(
                model_id, 
                trust_remote_code=True,
                local_files_only=False
            )
            
            if tokenizer.pad_token is None:
                if tokenizer.eos_token:
                    tokenizer.pad_token = tokenizer.eos_token
                else:
                    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="auto",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                local_files_only=False
            )
            
            self.model_cache[model_id] = model
            self.tokenizer_cache[model_id] = tokenizer
            print(f"Model loaded successfully")
        
        return self.model_cache[model_id], self.tokenizer_cache[model_id]
    
    def _generate(self, model, tokenizer, prompt, max_new_tokens, temperature, seed):
        # Set seed
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        
        inputs = tokenizer.encode(prompt, return_tensors="pt")
        inputs = inputs.to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True if temperature > 0 else False,
                pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        new_tokens = outputs[0][inputs.shape[1]:]
        result = tokenizer.decode(new_tokens, skip_special_tokens=True)
        return result.strip()

# Llama 節點 (Llama 2/3, Code Llama, Alpaca, Vicuna)
class LlamaLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "Tell me about machine learning."}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/llama-2-7b-chat"}),
                "template_style": (["llama_instruct", "alpaca", "vicuna"], {"default": "llama_instruct"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Llama"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "llama_instruct": f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_prompt} [/INST]",
            "alpaca": f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{system_prompt}\n\n{user_prompt}\n\n### Response:\n",
            "vicuna": f"A chat between a curious user and an artificial assistant.\n\nUSER: {system_prompt}\n\n{user_prompt}\nASSISTANT:"
        }
        return templates.get(template_style, templates["llama_instruct"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Mistral 節點 (Mistral 7B, Mixtral)
class MistralLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "Explain quantum computing."}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/mistral-7b-instruct"}),
                "template_style": (["mistral_instruct", "mistral_simple"], {"default": "mistral_instruct"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Mistral"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "mistral_instruct": f"<s>[INST] {system_prompt}\n\n{user_prompt} [/INST]",
            "mistral_simple": f"{system_prompt}\n\n{user_prompt}\n\nResponse:"
        }
        return templates.get(template_style, templates["mistral_instruct"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Phi 節點 (Phi-3, Phi-4)
class PhiLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "What is the meaning of life?"}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/phi-4-mini"}),
                "template_style": (["phi_instruct", "phi_simple"], {"default": "phi_instruct"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Phi"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "phi_instruct": f"<|user|>\n{system_prompt}\n\n{user_prompt}<|end|>\n<|assistant|>",
            "phi_simple": f"{system_prompt}\n\n{user_prompt}\n\nAnswer:"
        }
        return templates.get(template_style, templates["phi_instruct"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Qwen 節點 (Qwen 1.5, Qwen 2)
class QwenLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "介紹一下人工智能的發展歷史"}),
                "system_prompt": ("STRING", {"multiline": True, "default": "你是一個樂於助人的助手。"}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/qwen-7b-chat"}),
                "template_style": (["qwen_chatml", "qwen_simple"], {"default": "qwen_chatml"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Qwen"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "qwen_chatml": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n",
            "qwen_simple": f"{system_prompt}\n\n{user_prompt}\n\n回答："
        }
        return templates.get(template_style, templates["qwen_chatml"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Gemma 節點 (Gemma 2B, 7B)
class GemmaLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "Explain the concept of neural networks."}),
                "system_prompt": ("STRING", {"multiline": True, "default": "You are a helpful assistant."}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/gemma-7b-it"}),
                "template_style": (["gemma_chat", "gemma_simple"], {"default": "gemma_chat"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Gemma"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "gemma_chat": f"<start_of_turn>user\n{system_prompt}\n\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n",
            "gemma_simple": f"{system_prompt}\n\n{user_prompt}\n\nResponse:"
        }
        return templates.get(template_style, templates["gemma_chat"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# Yi 節點 (Yi-6B, Yi-34B)
class YiLLMNode(BaseLLMNode):
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "prompt": ("STRING", {"multiline": True, "default": "請介紹深度學習的基本概念"}),
                "system_prompt": ("STRING", {"multiline": True, "default": "你是一個AI助手，能夠提供準確和有幫助的回答。"}),
                "model_id": ("STRING", {"multiline": False, "default": "/workspace/ComfyUI/models/yi-6b-chat"}),
                "template_style": (["yi_chatml", "yi_simple"], {"default": "yi_chatml"}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 4096}),
                "temperature": ("FLOAT", {"default": 0.7, "min": 0.01, "max": 2.0}),
                "seed": ("INT", {"default": 42, "min": 1, "max": 2147483647}),
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("generated_text",)
    FUNCTION = "infer"
    CATEGORY = "MyNodes/LLM/Yi"

    def _build_prompt(self, system_prompt, user_prompt, template_style):
        templates = {
            "yi_chatml": f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n",
            "yi_simple": f"{system_prompt}\n\n{user_prompt}\n\n回答："
        }
        return templates.get(template_style, templates["yi_chatml"])

    def infer(self, prompt, system_prompt, model_id, template_style, max_new_tokens, temperature, seed):
        try:
            model, tokenizer = self._load_model(model_id)
            full_prompt = self._build_prompt(system_prompt, prompt, template_style)
            result = self._generate(model, tokenizer, full_prompt, max_new_tokens, temperature, seed)
            return (result,)
        except Exception as e:
            return (f"Error: {str(e)}",)

# ComfyUI 節點註冊
NODE_CLASS_MAPPINGS = {
    "LlamaLLMNode": LlamaLLMNode,
    "MistralLLMNode": MistralLLMNode,
    "PhiLLMNode": PhiLLMNode,
    "QwenLLMNode": QwenLLMNode,
    "GemmaLLMNode": GemmaLLMNode,
    "YiLLMNode": YiLLMNode,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LlamaLLMNode": "🦙 Llama LLM Node",
    "MistralLLMNode": "🌪️ Mistral LLM Node", 
    "PhiLLMNode": "⚡ Phi LLM Node",
    "QwenLLMNode": "🇨🇳 Qwen LLM Node",
    "GemmaLLMNode": "💎 Gemma LLM Node",
    "YiLLMNode": "🔥 Yi LLM Node",
}