Nick088 commited on
Commit
4c55de4
·
verified ·
1 Parent(s): 825d0c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -11,26 +11,25 @@ else:
11
  device = "cpu"
12
  print("Using CPU")
13
 
14
- tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
15
-
 
 
16
  @spaces.GPU()
17
  def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
 
 
 
 
18
  if model_precision_type == "fp16":
19
  dtype = torch.float16
20
  elif model_precision_type == "fp32":
21
  dtype = torch.float32
22
 
23
- model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=dtype)
24
- model.to(device)
25
 
26
  input_text = f"Expand the following prompt to add more detail: {your_prompt}"
27
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
28
-
29
- if seed == 0:
30
- seed = random.randint(1, 100000)
31
- torch.manual_seed(seed)
32
- else:
33
- torch.manual_seed(seed)
34
 
35
  outputs = model.generate(
36
  input_ids,
@@ -42,8 +41,7 @@ def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model
42
  top_k=top_k,
43
  )
44
 
45
- better_prompt = tokenizer.decode(outputs[0])
46
- better_prompt = better_prompt.replace("<pad>", "").replace("</s>", "")
47
  return better_prompt
48
 
49
 
 
11
  device = "cpu"
12
  print("Using CPU")
13
 
14
+ tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
15
+ model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1")
16
+ model.to(device)
17
+
18
  @spaces.GPU()
19
  def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
20
+ if seed == 0:
21
+ seed = random.randint(1, 2**32-1)
22
+ transformers.set_seed(seed)
23
+
24
  if model_precision_type == "fp16":
25
  dtype = torch.float16
26
  elif model_precision_type == "fp32":
27
  dtype = torch.float32
28
 
29
+ model.to(dtype)
 
30
 
31
  input_text = f"Expand the following prompt to add more detail: {your_prompt}"
32
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device).to(dtype)
 
 
 
 
 
 
33
 
34
  outputs = model.generate(
35
  input_ids,
 
41
  top_k=top_k,
42
  )
43
 
44
+ better_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
45
  return better_prompt
46
 
47