outdir = '/media/Scratch/MLSS/'
filename = '../../Data/S8_pass_to_S9.csv'
batch_size=4
name='S9_gender'
import os # interact with file system
import sys # Needed to load settings file
import platform # identify OS type
import shutil
from glob import glob
from tqdm.autonotebook import tqdm
import csv # read and write csv files
import transformers
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import pipeline
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import evaluate
from datasets import load_dataset, Features, Value, Dataset, load_metric
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import numpy as np
transformers.logging.set_verbosity_error()
/tmp/ipykernel_1863121/3703570026.py:6: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from tqdm.autonotebook import tqdm
# Define tokenization for all models in this series
tokenizer = BertTokenizer.from_pretrained('yiyanghkust/finbert-pretrain')
def tokenize_function(observations):
return tokenizer(observations["text"], padding="max_length", truncation=True, max_length=512)
def model_inference(encoded_dataset, name):
# import the pytorch model
model = AutoModelForSequenceClassification.from_pretrained(outdir + 'FinBERT_' + name + '/Best_model')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
logits = []
for i in tqdm(range(0, len(encoded_dataset))):
input_ids = encoded_dataset[i]['input_ids'].unsqueeze(0) # unsqueeze(0) casts from 1D tensor to 2D tensor. Quirk of BERT.
token_type_ids = encoded_dataset[i]['token_type_ids'].unsqueeze(0)
attention_mask = encoded_dataset[i]['attention_mask'].unsqueeze(0)
logits.append(model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).logits.cpu().detach().numpy().tolist()[0])
# free space by deleting model and data from GPU
del model
# output the data
cols = ['log' + str(i) for i in range(0, len(logits[0]))]
return pd.DataFrame(np.array(logits), columns = cols)
df = pd.read_csv(outdir + name + '.csv.gz', encoding='utf-8')
dataset = Dataset.from_pandas(df)
encoded_dataset = dataset.map(tokenize_function)
Map: 100%|█████████████████████████████████████████████████████| 1195/1195 [00:11<00:00, 105.79 examples/s]
encoded_dataset
Dataset({ features: ['text', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'], num_rows: 1195 })
encoded_dataset = encoded_dataset.remove_columns(['text', 'labels'])
# push encoded data to GPU
encoded_dataset.set_format("torch", device="cuda")
# pass data through the model
output = model_inference(encoded_dataset, name=name)
100%|██████████████████████████████████████████████████████████████████| 1195/1195 [00:20<00:00, 57.05it/s]
output = df.join(output)
encoded_dataset.set_format("torch", device="cpu")
output
text | labels | log0 | log1 | |
---|---|---|---|---|
0 | Who can blame an investor for taking to the bo... | 1 | 0.794654 | -0.469129 |
1 | Businesses are dumping office space at the fas... | 0 | 0.804681 | 0.202452 |
2 | WASHINGTON -- President George W. Bush's plan ... | 0 | 1.144736 | -0.799532 |
3 | FRANKFURT -- The European Central Bank, while ... | 1 | 0.791363 | -0.383891 |
4 | PARIS -- President Nicolas Sarkozy introduced ... | 0 | 1.267021 | -0.506859 |
... | ... | ... | ... | ... |
1190 | Should John McCain eke out a victory on Tuesda... | 0 | 1.544418 | -1.026227 |
1191 | In times of economic crisis, with thoughts of ... | 1 | 0.832721 | -0.465457 |
1192 | J.P. Morgan Chase & Co. launched an ambitious ... | 1 | 1.524031 | -0.449714 |
1193 | The Morning Brief, a look at the day's biggest... | 0 | 1.613019 | -0.998035 |
1194 | In the legal battle over chemical company Hunt... | 0 | 1.347540 | -0.431128 |
1195 rows × 4 columns
output.rename(columns={'log0': 'logit_male', 'log1':'logit_female'}, inplace=True)
output['i_female'] = output.apply(lambda x: 0 if x['logit_male'] > x['logit_female'] else 1, axis=1)
output[output['i_female'] == 1]
text | labels | logit_male | logit_female | i_female | |
---|---|---|---|---|---|
23 | Merck & Co. Thursday said it halted developmen... | 0 | 0.410868 | 0.453517 | 1 |
34 | Amgen Inc. said a federal court granted the bi... | 1 | -0.144691 | 0.993399 | 1 |
36 | Citing overwhelming economic pressures and war... | 1 | 0.233885 | 0.375459 | 1 |
49 | Options traders hopped on railroad companies T... | 1 | -0.181216 | 0.967081 | 1 |
54 | The euro fell to a fresh one-year low against ... | 1 | 0.145984 | 0.390424 | 1 |
... | ... | ... | ... | ... | ... |
1143 | A federal appeals court has upheld a decision ... | 0 | 0.115356 | 0.389280 | 1 |
1149 | Walgreen Co. plans to cut costs, overhaul stor... | 1 | -0.073960 | 0.877992 | 1 |
1152 | A combination of a treatment known as cognitiv... | 0 | 0.296025 | 0.333333 | 1 |
1157 | WASHINGTON -- The clouds around GlaxoSmithKlin... | 1 | -0.336796 | 0.735275 | 1 |
1158 | Cigna Corp., citing competitive and economic p... | 1 | -0.064404 | 0.586696 | 1 |
78 rows × 5 columns