Update app.py
Browse files
app.py
CHANGED
@@ -10,15 +10,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
10 |
trained_model.to(device)
|
11 |
untrained_model.to(device)
|
12 |
|
13 |
-
def generate(commentary_text):
|
14 |
-
# Generate text using the
|
15 |
input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
|
16 |
-
trained_output = trained_model.generate(input_ids, max_length=
|
17 |
trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
|
18 |
|
19 |
-
# Generate text using the
|
20 |
input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
|
21 |
-
untrained_output = untrained_model.generate(input_ids, max_length=
|
22 |
untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
|
23 |
|
24 |
return trained_text, untrained_text
|
@@ -26,10 +26,17 @@ def generate(commentary_text):
|
|
26 |
# Create Gradio interface
|
27 |
iface = gr.Interface(
|
28 |
fn=generate,
|
29 |
-
inputs=
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
title="GPT-2 Text Generation",
|
32 |
-
description="start writing a cricket commentary and GPT-2 will continue it using both a
|
33 |
)
|
34 |
|
35 |
# Launch the app
|
|
|
10 |
trained_model.to(device)
|
11 |
untrained_model.to(device)
|
12 |
|
13 |
+
def generate(commentary_text, max_length, temperature):
|
14 |
+
# Generate text using the finetuned model
|
15 |
input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
|
16 |
+
trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=False, temperature=temperature)
|
17 |
trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
|
18 |
|
19 |
+
# Generate text using the base model
|
20 |
input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
|
21 |
+
untrained_output = untrained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=False,temperature=temperature)
|
22 |
untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
|
23 |
|
24 |
return trained_text, untrained_text
|
|
|
26 |
# Create Gradio interface
|
27 |
iface = gr.Interface(
|
28 |
fn=generate,
|
29 |
+
inputs=[
|
30 |
+
gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
|
31 |
+
gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max Length"),
|
32 |
+
gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature")
|
33 |
+
],
|
34 |
+
outputs=[
|
35 |
+
gr.outputs.Textbox(label="commentary generation from finetuned GPT2 Model"),
|
36 |
+
gr.outputs.Textbox(label="commentary generation from base GPT2 Model")
|
37 |
+
],
|
38 |
title="GPT-2 Text Generation",
|
39 |
+
description="start writing a cricket commentary and GPT-2 will continue it using both a finetuned and base model."
|
40 |
)
|
41 |
|
42 |
# Launch the app
|