File size: 3,891 Bytes
36a15a7
1baa2a7
d0f3977
cc34dac
d0f3977
cc34dac
4ed952a
d0f3977
 
cc34dac
 
f004663
dbb9f19
 
657fa05
dbb9f19
00c64ce
cc34dac
 
 
 
 
b273e1a
cc34dac
 
 
 
b273e1a
cc34dac
 
d0f3977
cc34dac
b273e1a
cc34dac
dbb9f19
 
 
 
 
cc34dac
dbb9f19
 
 
 
 
00c64ce
cc34dac
 
 
dbb9f19
657fa05
00c64ce
657fa05
00c64ce
cc34dac
dbb9f19
d489bb6
 
00c64ce
cc34dac
d489bb6
 
 
 
 
 
 
 
 
 
 
 
dbb9f19
cc34dac
d489bb6
dbb9f19
d489bb6
00c64ce
 
d489bb6
00c64ce
dbb9f19
cc34dac
d489bb6
 
cc34dac
657fa05
 
 
 
 
 
 
00c64ce
 
 
 
 
 
 
d489bb6
00c64ce
 
 
 
d489bb6
dbb9f19
00c64ce
 
 
 
 
d489bb6
cc34dac
dbb9f19
d489bb6
 
 
657fa05
d489bb6
cc34dac
 
 
dbb9f19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import cache_manager
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model + tokenizer
model_name = "Samurai719214/gptneo-mythology-storyteller"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# Story generation with history
def generate_full_story_chunks(excerpt, history_state):
    if not excerpt or not excerpt.strip():
        history_state.append(("❌", "⚠️ Enter a story excerpt."))
        yield history_state, gr.update(visible=False), gr.update(interactive=True)
        return

    inputs = tokenizer(excerpt, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    output_ids = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=400,
        do_sample=True,
        temperature=0.1,
        top_p=0.9,
        no_repeat_ngram_size=2,
    )

    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Append user input
    history_state.append(("You", excerpt))

    # Stream response in chunks
    response = ""
    for i in range(0, len(generated_text), 200):
        response += generated_text[i:i+200]
        if len(history_state) > 0 and history_state[-1][0] == "AI":
            history_state[-1] = ("AI", response)
        else:
            history_state.append(("AI", response))
        yield history_state, gr.update(visible=False), gr.update(interactive=True)

# Clear conversation
def clear_history():
    return [], gr.update(interactive=False)

# Enable/disable generate button
def toggle_generate_button(text):
    return gr.update(interactive=bool(text.strip()))

# Build UI
with gr.Blocks() as demo:
    gr.Markdown("## 🏺 Mythology Storyteller")
    gr.Markdown("Enter a phrase from a chapter of your choice (include Parv, key event, and section for better results).")

    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(
                label="Incomplete story excerpt",
                placeholder="Enter an excerpt from the Mahabharata here...",
                lines=4,
            )
            summary_input = gr.Textbox(
                label="Chapter summary (optional)",
                placeholder="Enter summary if available...",
                lines=2,
            )
            generate_btn = gr.Button("✨ Generate Story", interactive=False)

        with gr.Column():
            output_text = gr.Chatbot(
                label="Generated Story",
                height=400,
                placeholder="⚔️ Legends are being written..." 
            )
            spinner = gr.Markdown("", visible=False)  # spinner placeholder
            clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False)

    gr.Markdown("---")
    gr.Markdown("🔌 Use via API (see Hugging Face Inference docs).")

    # Toggle generate button when input changes
    user_input.change(
        fn=toggle_generate_button,
        inputs=user_input,
        outputs=generate_btn,
    )

    # Show spinner when generating
    def show_spinner():
        return gr.update(value="⏳ Generating story...", visible=True)

    def hide_spinner():
        return gr.update(visible=False)

    generate_btn.click(
        fn=show_spinner,
        inputs=None,
        outputs=spinner,
    ).then(
        fn=generate_full_story_chunks,
        inputs=[user_input, output_text],
        outputs=[output_text, spinner, clear_btn],
    ).then(
        fn=hide_spinner,
        inputs=None,
        outputs=spinner,
    )

    # Clear history
    clear_btn.click(
        fn=clear_history,
        inputs=None,
        outputs=[output_text, clear_btn],
    )

# Launch app
if __name__ == "__main__":
    demo.launch()