Lin0He commited on
Commit
411f507
·
1 Parent(s): 66168b3

Upload interface.py

Browse files
Files changed (1) hide show
  1. interface.py +84 -13
interface.py CHANGED
@@ -1,13 +1,84 @@
1
- import gradio as gr
2
-
3
- title = "GPT-J-6B"
4
- description = "Gradio Demo for GPT-J 6B, a transformer model trained using Ben Wang's Mesh Transformer JAX. 'GPT-J' refers to the class of model, while '6B' represents the number of trainable parameters. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
5
- article = "<p style='text-align: center'><a href='https://github.com/kingoflolz/mesh-transformer-jax' target='_blank'>GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model</a></p>"
6
-
7
- gr.Interface.load(
8
- "huggingface/EleutherAI/gpt-j-6B",
9
- inputs=gr.Textbox(lines=5, label="Input Text"),
10
- title=title,
11
- description=description,
12
- article=article,
13
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # upload model
3
+ import torch
4
+ from transformers import GPT2LMHeadModel,GPT2Tokenizer, GPT2Config
5
+
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
+ model = torch.load('text_summary_4sets_2_550.pth', map_location=torch.device('mps'))
8
+
9
+ model.push_to_hub(repo_name="text-summary-gpt2-short", repo_id="Lin0He/text-summary-gpt2-short")
10
+ tokenizer.push_to_hub(repo_name="text-summary-gpt2-short", repo_id="Lin0He/text-summary-gpt2-short")
11
+ '''
12
+ from transformers import pipeline, AutoModel, AutoTokenizer
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("Lin0He/text-summary-gpt2-short")
15
+ model = AutoModel.from_pretrained("Lin0He/text-summary-gpt2-short")
16
+
17
+ def topk(probs, n=9):
18
+ # The scores are initially softmaxed to convert to probabilities
19
+ probs = torch.softmax(probs, dim= -1)
20
+
21
+ # PyTorch has its own topk method, which we use here
22
+ tokensProb, topIx = torch.topk(probs, k=n)
23
+
24
+ # The new selection pool (9 choices) is normalized
25
+ tokensProb = tokensProb / torch.sum(tokensProb)
26
+
27
+ # Send to CPU for numpy handling
28
+ tokensProb = tokensProb.cpu().detach().numpy()
29
+
30
+ # Make a random choice from the pool based on the new prob distribution
31
+ choice = np.random.choice(n, 1, p = tokensProb)#[np.argmax(tokensProb)]#
32
+ tokenId = topIx[choice][0]
33
+
34
+ return int(tokenId)
35
+
36
+ def model_infer(model, tokenizer, review, max_length=30):
37
+ # Preprocess the init token (task designator)
38
+ review_encoded = tokenizer.encode(review)
39
+ result = review_encoded
40
+ initial_input = torch.tensor(review_encoded).unsqueeze(0).to(device)
41
+
42
+ with torch.set_grad_enabled(False):
43
+ # Feed the init token to the model
44
+ output = model(initial_input)
45
+
46
+ # Flatten the logits at the final time step
47
+ logits = output.logits[0,-1]
48
+
49
+ # Make a top-k choice and append to the result
50
+ #choices = [topk(logits) for i in range(5)]
51
+ choices = topk(logits)
52
+ result.append(choices)
53
+
54
+ # For max_length times:
55
+ for _ in range(max_length):
56
+ # Feed the current sequence to the model and make a choice
57
+ input = torch.tensor(result).unsqueeze(0).to(device)
58
+ output = model(input)
59
+ logits = output.logits[0,-1]
60
+ res_id = topk(logits)
61
+
62
+ # If the chosen token is EOS, return the result
63
+ if res_id == tokenizer.eos_token_id:
64
+ return tokenizer.decode(result)
65
+ else: # Append to the sequence
66
+ result.append(res_id)
67
+
68
+ # IF no EOS is generated, return after the max_len
69
+ return tokenizer.decode(result)
70
+
71
+ def predict(text):
72
+ result_text = []
73
+ for i in range(6):
74
+ summary = model_infer(model, tokenizer, input+"TL;DR").strip()
75
+ result_text.append(summary[len(input)+5:])
76
+ return sorted(result_text, key=len)[3]
77
+ #print("summary:", sorted(result_text, key=len)[3])
78
+
79
+
80
+ '''
81
+ predictor = pipeline("summarization", model = model, tokenizer = tokenizer)
82
+ result = predictor("Input text for prediction")
83
+ print(result)
84
+ '''