Twelve2five commited on
Commit
19ba848
·
verified ·
1 Parent(s): 0591b3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -75
app.py CHANGED
@@ -630,90 +630,120 @@ def train_model(
630
  log.append(error_msg)
631
  return "\n".join(log)
632
 
633
- # --- Download and Process Dataset ---
634
- progress(0.4, desc="Downloading dataset...")
 
635
 
636
  try:
637
- dataset_path = "./downloaded_dataset_files"
 
 
 
638
  snapshot_download(
639
  repo_id=hf_dataset_repo_id,
640
- local_dir=dataset_path,
641
- use_auth_token=False,
 
642
  resume_download=True
643
  )
644
- log.append(f"Dataset repository content downloaded to: {dataset_path}")
645
-
646
- # Load dataset from PT files
647
- progress(0.5, desc="Processing dataset...")
648
-
649
- # Load RVQ pairs
650
- pair_files = glob.glob(f"{dataset_path}/*_rvq_pairs.pt")
651
- log.append(f"Found {len(pair_files)} RVQ pair files.")
652
-
653
- all_pairs = []
654
- for file in pair_files:
655
- pairs = torch.load(file)
656
- all_pairs.extend(pairs)
657
-
658
- log.append(f"Loaded a total of {len(all_pairs)} training pairs into memory.")
659
-
660
- # Process pairs into a format suitable for training
661
- all_texts = []
662
- for pair in all_pairs:
663
- # Create instruction format
664
- if isinstance(pair, dict):
665
- instruction = pair.get("instruction", "")
666
- input_text = pair.get("input", "")
667
- output = pair.get("output", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
669
- # ALPACA format
670
- if instruction and input_text:
671
- text = f"### Instruction: {instruction}\n### Input: {input_text}\n### Response: {output}"
672
- elif instruction:
673
- text = f"### Instruction: {instruction}\n### Response: {output}"
674
- else:
675
- text = output
676
- else:
677
- # Simple prompt-completion format
678
- if isinstance(pair, tuple) and len(pair) == 2:
679
- prompt, completion = pair
680
- text = f"{prompt}{completion}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681
  else:
682
- text = str(pair)
 
 
 
 
 
 
 
 
683
 
684
- all_texts.append({"text": text})
685
-
686
- # Create HF dataset
687
- train_dataset = Dataset.from_list(all_texts)
688
-
689
- # Function to tokenize the dataset
690
- def tokenize_function(examples):
691
- return tokenizer(
692
- examples["text"],
693
- padding=False,
694
- truncation=True,
695
- max_length=2048,
696
- return_tensors=None,
697
- )
698
-
699
- # Tokenize the dataset
700
- tokenized_dataset = train_dataset.map(
701
- tokenize_function,
702
- batched=True,
703
- remove_columns=["text"],
704
- desc="Tokenizing dataset",
705
- )
706
-
707
- train_dataset = tokenized_dataset
708
-
709
- # Data collator
710
- from transformers import DataCollatorForLanguageModeling
711
-
712
- data_collator = DataCollatorForLanguageModeling(
713
- tokenizer=tokenizer,
714
- mlm=False
715
- )
716
-
717
  except Exception as e:
718
  error_msg = f"Error loading dataset: {str(e)}"
719
  log.append(error_msg)
 
630
  log.append(error_msg)
631
  return "\n".join(log)
632
 
633
+ # --- Download and Load Dataset ---
634
+ progress(0.45, desc="Downloading dataset...")
635
+ log.append(f"Downloading dataset from {hf_dataset_repo_id}...")
636
 
637
  try:
638
+ # Download the dataset files
639
+ local_dataset_path = "./downloaded_dataset_files"
640
+
641
+ # Correctly specify repo_type as "dataset"
642
  snapshot_download(
643
  repo_id=hf_dataset_repo_id,
644
+ local_dir=local_dataset_path,
645
+ repo_type="dataset", # Important! Specifies this is a dataset repo
646
+ token=hf_token if hf_token and hf_token.strip() else None, # Use token for auth
647
  resume_download=True
648
  )
649
+ log.append(f"Dataset files downloaded to {local_dataset_path}")
650
+
651
+ # Check the structure of the downloaded files
652
+ log.append("Checking downloaded dataset structure...")
653
+ downloaded_files = glob.glob(f"{local_dataset_path}/**/*.pt", recursive=True)
654
+ log.append(f"Found {len(downloaded_files)} .pt files in the dataset directory")
655
+
656
+ if len(downloaded_files) == 0:
657
+ log.append("No .pt files found. Checking for other file types...")
658
+ all_files = glob.glob(f"{local_dataset_path}/**/*.*", recursive=True)
659
+ log.append(f"All files found: {', '.join(all_files[:10])}")
660
+ if len(all_files) > 10:
661
+ log.append(f"...and {len(all_files) - 10} more files")
662
+
663
+ # Look for the pairs directory
664
+ pairs_dir = os.path.join(local_dataset_path, "final_rvq_pairs")
665
+ if not os.path.exists(pairs_dir):
666
+ log.append(f"final_rvq_pairs directory not found. Looking for other possible directories...")
667
+ possible_dirs = [d for d in glob.glob(f"{local_dataset_path}/**/") if os.path.isdir(d)]
668
+ log.append(f"Available directories: {', '.join(possible_dirs)}")
669
+
670
+ # Try to find any directory containing .pt files
671
+ for dir_path in possible_dirs:
672
+ if glob.glob(f"{dir_path}/*.pt"):
673
+ pairs_dir = dir_path
674
+ log.append(f"Using {pairs_dir} as the pairs directory.")
675
+ break
676
+
677
+ # If we found the pairs directory, we're good to go
678
+ if pairs_dir and os.path.exists(pairs_dir):
679
+ log.append(f"Using pairs directory: {pairs_dir}")
680
+ pt_files = glob.glob(f"{pairs_dir}/*.pt")
681
+ log.append(f"Found {len(pt_files)} .pt files in pairs directory")
682
+
683
+ # Load the dataset from the files
684
+ progress(0.5, desc="Loading pairs from dataset files...")
685
+ log.append("Loading dataset pairs...")
686
+
687
+ try:
688
+ # Load pairs from .pt files
689
+ pairs = []
690
+ for pt_file in tqdm(pt_files, desc="Loading .pt files"):
691
+ pair_data = torch.load(pt_file)
692
+ pairs.append(pair_data)
693
 
694
+ log.append(f"Loaded {len(pairs)} conversation pairs")
695
+
696
+ # Create a dataset from the pairs
697
+ dataset = Dataset.from_dict({
698
+ "input_ids": [pair[0].tolist() for pair in pairs],
699
+ "labels": [pair[1].tolist() for pair in pairs]
700
+ })
701
+
702
+ # Split into training and validation sets
703
+ train_test_split = dataset.train_test_split(test_size=0.05)
704
+ train_dataset = train_test_split["train"]
705
+
706
+ log.append(f"Created dataset with {len(train_dataset)} training examples")
707
+
708
+ except Exception as e:
709
+ log.append(f"Error loading pair data: {e}")
710
+
711
+ # Try an alternative approach - look for JSON or other formats
712
+ log.append("Attempting alternative dataset loading approaches...")
713
+
714
+ # Search for JSON files
715
+ json_files = glob.glob(f"{local_dataset_path}/**/*.json", recursive=True)
716
+ if json_files:
717
+ log.append(f"Found {len(json_files)} JSON files. Trying to load from these...")
718
+
719
+ # Load from JSON
720
+ combined_data = []
721
+ for json_file in json_files[:5]: # Start with a few files
722
+ try:
723
+ with open(json_file, 'r') as f:
724
+ file_data = json.load(f)
725
+ log.append(f"Successfully loaded {json_file}")
726
+ # Print sample of the data structure
727
+ log.append(f"Sample data structure: {str(file_data)[:500]}...")
728
+ combined_data.append(file_data)
729
+ except Exception as je:
730
+ log.append(f"Error loading {json_file}: {je}")
731
+
732
+ # If we loaded any data, try to create a dataset from it
733
+ if combined_data:
734
+ log.append("Attempting to create dataset from JSON data...")
735
+ # This will need adapting based on the actual JSON structure
736
  else:
737
+ log.append("No JSON files found. Looking for other formats...")
738
+ # Add code for other formats if needed
739
+
740
+ log.append("Failed to load dataset after multiple attempts.")
741
+ return "\n".join(log)
742
+
743
+ else:
744
+ log.append("Could not locate pairs directory or any directory with .pt files.")
745
+ return "\n".join(log)
746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
  except Exception as e:
748
  error_msg = f"Error loading dataset: {str(e)}"
749
  log.append(error_msg)