Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load your hosted model and tokenizer from Hugging Face. | |
model_name = "Samurai719214/gptneo-mythology-storyteller" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Use GPU if available. | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
# Store conversation history | |
conversation_history = [] | |
def generate_full_story(excerpt: str) -> str: | |
""" | |
Given an incomplete story excerpt, generate the complete story including Parv, Key Event, Section, and continuation. | |
""" | |
if not excerpt.strip(): | |
return "Please enter a valid story excerpt." | |
# Tokenize the user-provided excerpt. | |
encoded_input = tokenizer(excerpt, return_tensors = "pt") | |
encoded_input = {k: v.to(device) for k, v in encoded_input.items()} | |
# Generate text with controlled parameters. | |
output = model.generate( | |
encoded_input["input_ids"], | |
attention_mask = encoded_input["attention_mask"], | |
max_new_tokens = 200, | |
do_sample = True, | |
temperature = 0.1, | |
top_p = 0.95, | |
no_repeat_ngram_size = 2, | |
return_dict_in_generate = True | |
) | |
# Decode generated text. | |
generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens = True) | |
# Append to conversation history | |
conversation_history.append((excerpt, generated_text)) | |
return generated_text | |
def get_conversation_history(): | |
"""Displays conversation history.""" | |
if not conversation_history: | |
return "No conversations started." | |
return "\n\n".join([f"**User-** {inp}\n**AI-** {out}" for inp, out in conversation_history]) | |
def clear_conversation(): | |
"""Clears the conversation history.""" | |
if not conversation_history: | |
return "No conversations started." | |
else: | |
conversation_history.clear() | |
return "Conversations deleted!" | |
# Build the Gradio interface. | |
with gr.Blocks() as interface: | |
gr.Markdown("# 🏺 Mythology Storyteller") | |
gr.Markdown("Enter a phrase from a chapter of your choice (please include Parv, key event, and section for a better answer). The model will generate the summary of the respective chapter.") | |
with gr.Row(): | |
user_input = gr.Textbox(lines = 5, label = "Incomplete story excerpt", placeholder = "Enter an excerpt from the Mahabharata here...") | |
output_text = gr.Textbox(label = "Chapter summary") | |
generate_btn = gr.Button("Generate Story") | |
generate_btn.click(fn = generate_full_story, inputs = user_input, outputs = output_text) | |
with gr.Row(): | |
history_display = gr.Textbox(label = "Conversation History", interactive = False) | |
show_history_btn = gr.Button("Show Conversation History") | |
show_history_btn.click(fn = get_conversation_history, outputs = history_display) | |
clear_btn = gr.Button("Clear Conversation") | |
clear_btn.click(fn = clear_conversation, outputs = history_display) | |
# Launch the Gradio app. | |
interface.launch() | |