durrani commited on
Commit
be02fc4
·
1 Parent(s): 2086dfb
Files changed (1) hide show
  1. app.py +14 -32
app.py CHANGED
@@ -1,38 +1,20 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
 
4
- def predict_rooms(new_students, new_temperature):
5
- # Load the model and tokenizer
6
- model_name = "AI" # Replace with the name or path of the model you want to use
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
- # Convert the input to tokens
11
- inputs = tokenizer.encode_plus(
12
- "Number of students: {}, Temperature: {}".format(new_students, new_temperature),
13
- padding="max_length",
14
- truncation=True,
15
- max_length=64,
16
- return_tensors="pt"
17
- )
18
 
19
- # Make the prediction
20
- with torch.no_grad():
21
- outputs = model(**inputs)
22
- logits = outputs.logits
23
- predicted_rooms = torch.argmax(logits, dim=1).item()
24
 
25
- return predicted_rooms
 
 
 
26
 
27
- def greet(name):
28
- return "Hello " + name + "!"
29
-
30
- iface = gr.Interface(
31
- fn=[predict_rooms, greet],
32
- inputs=[["number", "number"], "text"],
33
- outputs=["number", "text"],
34
- title="Room Prediction",
35
- description="Predict the number of rooms based on the number of students and temperature, and greet the user."
36
- )
37
-
38
- iface.launch()
 
1
+ import torch
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
 
4
+ # Load tokenizer and model
5
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
6
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=1)
 
 
7
 
8
+ # Input text
9
+ text = "I want to book a flight from New York to London on July 1st."
 
 
 
 
 
 
10
 
11
+ # Tokenize input text
12
+ inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
 
 
 
13
 
14
+ # Predict flight ticket
15
+ with torch.no_grad():
16
+ logits = model(**inputs).logits
17
+ prediction = torch.sigmoid(logits).item()
18
 
19
+ # Display prediction
20
+ print("Flight ticket prediction:", prediction)