SLAPaper commited on
Commit
dff15c2
·
verified ·
1 Parent(s): eb40ec9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -1
app.py CHANGED
@@ -1,3 +1,47 @@
 
1
  import gradio as gr
2
 
3
- gr.load("models/roborovski/superprompt-v1").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools as ft
2
  import gradio as gr
3
 
4
+ import torch
5
+ import transformers
6
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
7
+
8
+ tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
9
+ "models/roborovski/superprompt-v1", local_files_only=True
10
+ )
11
+ model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
12
+ "models/roborovski/superprompt-v1", local_files_only=True
13
+ )
14
+
15
+ @ft.lru_cache(maxsize=1024)
16
+ def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str:
17
+ transformers.set_seed(seed)
18
+
19
+ if max_new_tokens <= 0:
20
+ max_new_tokens = 150
21
+
22
+ with torch.inference_mode():
23
+ if prompt:
24
+ input_text = f"{prompt} {text}"
25
+ else:
26
+ input_text = f"Expand the following prompt to add more detail: {text}"
27
+
28
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
29
+
30
+ outputs = model.generate(
31
+ input_ids,
32
+ max_length=max_new_tokens,
33
+ do_sample=True,
34
+ temperature=0.7,
35
+ top_k=50,
36
+ top_p=0.95,
37
+ )
38
+
39
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+
41
+ demo = gr.Interface(
42
+ fn=super_prompt,
43
+ inputs=["text", "slider", "slider", "text"],
44
+ outputs=["text"],
45
+ )
46
+
47
+ demo.launch()