File size: 3,147 Bytes
05140a3
7eaefa4
 
 
69889de
7eaefa4
cf17b13
7eaefa4
 
cf17b13
7eaefa4
 
 
 
 
 
 
 
 
 
 
 
cf17b13
7eaefa4
 
 
 
 
 
 
 
cf17b13
7eaefa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9061c
7eaefa4
 
 
 
 
 
7b9061c
7fe3a25
05140a3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import json
from transformers import pipeline
# Load the pipeline
tourModel = pipeline(model="manhan/GPT-Tour")

def getTour(income,size,years,sex,edu,wrk): 
    # person = a dict with person-level and hh-level attributes:
    person = {}
    hh_income = int(income)
    if hh_income<0:
        _skip = 1 # nvm, bad data
    elif hh_income<4: # $25,000
        person['hh_inc'] = 'poor'
    elif hh_income<6: # $50,000
        person['hh_inc'] = 'low'
    elif hh_income<7: # $75,000
        person['hh_inc'] = 'medium'
    elif hh_income<9: # $125,000
        person['hh_inc'] = 'high'
    else: # over
        person['hh_inc'] = 'affluent'
    hh_size = int(size)
    if hh_size == 1:
        person['hh_size'] = 'single'
    elif hh_size == 2:
        person['hh_size'] = 'couple'
    elif hh_size <= 4:
        person['hh_size'] = 'small'
    else: # more than four people
        person['hh_size'] = 'large'
    age = int(years)
    if age < 18:
        person['age_grp'] = 'child'
    elif age < 45:
        person['age_grp'] = 'younger'
    elif age < 65:
        person['age_grp'] = 'older'
    else:
        person['age_grp'] = 'senior'
    person['sex'] = sex
    person['edu'] = edu
    person['wrk'] = wrk
    activity_list = []
    prompt = json.dumps(person)[:-1] + ", pattern: "
    print(person)
    while not activity_list:
        generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text']
        #print(f"{generated}")
        start_pos = generated.find('[')
        end_pos = generated.find(']')+1
        activity_list_str = generated[start_pos:end_pos]
        print(f"Extracted: '{activity_list_str}'")
        # this check doesn't appear to work anyways
        if person['wrk']=='yes' and activity_list_str.find('Work')==-1:
            continue # try again
        if person['wrk']=='no' and activity_list_str.find('Work')>0:
            continue # try again        
        if activity_list_str:
            try:
                activity_list = json.loads(activity_list_str)
                if activity_list[-1]!='Home':
                    activity_list=[]
                    continue
                break
            except Exception as e:
                print("Error parsing activity list")
                print(e)
        else:
            print("Nothing extracted!")
    return activity_list

with gr.Interface(fn=getTour, inputs=[
    gr.inputs.Textbox(label="Annual Household Income (in dollars)"),
    gr.inputs.Textbox(label="Household Size (number of people)"),
    gr.inputs.Textbox(label="Traveler Age (years)"),
    gr.inputs.Dropdown(["unknown", "male", "female"], label="Gender/sex"),
    gr.inputs.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"),
    gr.inputs.Dropdown(["unknown", "yes","no"], label="Worker status")
], outputs="json", title="GPT-Tour", description="Generate a tour for a person", allow_flagging=False, allow_screenshot=False, allow_embedding=False) as iface:
    iface.launch()

gr.Interface.load("models/manhan/GPT-Tour").launch()