import torch

seed = 0

def generate_text(model_data, input_text, max_new_token):
    """
    Generate text using the given model and tokenizer.
    """
    if "pipeline" in model_data:
        # اگر مدل از pipeline پشتیبانی می‌کند
        model_pipeline = model_data["pipeline"]
        generated_text = model_pipeline(
            input_text,
            max_length=max_new_token + len(input_text.split()),  # افزایش max_length
            do_sample=False,  # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
            truncation=True,  # فعال کردن truncation
            repetition_penalty=1.5,
            no_repeat_ngram_size=3,
        )[0]["generated_text"]
        return generated_text
    else:
        # روش قدیمی برای مدل‌هایی که از pipeline پشتیبانی نمی‌کنند
        model = model_data["model"]
        tokenizer = model_data["tokenizer"]

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        encodings = tokenizer(
            input_text,
            return_tensors="pt",
            padding=True,
            truncation=True,  # فعال کردن truncation
            max_length=512
        )
        input_ids = encodings.input_ids
        attention_mask = encodings.attention_mask

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_token,  # استفاده از max_new_tokens
            do_sample=False,  # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.5,
            no_repeat_ngram_size=3,
        )

        return tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_code(model_data, prompt, max_new_tokens):
    """
    Generate code based on the provided prompt using a code-specific model.
    """
    model = model_data["model"]
    tokenizer = model_data["tokenizer"]

    # تنظیم seed برای خروجی ثابت
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # توکنایز کردن ورودی
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    # ایجاد attention mask
    attention_mask = torch.ones(input_ids.shape, device=input_ids.device)

    # تولید کد
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        do_sample=False,  # غیرفعال کردن نمونه‌گیری (حالت حریصانه)
        pad_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.5,
        no_repeat_ngram_size=3,
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)