Spaces:
Running
Running
Commit
·
29f1f37
1
Parent(s):
348f920
debug run in cpu
Browse files
app.py
CHANGED
@@ -718,7 +718,7 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses, output_dir):
|
|
718 |
# --------------------------------------------------------------------------------
|
719 |
# -------------------------- New Chat function ---------------------
|
720 |
# --------------------------------------------------------------------------------
|
721 |
-
def
|
722 |
model,
|
723 |
prompt,
|
724 |
tokenizer,
|
@@ -727,189 +727,126 @@ def generate_chat(
|
|
727 |
context_size,
|
728 |
temperature=0.7,
|
729 |
top_k=50,
|
730 |
-
top_p=None,
|
731 |
eos_id=None,
|
732 |
repetition_penalty=1.2,
|
733 |
penalize_len_below=50,
|
734 |
-
device=None
|
|
|
|
|
735 |
):
|
736 |
if device is None:
|
737 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
738 |
|
739 |
idx = text_to_token_ids(prompt, chat_tokenizer).to(device)
|
740 |
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
|
|
|
|
|
|
748 |
token_freq = {}
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
logits = model(idx_cond)
|
754 |
-
|
755 |
-
|
756 |
-
# Apply repetition penalty
|
757 |
-
for token_id in idx[0].tolist():
|
758 |
-
if token_id in token_freq:
|
759 |
-
token_freq[token_id] += 1
|
760 |
-
logits[0, token_id] /= repetition_penalty
|
761 |
-
else:
|
762 |
-
token_freq[token_id] = 1
|
763 |
-
|
764 |
-
# Penalize EOT token for shorter sequences
|
765 |
-
if eos_id is not None and step < penalize_len_below:
|
766 |
-
penalty_factor = 1.0 + (penalize_len_below - step) / penalize_len_below
|
767 |
-
logits[0, eos_id] /= penalty_factor
|
768 |
-
|
769 |
-
# Apply temperature scaling
|
770 |
-
if temperature > 0.0:
|
771 |
-
logits = logits / temperature
|
772 |
-
|
773 |
-
# Convert logits to probabilities
|
774 |
-
probs = torch.softmax(logits, dim=-1)
|
775 |
-
|
776 |
-
# Apply top-p (nucleus) sampling if specified
|
777 |
-
if top_p and top_p > 0.0:
|
778 |
-
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
779 |
-
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
780 |
|
781 |
-
#
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
|
789 |
-
probs = probs.masked_fill(indices_to_remove, 0.0)
|
790 |
|
791 |
-
#
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
if len(repetition_window) >= 3 and all(w == word for w in repetition_window[-3:]):
|
851 |
-
continue # Skip this word if the last 3 words were identical to it
|
852 |
-
cleaned_words.append(word)
|
853 |
-
repetition_window.append(word)
|
854 |
-
if len(repetition_window) > 10: # Keep a limited window
|
855 |
-
repetition_window.pop(0)
|
856 |
-
|
857 |
-
return ' '.join(cleaned_words).strip()
|
858 |
-
|
859 |
-
def generate_and_print_chat(
|
860 |
-
prompt,
|
861 |
-
tokenizer,
|
862 |
-
chat_tokenizer,
|
863 |
-
model,
|
864 |
-
device=None,
|
865 |
-
max_new_tokens=150,
|
866 |
-
context_length=None,
|
867 |
-
temperature=0.7,
|
868 |
-
top_k=50,
|
869 |
-
top_p=0.9,
|
870 |
-
repetition_penalty=1.2,
|
871 |
-
clean_the_text=False,
|
872 |
-
print_output=True
|
873 |
-
):
|
874 |
-
if device is None:
|
875 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
876 |
-
|
877 |
-
if context_length is None:
|
878 |
-
# Try to get from model config or use default
|
879 |
-
context_length = getattr(model, "context_length", 2048)
|
880 |
-
|
881 |
-
# Generate tokens
|
882 |
-
token_ids = generate_chat(
|
883 |
-
model=model,
|
884 |
-
prompt=prompt,
|
885 |
-
tokenizer=tokenizer,
|
886 |
-
chat_tokenizer=chat_tokenizer,
|
887 |
-
max_new_tokens=max_new_tokens,
|
888 |
-
context_size=context_length,
|
889 |
-
temperature=temperature,
|
890 |
-
top_k=top_k,
|
891 |
-
top_p=top_p,
|
892 |
-
repetition_penalty=repetition_penalty,
|
893 |
-
device=device
|
894 |
-
)
|
895 |
-
|
896 |
-
# Convert tokens to text
|
897 |
-
output_text = token_ids_to_text(token_ids, tokenizer)
|
898 |
|
|
|
|
|
899 |
if clean_the_text:
|
900 |
# Clean the output
|
901 |
# cleaned_text = clean_chat_output(output_text)
|
902 |
cleaned_text = clean_text(output_text)
|
903 |
if '<|eot_id|>' in cleaned_text:
|
904 |
cleaned_text = cleaned_text.replace('<|eot_id|>','')
|
905 |
-
print("Generated text:\n", cleaned_text)
|
906 |
|
907 |
return cleaned_text
|
908 |
-
|
909 |
-
print("Generated text:\n", output_text)
|
910 |
-
|
911 |
-
return output_text
|
912 |
-
|
913 |
# ==================================================================================-
|
914 |
# ==================================================================================-
|
915 |
# ==================================================================================-
|
@@ -1093,19 +1030,22 @@ clean_the_text=False,
|
|
1093 |
print_output=True
|
1094 |
|
1095 |
def generate_text(prompt, max_new_tokens, top_k, top_p, temperature, repetition_penalty, penalize_len_below):
|
1096 |
-
|
1097 |
-
|
1098 |
prompt=prompt,
|
1099 |
tokenizer=tokenizer,
|
1100 |
chat_tokenizer=chat_tokenizer,
|
1101 |
-
model=model,
|
1102 |
-
device=device,
|
1103 |
max_new_tokens=max_new_tokens,
|
|
|
|
|
1104 |
top_k=top_k,
|
1105 |
top_p=top_p,
|
1106 |
-
|
1107 |
repetition_penalty=repetition_penalty,
|
1108 |
-
penalize_len_below=penalize_len_below
|
|
|
|
|
|
|
1109 |
)
|
1110 |
|
1111 |
|
@@ -1136,13 +1076,13 @@ with gr.Blocks(title="Nepali GPT-2 Text Generator", css=css) as interface:
|
|
1136 |
|
1137 |
with gr.Row():
|
1138 |
with gr.Column():
|
1139 |
-
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.
|
1140 |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")
|
1141 |
with gr.Column():
|
1142 |
-
top_k = gr.Slider(minimum=0, maximum=100, value=
|
1143 |
top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, label="Top P (set above 0 to use instead of Top K)")
|
1144 |
|
1145 |
-
min_length = gr.Slider(minimum=1, maximum=200, value=
|
1146 |
generate_btn = gr.Button("Generate Text")
|
1147 |
|
1148 |
with gr.Column():
|
|
|
718 |
# --------------------------------------------------------------------------------
|
719 |
# -------------------------- New Chat function ---------------------
|
720 |
# --------------------------------------------------------------------------------
|
721 |
+
def generate_chat_optimized(
|
722 |
model,
|
723 |
prompt,
|
724 |
tokenizer,
|
|
|
727 |
context_size,
|
728 |
temperature=0.7,
|
729 |
top_k=50,
|
730 |
+
top_p=None,
|
731 |
eos_id=None,
|
732 |
repetition_penalty=1.2,
|
733 |
penalize_len_below=50,
|
734 |
+
device=None,
|
735 |
+
batch_size=1, # Added parameter
|
736 |
+
clean_the_text=True
|
737 |
):
|
738 |
if device is None:
|
739 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
740 |
|
741 |
idx = text_to_token_ids(prompt, chat_tokenizer).to(device)
|
742 |
|
743 |
+
# Find EOS token once instead of checking every time
|
744 |
+
if not eos_id:
|
745 |
+
if "<|endoftext|>" in tokenizer.get_vocab():
|
746 |
+
encoded_endoftext = tokenizer.encode("<|endoftext|>")
|
747 |
+
eos_id = encoded_endoftext[0] if encoded_endoftext else None
|
748 |
+
elif "<|eot_id|>" in tokenizer.get_vocab():
|
749 |
+
encoded_endoftext = tokenizer.encode("<|eot_id|>")
|
750 |
+
eos_id = encoded_endoftext[0] if encoded_endoftext else None
|
751 |
+
|
752 |
+
# Pre-compute token frequencies for the initial context
|
753 |
token_freq = {}
|
754 |
+
for token_id in idx[0].tolist():
|
755 |
+
if token_id in token_freq:
|
756 |
+
token_freq[token_id] += 1
|
757 |
+
else:
|
758 |
+
token_freq[token_id] = 1
|
759 |
+
|
760 |
+
# Process tokens in batches for efficiency
|
761 |
+
with torch.no_grad(): # Move this outside the loop
|
762 |
+
for step in range(0, max_new_tokens, batch_size):
|
763 |
+
batch_end = min(step + batch_size, max_new_tokens)
|
764 |
+
current_batch_size = batch_end - step
|
765 |
+
|
766 |
+
idx_cond = idx[:, -context_size:]
|
767 |
logits = model(idx_cond)
|
768 |
+
logits = logits[:, -1, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
|
770 |
+
# Apply repetition penalty once for the batch
|
771 |
+
for token_id in idx[0].tolist()[-current_batch_size:]:
|
772 |
+
if token_id in token_freq:
|
773 |
+
token_freq[token_id] += 1
|
774 |
+
logits[0, token_id] /= repetition_penalty
|
775 |
+
else:
|
776 |
+
token_freq[token_id] = 1
|
|
|
|
|
777 |
|
778 |
+
# Process each token in the batch
|
779 |
+
for i in range(current_batch_size):
|
780 |
+
current_step = step + i
|
781 |
+
|
782 |
+
# Penalize EOT token for shorter sequences
|
783 |
+
current_logits = logits.clone() # Work with a copy
|
784 |
+
if eos_id is not None and current_step < penalize_len_below:
|
785 |
+
penalty_factor = 1.0 + (penalize_len_below - current_step) / penalize_len_below
|
786 |
+
current_logits[0, eos_id] /= penalty_factor
|
787 |
+
|
788 |
+
# Apply temperature scaling
|
789 |
+
if temperature > 0.0:
|
790 |
+
current_logits = current_logits / temperature
|
791 |
+
|
792 |
+
# Convert logits to probabilities
|
793 |
+
probs = torch.softmax(current_logits, dim=-1)
|
794 |
+
|
795 |
+
# Apply sampling strategies
|
796 |
+
if top_p and top_p > 0.0:
|
797 |
+
# Nucleus sampling implementation
|
798 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
799 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
800 |
+
|
801 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
802 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
803 |
+
sorted_indices_to_remove[..., 0] = 0
|
804 |
+
|
805 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
|
806 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
807 |
+
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
|
808 |
+
|
809 |
+
elif top_k and top_k > 0:
|
810 |
+
# Top-k sampling implementation
|
811 |
+
top_probs, top_indices = torch.topk(probs, min(top_k, probs.size(-1)))
|
812 |
+
probs = torch.zeros_like(probs).scatter_(-1, top_indices, top_probs)
|
813 |
+
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
|
814 |
+
|
815 |
+
# Sample from the filtered distribution
|
816 |
+
if temperature > 0.0:
|
817 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
818 |
+
else:
|
819 |
+
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
|
820 |
+
|
821 |
+
# Add the next token to the sequence
|
822 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
823 |
+
|
824 |
+
# Check for end of sequence token and break if needed
|
825 |
+
if idx_next.item() == eos_id:
|
826 |
+
output_text = token_ids_to_text(idx, tokenizer)
|
827 |
+
if clean_the_text:
|
828 |
+
# Clean the output
|
829 |
+
# cleaned_text = clean_chat_output(output_text)
|
830 |
+
cleaned_text = clean_text(output_text)
|
831 |
+
if '<|eot_id|>' in cleaned_text:
|
832 |
+
cleaned_text = cleaned_text.replace('<|eot_id|>','')
|
833 |
+
# print("Generated text:\n", cleaned_text)
|
834 |
+
|
835 |
+
return cleaned_text
|
836 |
+
return output_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
+
# Not end of text token. terminate early since it exceeds max_new_tokens
|
839 |
+
output_text = token_ids_to_text(idx, tokenizer)
|
840 |
if clean_the_text:
|
841 |
# Clean the output
|
842 |
# cleaned_text = clean_chat_output(output_text)
|
843 |
cleaned_text = clean_text(output_text)
|
844 |
if '<|eot_id|>' in cleaned_text:
|
845 |
cleaned_text = cleaned_text.replace('<|eot_id|>','')
|
846 |
+
# print("Generated text:\n", cleaned_text)
|
847 |
|
848 |
return cleaned_text
|
849 |
+
return output_text
|
|
|
|
|
|
|
|
|
850 |
# ==================================================================================-
|
851 |
# ==================================================================================-
|
852 |
# ==================================================================================-
|
|
|
1030 |
print_output=True
|
1031 |
|
1032 |
def generate_text(prompt, max_new_tokens, top_k, top_p, temperature, repetition_penalty, penalize_len_below):
|
1033 |
+
return generate_chat_optimized(
|
1034 |
+
model=model,
|
1035 |
prompt=prompt,
|
1036 |
tokenizer=tokenizer,
|
1037 |
chat_tokenizer=chat_tokenizer,
|
|
|
|
|
1038 |
max_new_tokens=max_new_tokens,
|
1039 |
+
context_size=context_length,
|
1040 |
+
temperature=temperature,
|
1041 |
top_k=top_k,
|
1042 |
top_p=top_p,
|
1043 |
+
eos_id=None,
|
1044 |
repetition_penalty=repetition_penalty,
|
1045 |
+
penalize_len_below=penalize_len_below,
|
1046 |
+
device=device,
|
1047 |
+
batch_size=1, # Added parameter
|
1048 |
+
clean_the_text=clean_the_text
|
1049 |
)
|
1050 |
|
1051 |
|
|
|
1076 |
|
1077 |
with gr.Row():
|
1078 |
with gr.Column():
|
1079 |
+
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.3, step=0.1, label="Temperature")
|
1080 |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition Penalty")
|
1081 |
with gr.Column():
|
1082 |
+
top_k = gr.Slider(minimum=0, maximum=100, value=5, step=1, label="Top K (set to 0 to use Top P)")
|
1083 |
top_p = gr.Slider(minimum=0, maximum=1.0, value=0.9, step=0.05, label="Top P (set above 0 to use instead of Top K)")
|
1084 |
|
1085 |
+
min_length = gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Minimum Length Penalty")
|
1086 |
generate_btn = gr.Button("Generate Text")
|
1087 |
|
1088 |
with gr.Column():
|