Shriti09 commited on
Commit
0fb0cc9
·
verified ·
1 Parent(s): db8b843

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -1,18 +1,12 @@
1
  import torch
2
  import gradio as gr
3
- from model import CustomLLM
4
- from transformers import GPT2Tokenizer
5
 
6
  class ModelLoader:
7
  def __init__(self):
8
  # Load config
9
- self.config = {
10
- "vocab_size": 50257, # Update with your actual values
11
- "hidden_size": 768,
12
- "num_hidden_layers": 12,
13
- "rms_norm_eps": 1e-6
14
- }
15
-
16
  # Instantiate model
17
  self.model = CustomLLM(self.config)
18
 
@@ -22,7 +16,7 @@ class ModelLoader:
22
  self.model.eval()
23
 
24
  # Load tokenizer
25
- self.tokenizer = GPT2Tokenizer.from_pretrained('tokenizer/')
26
  self.tokenizer.pad_token = self.tokenizer.eos_token
27
 
28
  def generate(self, prompt, max_new_tokens=100, temperature=0.9, top_k=50, top_p=0.95):
@@ -36,7 +30,7 @@ class ModelLoader:
36
  temperature=temperature,
37
  top_k=top_k,
38
  top_p=top_p,
39
- eos_token_id=self.tokenizer.eos_token_id,
40
  pad_token_id=self.tokenizer.pad_token_id
41
  )
42
 
@@ -60,4 +54,4 @@ interface = gr.Interface(
60
  description="Generate text using your custom-trained LLM"
61
  )
62
 
63
- interface.launch()
 
1
  import torch
2
  import gradio as gr
3
+ from model import CustomLLM, CustomConfig
4
+ from transformers import AutoTokenizer
5
 
6
  class ModelLoader:
7
  def __init__(self):
8
  # Load config
9
+ self.config = CustomConfig()
 
 
 
 
 
 
10
  # Instantiate model
11
  self.model = CustomLLM(self.config)
12
 
 
16
  self.model.eval()
17
 
18
  # Load tokenizer
19
+ self.tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
20
  self.tokenizer.pad_token = self.tokenizer.eos_token
21
 
22
  def generate(self, prompt, max_new_tokens=100, temperature=0.9, top_k=50, top_p=0.95):
 
30
  temperature=temperature,
31
  top_k=top_k,
32
  top_p=top_p,
33
+ eos_token_id=None,
34
  pad_token_id=self.tokenizer.pad_token_id
35
  )
36
 
 
54
  description="Generate text using your custom-trained LLM"
55
  )
56
 
57
+ interface.launch(share=True)