from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import gradio as gr
from transformers import pipeline
import logging

# Enable detailed logging
logging.basicConfig(level=logging.INFO)

# Load dataset
dataset = load_dataset("mwitiderrick/swahili")

# Print dataset columns for verification
print(f"Dataset columns: {dataset['train'].column_names}")

# Initialize the tokenizer and model
model_name = "gpt2"  # Use GPT-2 for text generation
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# Add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token  # Use eos_token as pad_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

# Preprocess the dataset
def preprocess_function(examples):
    # Tokenize and format the dataset
    encodings = tokenizer(
        examples['text'],  # Use 'text' column from your dataset
        truncation=True,
        padding='max_length',  # Ensure consistent length
        max_length=512
    )
    encodings['labels'] = encodings['input_ids']  # Use input_ids directly as labels
    return encodings

# Tokenize the dataset
try:
    tokenized_datasets = dataset.map(
        preprocess_function, 
        batched=True
    )
except Exception as e:
    print(f"Error during tokenization: {e}")

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=4,
    num_train_epochs=1,
    logging_dir='./logs',
    logging_steps=500,  # Log every 500 steps
    evaluation_strategy="steps",  # Use evaluation strategy
    save_steps=10_000,  # Save checkpoint every 10,000 steps
    save_total_limit=2,  # Keep only the last 2 checkpoints
)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    tokenizer=tokenizer,
)

# Start training
try:
    trainer.train()
except Exception as e:
    print(f"Error during training: {e}")

# Define the Gradio interface function
nlp = pipeline("text-generation", model=model, tokenizer=tokenizer)

def generate_text(prompt):
    try:
        return nlp(prompt, max_length=50)[0]['generated_text']
    except Exception as e:
        return f"Error during text generation: {e}"

# Create and launch the Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs="text",
    outputs="text",
    title="Swahili Language Model",
    description="Generate text in Swahili using a pre-trained language model."
)

iface.launch()