Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,8 @@ from transformers import (
|
|
10 |
Trainer,
|
11 |
DataCollatorForLanguageModeling,
|
12 |
AutoTokenizer,
|
13 |
-
LlamaConfig
|
|
|
14 |
)
|
15 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
16 |
from datasets import Dataset
|
@@ -333,24 +334,24 @@ def train_model(
|
|
333 |
progress=gr.Progress()
|
334 |
):
|
335 |
progress(0, desc="Installing dependencies...")
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
347 |
|
348 |
# --- Configuration ---
|
349 |
progress(0.05, desc="Setting up configuration...")
|
350 |
hf_model_repo_id = f"{hf_username}/{model_repo_name}"
|
351 |
hf_dataset_repo_id = f"{hf_username}/{dataset_repo_name}"
|
352 |
|
353 |
-
log = []
|
354 |
log.append(f"Model repo: {hf_model_repo_id}")
|
355 |
log.append(f"Dataset repo: {hf_dataset_repo_id}")
|
356 |
|
@@ -369,83 +370,77 @@ def train_model(
|
|
369 |
# --- Load Base Model (with quantization) ---
|
370 |
progress(0.1, desc="Loading base model...")
|
371 |
try:
|
372 |
-
# First
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
|
375 |
-
|
376 |
-
from transformers import LlamaConfig
|
377 |
|
378 |
-
#
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
)
|
383 |
|
384 |
-
|
|
|
|
|
385 |
model = AutoModelForCausalLM.from_pretrained(
|
386 |
-
|
387 |
config=config,
|
388 |
quantization_config=bnb_config,
|
389 |
device_map="auto",
|
390 |
-
trust_remote_code=
|
|
|
391 |
)
|
|
|
392 |
log.append(f"Loaded model vocab size: {model.config.vocab_size}")
|
393 |
log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
|
394 |
except Exception as e:
|
395 |
-
error_msg = f"Error loading model
|
396 |
log.append(error_msg)
|
397 |
-
|
398 |
-
try:
|
399 |
-
log.append("Attempting alternative loading method...")
|
400 |
-
# Try loading without auto detection
|
401 |
-
model = AutoModelForCausalLM.from_pretrained(
|
402 |
-
hf_model_repo_id,
|
403 |
-
quantization_config=bnb_config,
|
404 |
-
device_map="auto",
|
405 |
-
trust_remote_code=True,
|
406 |
-
torch_dtype=torch.bfloat16,
|
407 |
-
# Add these to help with the loading
|
408 |
-
revision="main",
|
409 |
-
low_cpu_mem_usage=True,
|
410 |
-
)
|
411 |
-
log.append("Alternative loading successful!")
|
412 |
-
log.append(f"Loaded model vocab size: {model.config.vocab_size}")
|
413 |
-
except Exception as e2:
|
414 |
-
log.append(f"Alternative loading also failed: {e2}")
|
415 |
-
return "\n".join(log)
|
416 |
-
|
417 |
-
# Load the official Meta tokenizer for LLaMA 3
|
418 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
419 |
-
"meta-llama/Llama-3-8B", # Use the official Meta tokenizer
|
420 |
-
use_auth_token=os.environ.get("HF_TOKEN", None) # In case it's needed
|
421 |
-
)
|
422 |
-
|
423 |
-
if tokenizer is None:
|
424 |
-
# Fallback to another common foundation model tokenizer
|
425 |
-
print("Falling back to another tokenizer as Meta tokenizer requires auth token")
|
426 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
427 |
-
|
428 |
-
print(f"Loaded tokenizer vocabulary size: {len(tokenizer)}")
|
429 |
-
|
430 |
-
# Print information about input embeddings
|
431 |
-
print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
|
432 |
-
|
433 |
-
# Prepare model for k-bit training
|
434 |
-
model = prepare_model_for_kbit_training(model)
|
435 |
|
436 |
-
#
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
# Cleanup
|
451 |
gc.collect()
|
|
|
10 |
Trainer,
|
11 |
DataCollatorForLanguageModeling,
|
12 |
AutoTokenizer,
|
13 |
+
LlamaConfig,
|
14 |
+
AutoConfig
|
15 |
)
|
16 |
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
17 |
from datasets import Dataset
|
|
|
334 |
progress=gr.Progress()
|
335 |
):
|
336 |
progress(0, desc="Installing dependencies...")
|
337 |
+
log = []
|
338 |
+
|
339 |
+
# Force reinstallation of transformers with specific version
|
340 |
+
log.append("Installing dependencies with specific versions...")
|
341 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "--force-reinstall", "transformers==4.36.2"])
|
342 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U", "accelerate", "bitsandbytes", "peft", "datasets", "huggingface_hub", "deepspeed"])
|
343 |
+
|
344 |
+
# Now import everything after installation to ensure we use the correct versions
|
345 |
+
from datasets import Dataset
|
346 |
+
from huggingface_hub import snapshot_download
|
347 |
+
from transformers import AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, TrainingArguments, Trainer
|
348 |
+
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
349 |
|
350 |
# --- Configuration ---
|
351 |
progress(0.05, desc="Setting up configuration...")
|
352 |
hf_model_repo_id = f"{hf_username}/{model_repo_name}"
|
353 |
hf_dataset_repo_id = f"{hf_username}/{dataset_repo_name}"
|
354 |
|
|
|
355 |
log.append(f"Model repo: {hf_model_repo_id}")
|
356 |
log.append(f"Dataset repo: {hf_dataset_repo_id}")
|
357 |
|
|
|
370 |
# --- Load Base Model (with quantization) ---
|
371 |
progress(0.1, desc="Loading base model...")
|
372 |
try:
|
373 |
+
# First try to download the repo without loading the model
|
374 |
+
# to see what files are available
|
375 |
+
local_model_path = "./model_files"
|
376 |
+
snapshot_download(
|
377 |
+
repo_id=hf_model_repo_id,
|
378 |
+
local_dir=local_model_path,
|
379 |
+
local_dir_use_symlinks=False
|
380 |
+
)
|
381 |
|
382 |
+
log.append(f"Model files downloaded to {local_model_path}")
|
|
|
383 |
|
384 |
+
# Check if this is a Llama model by looking at config.json
|
385 |
+
if os.path.exists(os.path.join(local_model_path, "config.json")):
|
386 |
+
with open(os.path.join(local_model_path, "config.json"), "r") as f:
|
387 |
+
config_data = json.load(f)
|
388 |
+
log.append(f"Model architecture type: {config_data.get('model_type', 'unknown')}")
|
389 |
+
|
390 |
+
# Force model_type to llama if needed
|
391 |
+
if "architectures" in config_data and "LlamaForCausalLM" in config_data["architectures"]:
|
392 |
+
config_data["model_type"] = "llama"
|
393 |
+
with open(os.path.join(local_model_path, "config.json"), "w") as f:
|
394 |
+
json.dump(config_data, f)
|
395 |
+
log.append("Updated config.json to use llama model_type")
|
396 |
+
|
397 |
+
# Now try to load the config and model from local path
|
398 |
+
config = AutoConfig.from_pretrained(
|
399 |
+
local_model_path,
|
400 |
+
trust_remote_code=False # Set to False to avoid custom model code loading
|
401 |
)
|
402 |
|
403 |
+
log.append(f"Successfully loaded config: {config.model_type}")
|
404 |
+
|
405 |
+
# Load model with the config
|
406 |
model = AutoModelForCausalLM.from_pretrained(
|
407 |
+
local_model_path,
|
408 |
config=config,
|
409 |
quantization_config=bnb_config,
|
410 |
device_map="auto",
|
411 |
+
trust_remote_code=False,
|
412 |
+
torch_dtype=torch.bfloat16
|
413 |
)
|
414 |
+
|
415 |
log.append(f"Loaded model vocab size: {model.config.vocab_size}")
|
416 |
log.append(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
|
417 |
except Exception as e:
|
418 |
+
error_msg = f"Error loading model: {str(e)}"
|
419 |
log.append(error_msg)
|
420 |
+
return "\n".join(log)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
|
422 |
+
# --- Prepare for K-bit Training & Apply LoRA ---
|
423 |
+
progress(0.15, desc="Preparing model for fine-tuning...")
|
424 |
+
try:
|
425 |
+
model = prepare_model_for_kbit_training(model)
|
426 |
+
log.append("Model prepared for k-bit training")
|
427 |
+
|
428 |
+
lora_config = LoraConfig(
|
429 |
+
task_type=TaskType.CAUSAL_LM,
|
430 |
+
r=16,
|
431 |
+
lora_alpha=32,
|
432 |
+
lora_dropout=0.05,
|
433 |
+
bias="none",
|
434 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
435 |
+
)
|
436 |
+
peft_model = get_peft_model(model, lora_config)
|
437 |
+
trainable_params = peft_model.print_trainable_parameters()
|
438 |
+
log.append(f"LoRA applied to model")
|
439 |
+
model_to_train = peft_model
|
440 |
+
except Exception as e:
|
441 |
+
error_msg = f"Error preparing model for training: {str(e)}"
|
442 |
+
log.append(error_msg)
|
443 |
+
return "\n".join(log)
|
444 |
|
445 |
# Cleanup
|
446 |
gc.collect()
|