aayushraina commited on
Commit
15e2a67
·
verified ·
1 Parent(s): 028bd08

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +37 -26
  2. requirements.txt +2 -6
app.py CHANGED
@@ -2,40 +2,51 @@ import gradio as gr
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
 
5
- # Load model and tokenizer from Hugging Face
6
  def load_model():
7
- model_name = "aayushraina/gpt2shakespeare"
8
- tokenizer = GPT2Tokenizer.from_pretrained(model_name)
9
- model = GPT2LMHeadModel.from_pretrained(model_name)
10
- model.eval()
11
- return model, tokenizer
 
 
 
 
 
 
12
 
13
  # Text generation function
14
  def generate_text(prompt, max_length=500, temperature=0.8, top_k=40, top_p=0.9):
15
- # Encode the input prompt
16
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
17
-
18
- # Generate text
19
- with torch.no_grad():
20
- output = model.generate(
21
- input_ids,
22
- max_length=max_length,
23
- temperature=temperature,
24
- top_k=top_k,
25
- top_p=top_p,
26
- do_sample=True,
27
- pad_token_id=tokenizer.eos_token_id,
28
- num_return_sequences=1
29
- )
30
-
31
- # Decode and return the generated text
32
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
33
- return generated_text
 
 
 
 
 
 
34
 
35
  # Load model and tokenizer globally
36
  print("Loading model and tokenizer...")
37
  model, tokenizer = load_model()
38
- print("Model loaded successfully!")
39
 
40
  # Create Gradio interface
41
  demo = gr.Interface(
 
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
 
5
+ # Load model and tokenizer
6
  def load_model():
7
+ try:
8
+ # Load the fine-tuned model
9
+ model = GPT2LMHeadModel.from_pretrained("aayushraina/gpt2shakespeare")
10
+ # Use the base GPT-2 tokenizer
11
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
12
+ model.eval()
13
+ print("Model and tokenizer loaded successfully!")
14
+ return model, tokenizer
15
+ except Exception as e:
16
+ print(f"Error loading model: {e}")
17
+ return None, None
18
 
19
  # Text generation function
20
  def generate_text(prompt, max_length=500, temperature=0.8, top_k=40, top_p=0.9):
21
+ if model is None or tokenizer is None:
22
+ return "Error: Model not loaded properly"
23
+
24
+ try:
25
+ # Encode the input prompt
26
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
27
+
28
+ # Generate text
29
+ with torch.no_grad():
30
+ output = model.generate(
31
+ input_ids,
32
+ max_length=max_length,
33
+ temperature=temperature,
34
+ top_k=top_k,
35
+ top_p=top_p,
36
+ do_sample=True,
37
+ pad_token_id=tokenizer.eos_token_id,
38
+ num_return_sequences=1
39
+ )
40
+
41
+ # Decode and return the generated text
42
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
43
+ return generated_text
44
+ except Exception as e:
45
+ return f"Error during generation: {str(e)}"
46
 
47
  # Load model and tokenizer globally
48
  print("Loading model and tokenizer...")
49
  model, tokenizer = load_model()
 
50
 
51
  # Create Gradio interface
52
  demo = gr.Interface(
requirements.txt CHANGED
@@ -1,7 +1,3 @@
1
- wandb
2
- tiktoken
3
  torch>=2.0.0
4
- numpy>=1.24.0
5
- tqdm
6
- transformers
7
- gradio>=4.0.0
 
 
 
1
  torch>=2.0.0
2
+ transformers>=4.30.0
3
+ gradio>=4.0.0