|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
model_name = "Samurai719214/gptneo-mythology-storyteller" |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
|
|
conversation_history = [] |
|
|
|
def generate_full_story(excerpt: str) -> str: |
|
""" |
|
Given an incomplete story excerpt, generate the complete story including Parv, Key Event, Section, and continuation. |
|
""" |
|
if not excerpt.strip(): |
|
return "Please enter a valid story excerpt." |
|
|
|
|
|
encoded_input = tokenizer(excerpt, return_tensors = "pt") |
|
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} |
|
|
|
|
|
output = model.generate( |
|
encoded_input["input_ids"], |
|
attention_mask = encoded_input["attention_mask"], |
|
max_new_tokens = 200, |
|
do_sample = True, |
|
temperature = 0, |
|
top_p = 0.95, |
|
no_repeat_ngram_size = 2, |
|
return_dict_in_generate = True |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens = True) |
|
|
|
|
|
conversation_history.append((excerpt, generated_text)) |
|
|
|
return generated_text |
|
|
|
def get_conversation_history(): |
|
"""Displays conversation history.""" |
|
if not conversation_history: |
|
return "No conversations started." |
|
return "\n\n".join([f"**User:** {inp}\n**AI:** {out}" for inp, out in conversation_history]) |
|
|
|
def clear_conversation(): |
|
"""Clears the conversation history.""" |
|
if not conversation_history: |
|
return "No conversations started." |
|
else: |
|
conversation_history.clear() |
|
return "Conversations deleted!" |
|
|
|
|
|
with gr.Blocks() as interface: |
|
gr.Markdown("# 🏺 Mythology Storyteller") |
|
gr.Markdown("Enter a phrase from a chapter of your choice. The model will generate the summary of the respective chapter.") |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(lines = 5, label = "Incomplete story excerpt", placeholder = "Enter an excerpt from the Mahabharata here...") |
|
|
|
output_text = gr.Textbox(label = "Chapter summary") |
|
|
|
generate_btn = gr.Button("Generate Story") |
|
generate_btn.click(fn = generate_full_story, inputs = user_input, outputs = output_text) |
|
|
|
with gr.Row(): |
|
history_display = gr.Textbox(label = "Conversation History", interactive = False) |
|
|
|
show_history_btn = gr.Button("Show Conversation History") |
|
show_history_btn.click(fn = get_conversation_history, outputs = history_display) |
|
|
|
clear_btn = gr.Button("Clear Conversation") |
|
clear_btn.click(fn = clear_conversation, outputs = history_display) |
|
|
|
|
|
interface.launch() |
|
|