aayushraina commited on
Commit
ba54c96
·
verified ·
1 Parent(s): 69497aa

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -1
app.py CHANGED
@@ -52,4 +52,61 @@ def load_model():
52
  print(f"Fallback failed: {e}")
53
  return None, None
54
 
55
- # Rest of the code remains the same...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  print(f"Fallback failed: {e}")
53
  return None, None
54
 
55
+ # Text generation function
56
+ def generate_text(prompt, max_length=500, temperature=0.8, top_k=40, top_p=0.9):
57
+ # Encode the input prompt
58
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
59
+
60
+ # Generate text
61
+ with torch.no_grad():
62
+ output = model.generate(
63
+ input_ids,
64
+ max_length=max_length,
65
+ temperature=temperature,
66
+ top_k=top_k,
67
+ top_p=top_p,
68
+ do_sample=True,
69
+ pad_token_id=tokenizer.eos_token_id,
70
+ num_return_sequences=1
71
+ )
72
+
73
+ # Decode and return the generated text
74
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
75
+ return generated_text
76
+
77
+ # Load model and tokenizer globally
78
+ print("Loading model and tokenizer...")
79
+ model, tokenizer = load_model()
80
+ print("Model loaded successfully!")
81
+
82
+ # Create Gradio interface
83
+ demo = gr.Interface(
84
+ fn=generate_text,
85
+ inputs=[
86
+ gr.Textbox(label="Enter your prompt", placeholder="Start your text here...", lines=2),
87
+ gr.Slider(minimum=10, maximum=1000, value=500, step=10, label="Maximum Length"),
88
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
89
+ gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k"),
90
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p"),
91
+ ],
92
+ outputs=gr.Textbox(label="Generated Text", lines=10),
93
+ title="Shakespeare-style Text Generator",
94
+ description="""Generate Shakespeare-style text using a fine-tuned GPT-2 model.
95
+
96
+ Parameters:
97
+ - Temperature: Higher values make the output more random, lower values more focused
98
+ - Top-k: Number of highest probability vocabulary tokens to keep for top-k filtering
99
+ - Top-p: Cumulative probability for nucleus sampling
100
+ """,
101
+ examples=[
102
+ ["First Citizen:", 500, 0.8, 40, 0.9],
103
+ ["To be, or not to be,", 500, 0.8, 40, 0.9],
104
+ ["Friends, Romans, countrymen,", 500, 0.8, 40, 0.9],
105
+ ["O Romeo, Romeo,", 500, 0.8, 40, 0.9],
106
+ ["Now is the winter of our discontent", 500, 0.8, 40, 0.9]
107
+ ]
108
+ )
109
+
110
+ # Launch the app
111
+ if __name__ == "__main__":
112
+ demo.launch()