KantaHayashiAI commited on
Commit
67f3195
·
verified ·
1 Parent(s): 8af9f91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -20
app.py CHANGED
@@ -24,7 +24,7 @@ MAX_INPUT_TOKEN_LENGTH = 32000
24
 
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
- tokenizer = AutoTokenizer.from_pretrained("evabyte/EvaByte-SFT", trust_remote_code=True)
28
  model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte-SFT", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
29
 
30
  @spaces.GPU(duration=120)
@@ -34,8 +34,6 @@ def generate(
34
  max_new_tokens: int = 1024,
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
37
- top_k: int = 50,
38
- repetition_penalty: float = 1.2,
39
  ) -> Iterator[str]:
40
  conversation = [*chat_history, {"role": "user", "content": message}]
41
 
@@ -52,10 +50,7 @@ def generate(
52
  max_new_tokens=max_new_tokens,
53
  do_sample=True,
54
  top_p=top_p,
55
- top_k=top_k,
56
  temperature=temperature,
57
- num_beams=1,
58
- repetition_penalty=repetition_penalty,
59
  )
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
@@ -90,20 +85,6 @@ demo = gr.ChatInterface(
90
  step=0.05,
91
  value=0.9,
92
  ),
93
- gr.Slider(
94
- label="Top-k",
95
- minimum=1,
96
- maximum=1000,
97
- step=1,
98
- value=50,
99
- ),
100
- gr.Slider(
101
- label="Repetition penalty",
102
- minimum=1.0,
103
- maximum=2.0,
104
- step=0.05,
105
- value=1.2,
106
- ),
107
  ],
108
  stop_btn=None,
109
  examples=[
 
24
 
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
+ tokenizer = AutoTokenizer.from_pretrained("EvaByte/EvaByte", trust_remote_code=True)
28
  model = AutoModelForCausalLM.from_pretrained("evabyte/EvaByte-SFT", torch_dtype=torch.bfloat16, trust_remote_code=True).eval().to("cuda")
29
 
30
  @spaces.GPU(duration=120)
 
34
  max_new_tokens: int = 1024,
35
  temperature: float = 0.6,
36
  top_p: float = 0.9,
 
 
37
  ) -> Iterator[str]:
38
  conversation = [*chat_history, {"role": "user", "content": message}]
39
 
 
50
  max_new_tokens=max_new_tokens,
51
  do_sample=True,
52
  top_p=top_p,
 
53
  temperature=temperature,
 
 
54
  )
55
  t = Thread(target=model.generate, kwargs=generate_kwargs)
56
  t.start()
 
85
  step=0.05,
86
  value=0.9,
87
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  ],
89
  stop_btn=None,
90
  examples=[