Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,63 +1,90 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
""
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
demo.launch()
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
|
3 |
import gradio as gr
|
4 |
+
from transformers import pipeline
|
5 |
+
import logging
|
6 |
+
|
7 |
+
# Enable detailed logging
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
|
10 |
+
# Load dataset
|
11 |
+
dataset = load_dataset("mwitiderrick/swahili")
|
12 |
+
|
13 |
+
# Print dataset columns for verification
|
14 |
+
print(f"Dataset columns: {dataset['train'].column_names}")
|
15 |
+
|
16 |
+
# Initialize the tokenizer and model
|
17 |
+
model_name = "gpt2" # Use GPT-2 for text generation
|
18 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
19 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
20 |
+
|
21 |
+
# Add a padding token to the tokenizer
|
22 |
+
tokenizer.pad_token = tokenizer.eos_token # Use eos_token as pad_token
|
23 |
+
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
24 |
+
|
25 |
+
# Preprocess the dataset
|
26 |
+
def preprocess_function(examples):
|
27 |
+
# Tokenize and format the dataset
|
28 |
+
encodings = tokenizer(
|
29 |
+
examples['text'], # Use 'text' column from your dataset
|
30 |
+
truncation=True,
|
31 |
+
padding='max_length', # Ensure consistent length
|
32 |
+
max_length=512
|
33 |
+
)
|
34 |
+
encodings['labels'] = encodings['input_ids'] # Use input_ids directly as labels
|
35 |
+
return encodings
|
36 |
+
|
37 |
+
# Tokenize the dataset
|
38 |
+
try:
|
39 |
+
tokenized_datasets = dataset.map(
|
40 |
+
preprocess_function,
|
41 |
+
batched=True
|
42 |
+
)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error during tokenization: {e}")
|
45 |
+
|
46 |
+
# Define training arguments
|
47 |
+
training_args = TrainingArguments(
|
48 |
+
output_dir='./results',
|
49 |
+
per_device_train_batch_size=4,
|
50 |
+
num_train_epochs=1,
|
51 |
+
logging_dir='./logs',
|
52 |
+
logging_steps=500, # Log every 500 steps
|
53 |
+
evaluation_strategy="steps", # Use evaluation strategy
|
54 |
+
save_steps=10_000, # Save checkpoint every 10,000 steps
|
55 |
+
save_total_limit=2, # Keep only the last 2 checkpoints
|
56 |
+
)
|
57 |
+
|
58 |
+
# Define Trainer
|
59 |
+
trainer = Trainer(
|
60 |
+
model=model,
|
61 |
+
args=training_args,
|
62 |
+
train_dataset=tokenized_datasets["train"],
|
63 |
+
tokenizer=tokenizer,
|
64 |
)
|
65 |
|
66 |
+
# Start training
|
67 |
+
try:
|
68 |
+
trainer.train()
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error during training: {e}")
|
71 |
+
|
72 |
+
# Define the Gradio interface function
|
73 |
+
nlp = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
74 |
+
|
75 |
+
def generate_text(prompt):
|
76 |
+
try:
|
77 |
+
return nlp(prompt, max_length=50)[0]['generated_text']
|
78 |
+
except Exception as e:
|
79 |
+
return f"Error during text generation: {e}"
|
80 |
+
|
81 |
+
# Create and launch the Gradio interface
|
82 |
+
iface = gr.Interface(
|
83 |
+
fn=generate_text,
|
84 |
+
inputs="text",
|
85 |
+
outputs="text",
|
86 |
+
title="Swahili Language Model",
|
87 |
+
description="Generate text in Swahili using a pre-trained language model."
|
88 |
+
)
|
89 |
|
90 |
+
iface.launch()
|
|