File size: 3,127 Bytes
1baa2a7
d0f3977
 
 
 
 
 
 
 
 
 
 
 
f004663
 
 
d0f3977
 
f004663
d0f3977
f004663
 
 
d0f3977
d101153
d0f3977
 
f004663
d0f3977
 
d8cb255
f004663
d8cb255
72a9cf5
d8cb255
 
 
d0f3977
 
f004663
d101153
d0f3977
f004663
 
 
d0f3977
 
f004663
 
 
 
3d93a76
f004663
 
 
13f2004
 
e7654fc
 
 
f004663
d0f3977
f004663
 
e633598
f004663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1baa2a7
d0f3977
f004663
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load your hosted model and tokenizer from Hugging Face.
model_name = "Samurai719214/gptneo-mythology-storyteller"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Use GPU if available.
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Store conversation history
conversation_history = []

def generate_full_story(excerpt: str) -> str:
    """
    Given an incomplete story excerpt, generate the complete story including Parv, Key Event, Section, and continuation.
    """
    if not excerpt.strip():
        return "Please enter a valid story excerpt."
    
    # Tokenize the user-provided excerpt.
    encoded_input = tokenizer(excerpt, return_tensors = "pt")
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
    
    # Generate text with controlled parameters.
    output = model.generate(
        encoded_input["input_ids"],
        attention_mask = encoded_input["attention_mask"],
        max_new_tokens = 200,
        do_sample = True,
        temperature = 0.1,
        top_p = 0.95,
        no_repeat_ngram_size = 2,
        return_dict_in_generate = True
    )
    
    # Decode generated text.
    generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens = True)
    
    # Append to conversation history
    conversation_history.append((excerpt, generated_text))
    
    return generated_text

def get_conversation_history():
    """Displays conversation history."""
    if not conversation_history:
        return "No conversations started."
    return "\n\n".join([f"**User-** {inp}\n**AI-** {out}" for inp, out in conversation_history])

def clear_conversation():
    """Clears the conversation history."""
    if not conversation_history:
        return "No conversations started."
    else:
        conversation_history.clear()
        return "Conversations deleted!"

# Build the Gradio interface.
with gr.Blocks() as interface:
    gr.Markdown("# 🏺 Mythology Storyteller")
    gr.Markdown("Enter a phrase from a chapter of your choice (please include Parv, key event, and section for a better answer). The model will generate the summary of the respective chapter.")
    
    with gr.Row():
        user_input = gr.Textbox(lines = 5, label = "Incomplete story excerpt", placeholder = "Enter an excerpt from the Mahabharata here...")
    
    output_text = gr.Textbox(label = "Chapter summary")
    
    generate_btn = gr.Button("Generate Story")
    generate_btn.click(fn = generate_full_story, inputs = user_input, outputs = output_text)
    
    with gr.Row():
        history_display = gr.Textbox(label = "Conversation History", interactive = False)
    
    show_history_btn = gr.Button("Show Conversation History")
    show_history_btn.click(fn = get_conversation_history, outputs = history_display)
    
    clear_btn = gr.Button("Clear Conversation")
    clear_btn.click(fn = clear_conversation, outputs = history_display)

# Launch the Gradio app.
interface.launch()