| 92 | |
| 93 | # save errors into a directory |
| 94 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False): |
| 95 | if not hasattr(self, 'plot_data'): |
| 96 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} |
| 97 | self.plot_data['X'].append(epoch + counter_ratio) |
| 98 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) |
| 99 | |
| 100 | # embed() |
| 101 | if(keys=='+ALL'): |
| 102 | plot_keys = self.plot_data['legend'] |
| 103 | else: |
| 104 | plot_keys = keys |
| 105 | |
| 106 | if(to_plot): |
| 107 | (f,ax) = plt.subplots(1,1) |
| 108 | for (k,kname) in enumerate(plot_keys): |
| 109 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0] |
| 110 | x = self.plot_data['X'] |
| 111 | y = np.array(self.plot_data['Y'])[:,kk] |
| 112 | if(to_plot): |
| 113 | ax.plot(x, y, 'o-', label=kname) |
| 114 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x) |
| 115 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y) |
| 116 | |
| 117 | if(to_plot): |
| 118 | plt.legend(loc=0,fontsize='small') |
| 119 | plt.xlabel('epoch') |
| 120 | plt.ylabel('Value') |
| 121 | f.savefig(os.path.join(self.web_dir,'%s.png'%name)) |
| 122 | f.clf() |
| 123 | plt.close() |
| 124 | |
| 125 | # errors: dictionary of error labels and values |
| 126 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): |