SLAPaper commited on
Commit
38f411b
·
verified ·
1 Parent(s): 08ea196

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -1,17 +1,19 @@
1
  import functools as ft
2
- import gradio as gr
3
 
 
4
  import torch
5
  import transformers
6
  from transformers import T5ForConditionalGeneration, T5Tokenizer
7
 
 
8
  tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
9
- "models/roborovski/superprompt-v1", local_files_only=True
10
  )
11
  model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
12
- "models/roborovski/superprompt-v1", local_files_only=True
13
  )
14
 
 
15
  @ft.lru_cache(maxsize=1024)
16
  def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str:
17
  transformers.set_seed(seed)
@@ -38,6 +40,7 @@ def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str:
38
 
39
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
 
41
  demo = gr.Interface(
42
  fn=super_prompt,
43
  inputs=["text", "slider", "slider", "text"],
 
1
  import functools as ft
 
2
 
3
+ import gradio as gr
4
  import torch
5
  import transformers
6
  from transformers import T5ForConditionalGeneration, T5Tokenizer
7
 
8
+
9
  tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
10
+ "roborovski/superprompt-v1"
11
  )
12
  model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
13
+ "roborovski/superprompt-v1"
14
  )
15
 
16
+
17
  @ft.lru_cache(maxsize=1024)
18
  def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str:
19
  transformers.set_seed(seed)
 
40
 
41
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
+
44
  demo = gr.Interface(
45
  fn=super_prompt,
46
  inputs=["text", "slider", "slider", "text"],