Spaces:
Sleeping
Sleeping
File size: 3,891 Bytes
36a15a7 1baa2a7 d0f3977 cc34dac d0f3977 cc34dac 4ed952a d0f3977 cc34dac f004663 dbb9f19 657fa05 dbb9f19 00c64ce cc34dac b273e1a cc34dac b273e1a cc34dac d0f3977 cc34dac b273e1a cc34dac dbb9f19 cc34dac dbb9f19 00c64ce cc34dac dbb9f19 657fa05 00c64ce 657fa05 00c64ce cc34dac dbb9f19 d489bb6 00c64ce cc34dac d489bb6 dbb9f19 cc34dac d489bb6 dbb9f19 d489bb6 00c64ce d489bb6 00c64ce dbb9f19 cc34dac d489bb6 cc34dac 657fa05 00c64ce d489bb6 00c64ce d489bb6 dbb9f19 00c64ce d489bb6 cc34dac dbb9f19 d489bb6 657fa05 d489bb6 cc34dac dbb9f19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import cache_manager
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model + tokenizer
model_name = "Samurai719214/gptneo-mythology-storyteller"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# Story generation with history
def generate_full_story_chunks(excerpt, history_state):
if not excerpt or not excerpt.strip():
history_state.append(("❌", "⚠️ Enter a story excerpt."))
yield history_state, gr.update(visible=False), gr.update(interactive=True)
return
inputs = tokenizer(excerpt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
output_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=400,
do_sample=True,
temperature=0.1,
top_p=0.9,
no_repeat_ngram_size=2,
)
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Append user input
history_state.append(("You", excerpt))
# Stream response in chunks
response = ""
for i in range(0, len(generated_text), 200):
response += generated_text[i:i+200]
if len(history_state) > 0 and history_state[-1][0] == "AI":
history_state[-1] = ("AI", response)
else:
history_state.append(("AI", response))
yield history_state, gr.update(visible=False), gr.update(interactive=True)
# Clear conversation
def clear_history():
return [], gr.update(interactive=False)
# Enable/disable generate button
def toggle_generate_button(text):
return gr.update(interactive=bool(text.strip()))
# Build UI
with gr.Blocks() as demo:
gr.Markdown("## 🏺 Mythology Storyteller")
gr.Markdown("Enter a phrase from a chapter of your choice (include Parv, key event, and section for better results).")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(
label="Incomplete story excerpt",
placeholder="Enter an excerpt from the Mahabharata here...",
lines=4,
)
summary_input = gr.Textbox(
label="Chapter summary (optional)",
placeholder="Enter summary if available...",
lines=2,
)
generate_btn = gr.Button("✨ Generate Story", interactive=False)
with gr.Column():
output_text = gr.Chatbot(
label="Generated Story",
height=400,
placeholder="⚔️ Legends are being written..."
)
spinner = gr.Markdown("", visible=False) # spinner placeholder
clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)
gr.Markdown("---")
gr.Markdown("🔌 Use via API (see Hugging Face Inference docs).")
# Toggle generate button when input changes
user_input.change(
fn=toggle_generate_button,
inputs=user_input,
outputs=generate_btn,
)
# Show spinner when generating
def show_spinner():
return gr.update(value="⏳ Generating story...", visible=True)
def hide_spinner():
return gr.update(visible=False)
generate_btn.click(
fn=show_spinner,
inputs=None,
outputs=spinner,
).then(
fn=generate_full_story_chunks,
inputs=[user_input, output_text],
outputs=[output_text, spinner, clear_btn],
).then(
fn=hide_spinner,
inputs=None,
outputs=spinner,
)
# Clear history
clear_btn.click(
fn=clear_history,
inputs=None,
outputs=[output_text, clear_btn],
)
# Launch app
if __name__ == "__main__":
demo.launch() |