Aananda-giri commited on
Commit
29f1f37
·
1 Parent(s): 348f920

debug run in cpu

Browse files
Files changed (1) hide show
  1. app.py +111 -171
app.py CHANGED
@@ -718,7 +718,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
718
  # --------------------------------------------------------------------------------
719
  # -------------------------- New Chat function ---------------------
720
  # --------------------------------------------------------------------------------
721
- def generate_chat(
722
  model,
723
  prompt,
724
  tokenizer,
@@ -727,189 +727,126 @@ def generate_chat(
727
  context_size,
728
  temperature=0.7,
729
  top_k=50,
730
- top_p=None, # Nucleus sampling
731
  eos_id=None,
732
  repetition_penalty=1.2,
733
  penalize_len_below=50,
734
- device=None
 
 
735
  ):
736
  if device is None:
737
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
738
 
739
  idx = text_to_token_ids(prompt, chat_tokenizer).to(device)
740
 
741
- if not eos_id and "<|endoftext|>" in tokenizer.get_vocab():
742
- encoded_endoftext = tokenizer.encode("<|endoftext|>")
743
- eos_id = encoded_endoftext[0] if encoded_endoftext else None
744
- elif not eos_id and "<|eot_id|>" in tokenizer.get_vocab():
745
- encoded_endoftext = tokenizer.encode("<|eot_id|>")
746
- eos_id = encoded_endoftext[0] if encoded_endoftext else None
747
-
 
 
 
748
  token_freq = {}
749
-
750
- for step in range(max_new_tokens):
751
- idx_cond = idx[:, -context_size:]
752
- with torch.no_grad():
 
 
 
 
 
 
 
 
 
753
  logits = model(idx_cond)
754
- logits = logits[:, -1, :]
755
-
756
- # Apply repetition penalty
757
- for token_id in idx[0].tolist():
758
- if token_id in token_freq:
759
- token_freq[token_id] += 1
760
- logits[0, token_id] /= repetition_penalty
761
- else:
762
- token_freq[token_id] = 1
763
-
764
- # Penalize EOT token for shorter sequences
765
- if eos_id is not None and step < penalize_len_below:
766
- penalty_factor = 1.0 + (penalize_len_below - step) / penalize_len_below
767
- logits[0, eos_id] /= penalty_factor
768
-
769
- # Apply temperature scaling
770
- if temperature > 0.0:
771
- logits = logits / temperature
772
-
773
- # Convert logits to probabilities
774
- probs = torch.softmax(logits, dim=-1)
775
-
776
- # Apply top-p (nucleus) sampling if specified
777
- if top_p and top_p > 0.0:
778
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
779
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
780
 
781
- # Remove tokens with cumulative probability above the threshold
782
- sorted_indices_to_remove = cumulative_probs > top_p
783
- # Shift the indices to the right to keep also the first token above the threshold
784
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
785
- sorted_indices_to_remove[..., 0] = 0
786
-
787
- # Create a mask for indices to remove
788
- indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
789
- probs = probs.masked_fill(indices_to_remove, 0.0)
790
 
791
- # Renormalize probabilities
792
- probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8) # Avoid division by zero
793
-
794
- # If top_p is None or 0, apply top-k sampling
795
- elif top_k and top_k > 0:
796
- top_probs, top_indices = torch.topk(probs, min(top_k, probs.size(-1)))
797
- probs = torch.zeros_like(probs).scatter_(-1, top_indices, top_probs)
798
- # Renormalize probabilities
799
- probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8) # Avoid division by zero
800
-
801
- # Sample from the filtered distribution
802
- if temperature > 0.0:
803
- idx_next = torch.multinomial(probs, num_samples=1)
804
- else:
805
- idx_next = torch.argmax(probs, dim=-1, keepdim=True)
806
-
807
- # Add the next token to the sequence
808
- idx = torch.cat((idx, idx_next), dim=1)
809
-
810
- # Check for end of sequence token
811
- if idx_next.item() == eos_id:
812
- break
813
-
814
- return idx
815
-
816
- def clean_chat_output(text):
817
- """Clean up the generated text to remove repetition and artifacts."""
818
- # Remove repetitive patterns (like the example with repeated "म त यो देशको प्रधानमन्त्री हुँ")
819
- import re
820
-
821
- # Handle endoftext markers
822
- text = re.sub(r'<\|endoftext\|>.*$', '', text, flags=re.DOTALL)
823
- text = re.sub(r'<\|eot_id\|>.*$', '', text, flags=re.DOTALL)
824
-
825
- # Remove excessive repetition (3 or more identical sentences)
826
- lines = text.split('\n')
827
- cleaned_lines = []
828
- prev_line = None
829
- repetition_count = 0
830
-
831
- for line in lines:
832
- if line == prev_line:
833
- repetition_count += 1
834
- if repetition_count > 2: # Skip if this is the 3rd or more repetition
835
- continue
836
- else:
837
- repetition_count = 0
838
-
839
- cleaned_lines.append(line)
840
- prev_line = line
841
-
842
- text = '\n'.join(cleaned_lines)
843
-
844
- # Also clean repetitive phrases within a single line
845
- words = text.split()
846
- cleaned_words = []
847
- repetition_window = []
848
-
849
- for word in words:
850
- if len(repetition_window) >= 3 and all(w == word for w in repetition_window[-3:]):
851
- continue # Skip this word if the last 3 words were identical to it
852
- cleaned_words.append(word)
853
- repetition_window.append(word)
854
- if len(repetition_window) > 10: # Keep a limited window
855
- repetition_window.pop(0)
856
-
857
- return ' '.join(cleaned_words).strip()
858
-
859
- def generate_and_print_chat(
860
- prompt,
861
- tokenizer,
862
- chat_tokenizer,
863
- model,
864
- device=None,
865
- max_new_tokens=150,
866
- context_length=None,
867
- temperature=0.7,
868
- top_k=50,
869
- top_p=0.9,
870
- repetition_penalty=1.2,
871
- clean_the_text=False,
872
- print_output=True
873
- ):
874
- if device is None:
875
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
876
-
877
- if context_length is None:
878
- # Try to get from model config or use default
879
- context_length = getattr(model, "context_length", 2048)
880
-
881
- # Generate tokens
882
- token_ids = generate_chat(
883
- model=model,
884
- prompt=prompt,
885
- tokenizer=tokenizer,
886
- chat_tokenizer=chat_tokenizer,
887
- max_new_tokens=max_new_tokens,
888
- context_size=context_length,
889
- temperature=temperature,
890
- top_k=top_k,
891
- top_p=top_p,
892
- repetition_penalty=repetition_penalty,
893
- device=device
894
- )
895
-
896
- # Convert tokens to text
897
- output_text = token_ids_to_text(token_ids, tokenizer)
898
 
 
 
899
  if clean_the_text:
900
  # Clean the output
901
  # cleaned_text = clean_chat_output(output_text)
902
  cleaned_text = clean_text(output_text)
903
  if '<|eot_id|>' in cleaned_text:
904
  cleaned_text = cleaned_text.replace('<|eot_id|>','')
905
- print("Generated text:\n", cleaned_text)
906
 
907
  return cleaned_text
908
- else:
909
- print("Generated text:\n", output_text)
910
-
911
- return output_text
912
-
913
  # ==================================================================================-
914
  # ==================================================================================-
915
  # ==================================================================================-
@@ -1093,19 +1030,22 @@ clean_the_text=False,
1093
  print_output=True
1094
 
1095
  def generate_text(prompt, max_new_tokens, top_k, top_p, temperature, repetition_penalty, penalize_len_below):
1096
-
1097
- return generate_and_print_chat(
1098
  prompt=prompt,
1099
  tokenizer=tokenizer,
1100
  chat_tokenizer=chat_tokenizer,
1101
- model=model,
1102
- device=device,
1103
  max_new_tokens=max_new_tokens,
 
 
1104
  top_k=top_k,
1105
  top_p=top_p,
1106
- temperature=temperature,
1107
  repetition_penalty=repetition_penalty,
1108
- penalize_len_below=penalize_len_below
 
 
 
1109
  )
1110
 
1111
 
@@ -1136,13 +1076,13 @@ with gr.Blocks(title="Nepali GPT-2 Text Generator", css=css) as interface:
1136
 
1137
  with gr.Row():
1138
  with gr.Column():
1139
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
1140
  repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")
1141
  with gr.Column():
1142
- top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top K (set to 0 to use Top P)")
1143
  top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, label="Top P (set above 0 to use instead of Top K)")
1144
 
1145
- min_length = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Minimum Length Penalty")
1146
  generate_btn = gr.Button("Generate Text")
1147
 
1148
  with gr.Column():
 
718
  # --------------------------------------------------------------------------------
719
  # -------------------------- New Chat function ---------------------
720
  # --------------------------------------------------------------------------------
721
+ def generate_chat_optimized(
722
  model,
723
  prompt,
724
  tokenizer,
 
727
  context_size,
728
  temperature=0.7,
729
  top_k=50,
730
+ top_p=None,
731
  eos_id=None,
732
  repetition_penalty=1.2,
733
  penalize_len_below=50,
734
+ device=None,
735
+ batch_size=1, # Added parameter
736
+ clean_the_text=True
737
  ):
738
  if device is None:
739
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
740
 
741
  idx = text_to_token_ids(prompt, chat_tokenizer).to(device)
742
 
743
+ # Find EOS token once instead of checking every time
744
+ if not eos_id:
745
+ if "<|endoftext|>" in tokenizer.get_vocab():
746
+ encoded_endoftext = tokenizer.encode("<|endoftext|>")
747
+ eos_id = encoded_endoftext[0] if encoded_endoftext else None
748
+ elif "<|eot_id|>" in tokenizer.get_vocab():
749
+ encoded_endoftext = tokenizer.encode("<|eot_id|>")
750
+ eos_id = encoded_endoftext[0] if encoded_endoftext else None
751
+
752
+ # Pre-compute token frequencies for the initial context
753
  token_freq = {}
754
+ for token_id in idx[0].tolist():
755
+ if token_id in token_freq:
756
+ token_freq[token_id] += 1
757
+ else:
758
+ token_freq[token_id] = 1
759
+
760
+ # Process tokens in batches for efficiency
761
+ with torch.no_grad(): # Move this outside the loop
762
+ for step in range(0, max_new_tokens, batch_size):
763
+ batch_end = min(step + batch_size, max_new_tokens)
764
+ current_batch_size = batch_end - step
765
+
766
+ idx_cond = idx[:, -context_size:]
767
  logits = model(idx_cond)
768
+ logits = logits[:, -1, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
+ # Apply repetition penalty once for the batch
771
+ for token_id in idx[0].tolist()[-current_batch_size:]:
772
+ if token_id in token_freq:
773
+ token_freq[token_id] += 1
774
+ logits[0, token_id] /= repetition_penalty
775
+ else:
776
+ token_freq[token_id] = 1
 
 
777
 
778
+ # Process each token in the batch
779
+ for i in range(current_batch_size):
780
+ current_step = step + i
781
+
782
+ # Penalize EOT token for shorter sequences
783
+ current_logits = logits.clone() # Work with a copy
784
+ if eos_id is not None and current_step < penalize_len_below:
785
+ penalty_factor = 1.0 + (penalize_len_below - current_step) / penalize_len_below
786
+ current_logits[0, eos_id] /= penalty_factor
787
+
788
+ # Apply temperature scaling
789
+ if temperature > 0.0:
790
+ current_logits = current_logits / temperature
791
+
792
+ # Convert logits to probabilities
793
+ probs = torch.softmax(current_logits, dim=-1)
794
+
795
+ # Apply sampling strategies
796
+ if top_p and top_p > 0.0:
797
+ # Nucleus sampling implementation
798
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
799
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
800
+
801
+ sorted_indices_to_remove = cumulative_probs > top_p
802
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
803
+ sorted_indices_to_remove[..., 0] = 0
804
+
805
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
806
+ probs = probs.masked_fill(indices_to_remove, 0.0)
807
+ probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
808
+
809
+ elif top_k and top_k > 0:
810
+ # Top-k sampling implementation
811
+ top_probs, top_indices = torch.topk(probs, min(top_k, probs.size(-1)))
812
+ probs = torch.zeros_like(probs).scatter_(-1, top_indices, top_probs)
813
+ probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
814
+
815
+ # Sample from the filtered distribution
816
+ if temperature > 0.0:
817
+ idx_next = torch.multinomial(probs, num_samples=1)
818
+ else:
819
+ idx_next = torch.argmax(probs, dim=-1, keepdim=True)
820
+
821
+ # Add the next token to the sequence
822
+ idx = torch.cat((idx, idx_next), dim=1)
823
+
824
+ # Check for end of sequence token and break if needed
825
+ if idx_next.item() == eos_id:
826
+ output_text = token_ids_to_text(idx, tokenizer)
827
+ if clean_the_text:
828
+ # Clean the output
829
+ # cleaned_text = clean_chat_output(output_text)
830
+ cleaned_text = clean_text(output_text)
831
+ if '<|eot_id|>' in cleaned_text:
832
+ cleaned_text = cleaned_text.replace('<|eot_id|>','')
833
+ # print("Generated text:\n", cleaned_text)
834
+
835
+ return cleaned_text
836
+ return output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
 
838
+ # Not end of text token. terminate early since it exceeds max_new_tokens
839
+ output_text = token_ids_to_text(idx, tokenizer)
840
  if clean_the_text:
841
  # Clean the output
842
  # cleaned_text = clean_chat_output(output_text)
843
  cleaned_text = clean_text(output_text)
844
  if '<|eot_id|>' in cleaned_text:
845
  cleaned_text = cleaned_text.replace('<|eot_id|>','')
846
+ # print("Generated text:\n", cleaned_text)
847
 
848
  return cleaned_text
849
+ return output_text
 
 
 
 
850
  # ==================================================================================-
851
  # ==================================================================================-
852
  # ==================================================================================-
 
1030
  print_output=True
1031
 
1032
  def generate_text(prompt, max_new_tokens, top_k, top_p, temperature, repetition_penalty, penalize_len_below):
1033
+ return generate_chat_optimized(
1034
+ model=model,
1035
  prompt=prompt,
1036
  tokenizer=tokenizer,
1037
  chat_tokenizer=chat_tokenizer,
 
 
1038
  max_new_tokens=max_new_tokens,
1039
+ context_size=context_length,
1040
+ temperature=temperature,
1041
  top_k=top_k,
1042
  top_p=top_p,
1043
+ eos_id=None,
1044
  repetition_penalty=repetition_penalty,
1045
+ penalize_len_below=penalize_len_below,
1046
+ device=device,
1047
+ batch_size=1, # Added parameter
1048
+ clean_the_text=clean_the_text
1049
  )
1050
 
1051
 
 
1076
 
1077
  with gr.Row():
1078
  with gr.Column():
1079
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.3, step=0.1, label="Temperature")
1080
  repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")
1081
  with gr.Column():
1082
+ top_k = gr.Slider(minimum=0, maximum=100, value=5, step=1, label="Top K (set to 0 to use Top P)")
1083
  top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, label="Top P (set above 0 to use instead of Top K)")
1084
 
1085
+ min_length = gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Minimum Length Penalty")
1086
  generate_btn = gr.Button("Generate Text")
1087
 
1088
  with gr.Column():