Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -630,90 +630,120 @@ def train_model(
|
|
630 |
log.append(error_msg)
|
631 |
return "\n".join(log)
|
632 |
|
633 |
-
# --- Download and
|
634 |
-
progress(0.
|
|
|
635 |
|
636 |
try:
|
637 |
-
|
|
|
|
|
|
|
638 |
snapshot_download(
|
639 |
repo_id=hf_dataset_repo_id,
|
640 |
-
local_dir=
|
641 |
-
|
|
|
642 |
resume_download=True
|
643 |
)
|
644 |
-
log.append(f"Dataset
|
645 |
-
|
646 |
-
#
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
668 |
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
#
|
678 |
-
|
679 |
-
|
680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
681 |
else:
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
-
all_texts.append({"text": text})
|
685 |
-
|
686 |
-
# Create HF dataset
|
687 |
-
train_dataset = Dataset.from_list(all_texts)
|
688 |
-
|
689 |
-
# Function to tokenize the dataset
|
690 |
-
def tokenize_function(examples):
|
691 |
-
return tokenizer(
|
692 |
-
examples["text"],
|
693 |
-
padding=False,
|
694 |
-
truncation=True,
|
695 |
-
max_length=2048,
|
696 |
-
return_tensors=None,
|
697 |
-
)
|
698 |
-
|
699 |
-
# Tokenize the dataset
|
700 |
-
tokenized_dataset = train_dataset.map(
|
701 |
-
tokenize_function,
|
702 |
-
batched=True,
|
703 |
-
remove_columns=["text"],
|
704 |
-
desc="Tokenizing dataset",
|
705 |
-
)
|
706 |
-
|
707 |
-
train_dataset = tokenized_dataset
|
708 |
-
|
709 |
-
# Data collator
|
710 |
-
from transformers import DataCollatorForLanguageModeling
|
711 |
-
|
712 |
-
data_collator = DataCollatorForLanguageModeling(
|
713 |
-
tokenizer=tokenizer,
|
714 |
-
mlm=False
|
715 |
-
)
|
716 |
-
|
717 |
except Exception as e:
|
718 |
error_msg = f"Error loading dataset: {str(e)}"
|
719 |
log.append(error_msg)
|
|
|
630 |
log.append(error_msg)
|
631 |
return "\n".join(log)
|
632 |
|
633 |
+
# --- Download and Load Dataset ---
|
634 |
+
progress(0.45, desc="Downloading dataset...")
|
635 |
+
log.append(f"Downloading dataset from {hf_dataset_repo_id}...")
|
636 |
|
637 |
try:
|
638 |
+
# Download the dataset files
|
639 |
+
local_dataset_path = "./downloaded_dataset_files"
|
640 |
+
|
641 |
+
# Correctly specify repo_type as "dataset"
|
642 |
snapshot_download(
|
643 |
repo_id=hf_dataset_repo_id,
|
644 |
+
local_dir=local_dataset_path,
|
645 |
+
repo_type="dataset", # Important! Specifies this is a dataset repo
|
646 |
+
token=hf_token if hf_token and hf_token.strip() else None, # Use token for auth
|
647 |
resume_download=True
|
648 |
)
|
649 |
+
log.append(f"Dataset files downloaded to {local_dataset_path}")
|
650 |
+
|
651 |
+
# Check the structure of the downloaded files
|
652 |
+
log.append("Checking downloaded dataset structure...")
|
653 |
+
downloaded_files = glob.glob(f"{local_dataset_path}/**/*.pt", recursive=True)
|
654 |
+
log.append(f"Found {len(downloaded_files)} .pt files in the dataset directory")
|
655 |
+
|
656 |
+
if len(downloaded_files) == 0:
|
657 |
+
log.append("No .pt files found. Checking for other file types...")
|
658 |
+
all_files = glob.glob(f"{local_dataset_path}/**/*.*", recursive=True)
|
659 |
+
log.append(f"All files found: {', '.join(all_files[:10])}")
|
660 |
+
if len(all_files) > 10:
|
661 |
+
log.append(f"...and {len(all_files) - 10} more files")
|
662 |
+
|
663 |
+
# Look for the pairs directory
|
664 |
+
pairs_dir = os.path.join(local_dataset_path, "final_rvq_pairs")
|
665 |
+
if not os.path.exists(pairs_dir):
|
666 |
+
log.append(f"final_rvq_pairs directory not found. Looking for other possible directories...")
|
667 |
+
possible_dirs = [d for d in glob.glob(f"{local_dataset_path}/**/") if os.path.isdir(d)]
|
668 |
+
log.append(f"Available directories: {', '.join(possible_dirs)}")
|
669 |
+
|
670 |
+
# Try to find any directory containing .pt files
|
671 |
+
for dir_path in possible_dirs:
|
672 |
+
if glob.glob(f"{dir_path}/*.pt"):
|
673 |
+
pairs_dir = dir_path
|
674 |
+
log.append(f"Using {pairs_dir} as the pairs directory.")
|
675 |
+
break
|
676 |
+
|
677 |
+
# If we found the pairs directory, we're good to go
|
678 |
+
if pairs_dir and os.path.exists(pairs_dir):
|
679 |
+
log.append(f"Using pairs directory: {pairs_dir}")
|
680 |
+
pt_files = glob.glob(f"{pairs_dir}/*.pt")
|
681 |
+
log.append(f"Found {len(pt_files)} .pt files in pairs directory")
|
682 |
+
|
683 |
+
# Load the dataset from the files
|
684 |
+
progress(0.5, desc="Loading pairs from dataset files...")
|
685 |
+
log.append("Loading dataset pairs...")
|
686 |
+
|
687 |
+
try:
|
688 |
+
# Load pairs from .pt files
|
689 |
+
pairs = []
|
690 |
+
for pt_file in tqdm(pt_files, desc="Loading .pt files"):
|
691 |
+
pair_data = torch.load(pt_file)
|
692 |
+
pairs.append(pair_data)
|
693 |
|
694 |
+
log.append(f"Loaded {len(pairs)} conversation pairs")
|
695 |
+
|
696 |
+
# Create a dataset from the pairs
|
697 |
+
dataset = Dataset.from_dict({
|
698 |
+
"input_ids": [pair[0].tolist() for pair in pairs],
|
699 |
+
"labels": [pair[1].tolist() for pair in pairs]
|
700 |
+
})
|
701 |
+
|
702 |
+
# Split into training and validation sets
|
703 |
+
train_test_split = dataset.train_test_split(test_size=0.05)
|
704 |
+
train_dataset = train_test_split["train"]
|
705 |
+
|
706 |
+
log.append(f"Created dataset with {len(train_dataset)} training examples")
|
707 |
+
|
708 |
+
except Exception as e:
|
709 |
+
log.append(f"Error loading pair data: {e}")
|
710 |
+
|
711 |
+
# Try an alternative approach - look for JSON or other formats
|
712 |
+
log.append("Attempting alternative dataset loading approaches...")
|
713 |
+
|
714 |
+
# Search for JSON files
|
715 |
+
json_files = glob.glob(f"{local_dataset_path}/**/*.json", recursive=True)
|
716 |
+
if json_files:
|
717 |
+
log.append(f"Found {len(json_files)} JSON files. Trying to load from these...")
|
718 |
+
|
719 |
+
# Load from JSON
|
720 |
+
combined_data = []
|
721 |
+
for json_file in json_files[:5]: # Start with a few files
|
722 |
+
try:
|
723 |
+
with open(json_file, 'r') as f:
|
724 |
+
file_data = json.load(f)
|
725 |
+
log.append(f"Successfully loaded {json_file}")
|
726 |
+
# Print sample of the data structure
|
727 |
+
log.append(f"Sample data structure: {str(file_data)[:500]}...")
|
728 |
+
combined_data.append(file_data)
|
729 |
+
except Exception as je:
|
730 |
+
log.append(f"Error loading {json_file}: {je}")
|
731 |
+
|
732 |
+
# If we loaded any data, try to create a dataset from it
|
733 |
+
if combined_data:
|
734 |
+
log.append("Attempting to create dataset from JSON data...")
|
735 |
+
# This will need adapting based on the actual JSON structure
|
736 |
else:
|
737 |
+
log.append("No JSON files found. Looking for other formats...")
|
738 |
+
# Add code for other formats if needed
|
739 |
+
|
740 |
+
log.append("Failed to load dataset after multiple attempts.")
|
741 |
+
return "\n".join(log)
|
742 |
+
|
743 |
+
else:
|
744 |
+
log.append("Could not locate pairs directory or any directory with .pt files.")
|
745 |
+
return "\n".join(log)
|
746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
except Exception as e:
|
748 |
error_msg = f"Error loading dataset: {str(e)}"
|
749 |
log.append(error_msg)
|